diff --git a/pysegcnn/core/graphics.py b/pysegcnn/core/graphics.py index 5366f7dcaef6d87d96ad8ef44bd2f0cbdd81ce58..3969d7ff7917ff3428f9763a97d4a715d36acd6b 100644 --- a/pysegcnn/core/graphics.py +++ b/pysegcnn/core/graphics.py @@ -174,17 +174,17 @@ def plot_confusion_matrix(cm, labels, normalize=True, return fig, ax -def plot_loss(loss_file, figsize=(10, 10), step=5, +def plot_loss(state_file, figsize=(10, 10), step=5, colors=['lightgreen', 'green', 'skyblue', 'steelblue'], outpath=os.path.join(HERE, '_graphics/')): - # load the model loss - state = torch.load(loss_file) + # load the model state + model_state = torch.load(state_file) # get all non-zero elements, i.e. get number of epochs trained before # early stop loss = {k: v[np.nonzero(v)].reshape(v.shape[0], -1) for k, v in - state.items() if k != 'epoch'} + model_state['state'].items()} # compute running mean with a window equal to the number of batches in # an epoch @@ -245,7 +245,7 @@ def plot_loss(loss_file, figsize=(10, 10), step=5, # save figure os.makedirs(outpath, exist_ok=True) fig.savefig(os.path.join( - outpath, os.path.basename(loss_file).replace('.pt', '.png')), + outpath, os.path.basename(state_file).replace('.pt', '.png')), dpi=300, bbox_inches='tight') return fig