From 9e8bf5d74699d9dff60d0f6ef5986f1fd00e4b88 Mon Sep 17 00:00:00 2001 From: "Daniel.Frisinghelli" <daniel.frisinghelli@eurac.edu> Date: Thu, 27 Aug 2020 17:22:11 +0200 Subject: [PATCH] Fixed a bug: plots of loss and confusion matrix were overwritten: same filename. --- pysegcnn/core/graphics.py | 12 +++++++----- pysegcnn/main/eval.py | 2 +- 2 files changed, 8 insertions(+), 6 deletions(-) diff --git a/pysegcnn/core/graphics.py b/pysegcnn/core/graphics.py index dfb1034..a4f693b 100644 --- a/pysegcnn/core/graphics.py +++ b/pysegcnn/core/graphics.py @@ -193,7 +193,7 @@ def plot_sample(x, use_bands, labels, y=None, y_pred=None, figsize=(10, 10), def plot_confusion_matrix(cm, labels, normalize=True, - figsize=(10, 10), cmap='Blues', state=None, + figsize=(10, 10), cmap='Blues', state_file=None, outpath=os.path.join(HERE, '_graphics/')): """Plot the confusion matrix ``cm``. @@ -214,7 +214,7 @@ def plot_confusion_matrix(cm, labels, normalize=True, The figure size in centimeters. The default is (10, 10). cmap : `str`, optional A colormap in `matplotlib.pyplot.colormaps()`. The default is 'Blues'. - state : `str` or `None`, optional + state_file : `str` or `None` or `pathlib.Path`, optional Filename to save the plot to. ``state`` should be an existing model state file ending with '.pt'. The default is None, i.e. plot is not saved to disk. @@ -286,9 +286,11 @@ def plot_confusion_matrix(cm, labels, normalize=True, fig.colorbar(im, cax=cax) # save figure - if state is not None: + if state_file is not None: os.makedirs(outpath, exist_ok=True) - fig.savefig(os.path.join(outpath, state), dpi=300, bbox_inches='tight') + fig.savefig(os.path.join( + outpath, os.path.basename(state_file).replace('.pt', '_cm.png')), + dpi=300, bbox_inches='tight') return fig, ax @@ -387,7 +389,7 @@ def plot_loss(state_file, figsize=(10, 10), step=5, # save figure os.makedirs(outpath, exist_ok=True) fig.savefig(os.path.join( - outpath, os.path.basename(state_file).replace('.pt', '.png')), + outpath, os.path.basename(state_file).replace('.pt', '_loss.png')), dpi=300, bbox_inches='tight') return fig diff --git a/pysegcnn/main/eval.py b/pysegcnn/main/eval.py index dee7f91..9197a88 100644 --- a/pysegcnn/main/eval.py +++ b/pysegcnn/main/eval.py @@ -77,5 +77,5 @@ if __name__ == '__main__': # whether to plot the confusion matrix if ec.cm: plot_confusion_matrix(cm, ds.dataset.labels, - state=ec.state_file.name.replace('.pt', '.png'), + state_file=ec.state_file, outpath=ec.perfmc_path) -- GitLab