From 3f14aa4c5b3add474c8566e0715e98fd8f3c5e25 Mon Sep 17 00:00:00 2001
From: "Daniel.Frisinghelli" <daniel.frisinghelli@eurac.edu>
Date: Mon, 17 Aug 2020 17:20:56 +0200
Subject: [PATCH] Loss file is deprecated; loss is now also stored in state
 file

---
 pysegcnn/core/graphics.py | 10 +++++-----
 1 file changed, 5 insertions(+), 5 deletions(-)

diff --git a/pysegcnn/core/graphics.py b/pysegcnn/core/graphics.py
index 5366f7d..3969d7f 100644
--- a/pysegcnn/core/graphics.py
+++ b/pysegcnn/core/graphics.py
@@ -174,17 +174,17 @@ def plot_confusion_matrix(cm, labels, normalize=True,
     return fig, ax
 
 
-def plot_loss(loss_file, figsize=(10, 10), step=5,
+def plot_loss(state_file, figsize=(10, 10), step=5,
               colors=['lightgreen', 'green', 'skyblue', 'steelblue'],
               outpath=os.path.join(HERE, '_graphics/')):
 
-    # load the model loss
-    state = torch.load(loss_file)
+    # load the model state
+    model_state = torch.load(state_file)
 
     # get all non-zero elements, i.e. get number of epochs trained before
     # early stop
     loss = {k: v[np.nonzero(v)].reshape(v.shape[0], -1) for k, v in
-            state.items() if k != 'epoch'}
+            model_state['state'].items()}
 
     # compute running mean with a window equal to the number of batches in
     # an epoch
@@ -245,7 +245,7 @@ def plot_loss(loss_file, figsize=(10, 10), step=5,
     # save figure
     os.makedirs(outpath, exist_ok=True)
     fig.savefig(os.path.join(
-        outpath, os.path.basename(loss_file).replace('.pt', '.png')),
+        outpath, os.path.basename(state_file).replace('.pt', '.png')),
                 dpi=300, bbox_inches='tight')
 
     return fig
-- 
GitLab