Skip to content
Snippets Groups Projects
Commit ef5cae60 authored by Frisinghelli Daniel's avatar Frisinghelli Daniel
Browse files

Fixed plot order in plot_loss; increased flexibility for plot_confusion_matrix

parent aae0cde5
No related branches found
No related tags found
No related merge requests found
...@@ -173,8 +173,7 @@ def plot_confusion_matrix(cm, labels, normalize=True, ...@@ -173,8 +173,7 @@ def plot_confusion_matrix(cm, labels, normalize=True,
# save figure # save figure
if state is not None: if state is not None:
os.makedirs(outpath, exist_ok=True) os.makedirs(outpath, exist_ok=True)
fig.savefig(os.path.join(outpath, state.replace('.pt', '_cm.png')), fig.savefig(os.path.join(outpath, state), dpi=300, bbox_inches='tight')
dpi=300, bbox_inches='tight')
return fig, ax return fig, ax
...@@ -195,6 +194,9 @@ def plot_loss(loss_file, figsize=(10, 10), step=5, ...@@ -195,6 +194,9 @@ def plot_loss(loss_file, figsize=(10, 10), step=5,
# an epoch # an epoch
rm = {k: running_mean(v.flatten('F'), v.shape[0]) for k, v in loss.items()} rm = {k: running_mean(v.flatten('F'), v.shape[0]) for k, v in loss.items()}
# sort the keys of the dictionary alphabetically
rm = {k: rm[k] for k in sorted(rm)}
# number of epochs trained # number of epochs trained
epochs = np.arange(0, loss['tl'].shape[1]) epochs = np.arange(0, loss['tl'].shape[1])
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment