From 405feffbc112e7c4ec2e5cc62a0908a754fa99cb Mon Sep 17 00:00:00 2001
From: "Daniel.Frisinghelli" <daniel.frisinghelli@eurac.edu>
Date: Thu, 30 Jul 2020 10:06:16 +0200
Subject: [PATCH] Changed default path to save model output

---
 pysegcnn/core/graphics.py | 7 ++++---
 pysegcnn/core/models.py   | 7 ++++---
 pysegcnn/main/eval.py     | 2 +-
 3 files changed, 9 insertions(+), 7 deletions(-)

diff --git a/pysegcnn/core/graphics.py b/pysegcnn/core/graphics.py
index c447cdd..f2fb7b1 100644
--- a/pysegcnn/core/graphics.py
+++ b/pysegcnn/core/graphics.py
@@ -19,6 +19,7 @@ from matplotlib import cm as colormap
 
 # locals
 from pysegcnn.core.trainer import accuracy_function
+from pysegcnn.core.config import HERE
 
 
 # this function applies percentile stretching at the alpha level
@@ -49,7 +50,7 @@ def running_mean(x, w):
 # with the model prediction and the corresponding ground truth
 def plot_sample(x, y, use_bands, labels, y_pred=None, figsize=(10, 10),
                 bands=['nir', 'red', 'green'], stretch=False, state=None,
-                outpath=os.path.join(os.getcwd(), '_samples/'),  **kwargs):
+                outpath=os.path.join(HERE, '_samples/'),  **kwargs):
 
     # check whether to apply constrast stretching
     stretch = True if kwargs else stretch
@@ -111,7 +112,7 @@ def plot_sample(x, y, use_bands, labels, y_pred=None, figsize=(10, 10),
 # set returned by the pytorch.predict function
 def plot_confusion_matrix(cm, labels, normalize=True,
                           figsize=(10, 10), cmap='Blues', state=None,
-                          outpath=os.path.join(os.getcwd(), '_graphics/')):
+                          outpath=os.path.join(HERE, '_graphics/')):
 
     # number of classes
     labels = [label['label'] for label in labels.values()]
@@ -180,7 +181,7 @@ def plot_confusion_matrix(cm, labels, normalize=True,
 
 def plot_loss(loss_file, figsize=(10, 10), step=5,
               colors=['lightgreen', 'green', 'skyblue', 'steelblue'],
-              outpath=os.path.join(os.getcwd(), '_graphics/')):
+              outpath=os.path.join(HERE, '_graphics/')):
 
     # load the model loss
     state = torch.load(loss_file)
diff --git a/pysegcnn/core/models.py b/pysegcnn/core/models.py
index eca1ca4..617252e 100644
--- a/pysegcnn/core/models.py
+++ b/pysegcnn/core/models.py
@@ -16,6 +16,7 @@ import torch.nn as nn
 # locals
 from pysegcnn.core.layers import (Encoder, Decoder, Conv2dPool, Conv2dUnpool,
                                   Conv2dUpsample, Conv2dSame)
+from pysegcnn.main.config import HERE
 
 
 class Network(nn.Module):
@@ -31,8 +32,8 @@ class Network(nn.Module):
         for param in self.parameters():
             param.requires_grad = True
 
-    def save(self, state_file, optimizer, bands,
-             outpath=os.path.join(os.getcwd(), '_models')):
+    def save(self, state_file, optimizer, bands=None,
+             outpath=os.path.join(HERE, '_models/')):
 
         # check if the output path exists and if not, create it
         if not os.path.isdir(outpath):
@@ -70,7 +71,7 @@ class Network(nn.Module):
         return state
 
     def load(self, state_file, optimizer=None,
-             inpath=os.path.join(os.getcwd(), '_models')):
+             inpath=os.path.join(HERE, '_models/')):
 
         # load the model state file
         state = os.path.join(inpath, state_file)
diff --git a/pysegcnn/main/eval.py b/pysegcnn/main/eval.py
index 7375258..f082c35 100644
--- a/pysegcnn/main/eval.py
+++ b/pysegcnn/main/eval.py
@@ -38,7 +38,7 @@ if __name__ == '__main__':
                                     trainer.cm,
                                     trainer.plot_scenes,
                                     bands=trainer.plot_bands,
-                                    outpath=os.path.join(HERE, '_samples/'),
+                                    outpath=os.path.join(HERE, '_scenes/'),
                                     stretch=True,
                                     alpha=5)
 
-- 
GitLab