From 3f14aa4c5b3add474c8566e0715e98fd8f3c5e25 Mon Sep 17 00:00:00 2001 From: "Daniel.Frisinghelli" <daniel.frisinghelli@eurac.edu> Date: Mon, 17 Aug 2020 17:20:56 +0200 Subject: [PATCH] Loss file is deprecated; loss is now also stored in state file --- pysegcnn/core/graphics.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/pysegcnn/core/graphics.py b/pysegcnn/core/graphics.py index 5366f7d..3969d7f 100644 --- a/pysegcnn/core/graphics.py +++ b/pysegcnn/core/graphics.py @@ -174,17 +174,17 @@ def plot_confusion_matrix(cm, labels, normalize=True, return fig, ax -def plot_loss(loss_file, figsize=(10, 10), step=5, +def plot_loss(state_file, figsize=(10, 10), step=5, colors=['lightgreen', 'green', 'skyblue', 'steelblue'], outpath=os.path.join(HERE, '_graphics/')): - # load the model loss - state = torch.load(loss_file) + # load the model state + model_state = torch.load(state_file) # get all non-zero elements, i.e. get number of epochs trained before # early stop loss = {k: v[np.nonzero(v)].reshape(v.shape[0], -1) for k, v in - state.items() if k != 'epoch'} + model_state['state'].items()} # compute running mean with a window equal to the number of batches in # an epoch @@ -245,7 +245,7 @@ def plot_loss(loss_file, figsize=(10, 10), step=5, # save figure os.makedirs(outpath, exist_ok=True) fig.savefig(os.path.join( - outpath, os.path.basename(loss_file).replace('.pt', '.png')), + outpath, os.path.basename(state_file).replace('.pt', '.png')), dpi=300, bbox_inches='tight') return fig -- GitLab