Skip to content
Snippets Groups Projects
Commit 9e8bf5d7 authored by Frisinghelli Daniel's avatar Frisinghelli Daniel
Browse files

Fixed a bug: plots of loss and confusion matrix were overwritten: same filename.

parent 63b1485b
No related branches found
No related tags found
No related merge requests found
...@@ -193,7 +193,7 @@ def plot_sample(x, use_bands, labels, y=None, y_pred=None, figsize=(10, 10), ...@@ -193,7 +193,7 @@ def plot_sample(x, use_bands, labels, y=None, y_pred=None, figsize=(10, 10),
def plot_confusion_matrix(cm, labels, normalize=True, def plot_confusion_matrix(cm, labels, normalize=True,
figsize=(10, 10), cmap='Blues', state=None, figsize=(10, 10), cmap='Blues', state_file=None,
outpath=os.path.join(HERE, '_graphics/')): outpath=os.path.join(HERE, '_graphics/')):
"""Plot the confusion matrix ``cm``. """Plot the confusion matrix ``cm``.
...@@ -214,7 +214,7 @@ def plot_confusion_matrix(cm, labels, normalize=True, ...@@ -214,7 +214,7 @@ def plot_confusion_matrix(cm, labels, normalize=True,
The figure size in centimeters. The default is (10, 10). The figure size in centimeters. The default is (10, 10).
cmap : `str`, optional cmap : `str`, optional
A colormap in `matplotlib.pyplot.colormaps()`. The default is 'Blues'. A colormap in `matplotlib.pyplot.colormaps()`. The default is 'Blues'.
state : `str` or `None`, optional state_file : `str` or `None` or `pathlib.Path`, optional
Filename to save the plot to. ``state`` should be an existing model Filename to save the plot to. ``state`` should be an existing model
state file ending with '.pt'. The default is None, i.e. plot is not state file ending with '.pt'. The default is None, i.e. plot is not
saved to disk. saved to disk.
...@@ -286,9 +286,11 @@ def plot_confusion_matrix(cm, labels, normalize=True, ...@@ -286,9 +286,11 @@ def plot_confusion_matrix(cm, labels, normalize=True,
fig.colorbar(im, cax=cax) fig.colorbar(im, cax=cax)
# save figure # save figure
if state is not None: if state_file is not None:
os.makedirs(outpath, exist_ok=True) os.makedirs(outpath, exist_ok=True)
fig.savefig(os.path.join(outpath, state), dpi=300, bbox_inches='tight') fig.savefig(os.path.join(
outpath, os.path.basename(state_file).replace('.pt', '_cm.png')),
dpi=300, bbox_inches='tight')
return fig, ax return fig, ax
...@@ -387,7 +389,7 @@ def plot_loss(state_file, figsize=(10, 10), step=5, ...@@ -387,7 +389,7 @@ def plot_loss(state_file, figsize=(10, 10), step=5,
# save figure # save figure
os.makedirs(outpath, exist_ok=True) os.makedirs(outpath, exist_ok=True)
fig.savefig(os.path.join( fig.savefig(os.path.join(
outpath, os.path.basename(state_file).replace('.pt', '.png')), outpath, os.path.basename(state_file).replace('.pt', '_loss.png')),
dpi=300, bbox_inches='tight') dpi=300, bbox_inches='tight')
return fig return fig
...@@ -77,5 +77,5 @@ if __name__ == '__main__': ...@@ -77,5 +77,5 @@ if __name__ == '__main__':
# whether to plot the confusion matrix # whether to plot the confusion matrix
if ec.cm: if ec.cm:
plot_confusion_matrix(cm, ds.dataset.labels, plot_confusion_matrix(cm, ds.dataset.labels,
state=ec.state_file.name.replace('.pt', '.png'), state_file=ec.state_file,
outpath=ec.perfmc_path) outpath=ec.perfmc_path)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment