Skip to content
Snippets Groups Projects
Commit 405feffb authored by Frisinghelli Daniel's avatar Frisinghelli Daniel
Browse files

Changed default path to save model output

parent a12d474a
No related branches found
No related tags found
No related merge requests found
...@@ -19,6 +19,7 @@ from matplotlib import cm as colormap ...@@ -19,6 +19,7 @@ from matplotlib import cm as colormap
# locals # locals
from pysegcnn.core.trainer import accuracy_function from pysegcnn.core.trainer import accuracy_function
from pysegcnn.core.config import HERE
# this function applies percentile stretching at the alpha level # this function applies percentile stretching at the alpha level
...@@ -49,7 +50,7 @@ def running_mean(x, w): ...@@ -49,7 +50,7 @@ def running_mean(x, w):
# with the model prediction and the corresponding ground truth # with the model prediction and the corresponding ground truth
def plot_sample(x, y, use_bands, labels, y_pred=None, figsize=(10, 10), def plot_sample(x, y, use_bands, labels, y_pred=None, figsize=(10, 10),
bands=['nir', 'red', 'green'], stretch=False, state=None, bands=['nir', 'red', 'green'], stretch=False, state=None,
outpath=os.path.join(os.getcwd(), '_samples/'), **kwargs): outpath=os.path.join(HERE, '_samples/'), **kwargs):
# check whether to apply constrast stretching # check whether to apply constrast stretching
stretch = True if kwargs else stretch stretch = True if kwargs else stretch
...@@ -111,7 +112,7 @@ def plot_sample(x, y, use_bands, labels, y_pred=None, figsize=(10, 10), ...@@ -111,7 +112,7 @@ def plot_sample(x, y, use_bands, labels, y_pred=None, figsize=(10, 10),
# set returned by the pytorch.predict function # set returned by the pytorch.predict function
def plot_confusion_matrix(cm, labels, normalize=True, def plot_confusion_matrix(cm, labels, normalize=True,
figsize=(10, 10), cmap='Blues', state=None, figsize=(10, 10), cmap='Blues', state=None,
outpath=os.path.join(os.getcwd(), '_graphics/')): outpath=os.path.join(HERE, '_graphics/')):
# number of classes # number of classes
labels = [label['label'] for label in labels.values()] labels = [label['label'] for label in labels.values()]
...@@ -180,7 +181,7 @@ def plot_confusion_matrix(cm, labels, normalize=True, ...@@ -180,7 +181,7 @@ def plot_confusion_matrix(cm, labels, normalize=True,
def plot_loss(loss_file, figsize=(10, 10), step=5, def plot_loss(loss_file, figsize=(10, 10), step=5,
colors=['lightgreen', 'green', 'skyblue', 'steelblue'], colors=['lightgreen', 'green', 'skyblue', 'steelblue'],
outpath=os.path.join(os.getcwd(), '_graphics/')): outpath=os.path.join(HERE, '_graphics/')):
# load the model loss # load the model loss
state = torch.load(loss_file) state = torch.load(loss_file)
......
...@@ -16,6 +16,7 @@ import torch.nn as nn ...@@ -16,6 +16,7 @@ import torch.nn as nn
# locals # locals
from pysegcnn.core.layers import (Encoder, Decoder, Conv2dPool, Conv2dUnpool, from pysegcnn.core.layers import (Encoder, Decoder, Conv2dPool, Conv2dUnpool,
Conv2dUpsample, Conv2dSame) Conv2dUpsample, Conv2dSame)
from pysegcnn.main.config import HERE
class Network(nn.Module): class Network(nn.Module):
...@@ -31,8 +32,8 @@ class Network(nn.Module): ...@@ -31,8 +32,8 @@ class Network(nn.Module):
for param in self.parameters(): for param in self.parameters():
param.requires_grad = True param.requires_grad = True
def save(self, state_file, optimizer, bands, def save(self, state_file, optimizer, bands=None,
outpath=os.path.join(os.getcwd(), '_models')): outpath=os.path.join(HERE, '_models/')):
# check if the output path exists and if not, create it # check if the output path exists and if not, create it
if not os.path.isdir(outpath): if not os.path.isdir(outpath):
...@@ -70,7 +71,7 @@ class Network(nn.Module): ...@@ -70,7 +71,7 @@ class Network(nn.Module):
return state return state
def load(self, state_file, optimizer=None, def load(self, state_file, optimizer=None,
inpath=os.path.join(os.getcwd(), '_models')): inpath=os.path.join(HERE, '_models/')):
# load the model state file # load the model state file
state = os.path.join(inpath, state_file) state = os.path.join(inpath, state_file)
......
...@@ -38,7 +38,7 @@ if __name__ == '__main__': ...@@ -38,7 +38,7 @@ if __name__ == '__main__':
trainer.cm, trainer.cm,
trainer.plot_scenes, trainer.plot_scenes,
bands=trainer.plot_bands, bands=trainer.plot_bands,
outpath=os.path.join(HERE, '_samples/'), outpath=os.path.join(HERE, '_scenes/'),
stretch=True, stretch=True,
alpha=5) alpha=5)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment