diff --git a/pysegcnn/core/graphics.py b/pysegcnn/core/graphics.py index 23a8e9b735b3f3dac983d5fa5c2c0a59c2da87df..d164c4985f7464c8a8b776e8b42a0d62c9362c6f 100644 --- a/pysegcnn/core/graphics.py +++ b/pysegcnn/core/graphics.py @@ -173,8 +173,7 @@ def plot_confusion_matrix(cm, labels, normalize=True, # save figure if state is not None: os.makedirs(outpath, exist_ok=True) - fig.savefig(os.path.join(outpath, state.replace('.pt', '_cm.png')), - dpi=300, bbox_inches='tight') + fig.savefig(os.path.join(outpath, state), dpi=300, bbox_inches='tight') return fig, ax @@ -195,6 +194,9 @@ def plot_loss(loss_file, figsize=(10, 10), step=5, # an epoch rm = {k: running_mean(v.flatten('F'), v.shape[0]) for k, v in loss.items()} + # sort the keys of the dictionary alphabetically + rm = {k: rm[k] for k in sorted(rm)} + # number of epochs trained epochs = np.arange(0, loss['tl'].shape[1])