Skip to content
Snippets Groups Projects
Commit 6c3c00d5 authored by Frisinghelli Daniel's avatar Frisinghelli Daniel
Browse files

Improved plot_loss() function. Now includes early stopping epoch.

parent f6315012
No related branches found
No related tags found
No related merge requests found
......@@ -447,7 +447,7 @@ class ImageDataset(Dataset):
return fig, ax
def plot_loss(self, state_file, figsize=(10, 10),
colors=['lightgreen', 'darkgreen', 'skyblue', 'steelblue'],
colors=['lightgreen', 'skyblue', 'darkgreen', 'steelblue'],
outpath=os.path.join(os.getcwd(), '_graphics/')):
# load the model loss
......@@ -459,7 +459,7 @@ class ImageDataset(Dataset):
state.items() if k != 'epoch'}
# number of epochs trained
epochs = np.arange(0, state['epoch'])
epochs = np.arange(0, state['epoch'] + 1)
# instanciate figure
fig, ax1 = plt.subplots(1, 1, figsize=figsize)
......@@ -469,10 +469,11 @@ class ImageDataset(Dataset):
label=k.capitalize().replace('_', ' '), color=c, lw=2)
for (k, v), c in zip(loss.items(), colors) if 'loss' in k]
# plot training and validation loss per batch
# plot training loss per batch
ax2 = ax1.twiny()
[ax2.plot(v.flatten('F'), color=c, alpha=0.5)
for (k, v), c in zip(loss.items(), colors) if 'loss' in k]
for (k, v), c in zip(loss.items(), colors) if 'loss' in k and
'validation' not in k]
# plot training and validation mean accuracy per epoch
ax3 = ax1.twinx()
......@@ -480,27 +481,39 @@ class ImageDataset(Dataset):
label=k.capitalize().replace('_', ' '), color=c, lw=2)
for (k, v), c in zip(loss.items(), colors) if 'accuracy' in k]
# plot training and validation accuracy per batch
# plot training accuracy per batch
ax4 = ax3.twiny()
[ax4.plot(v.flatten('F'), color=c, alpha=0.5)
for (k, v), c in zip(loss.items(), colors) if 'accuracy' in k]
# plot early stopping point
# ax2.vlines(loss['validation_accuracy'].mean(axis=0).max())
ax1.legend(frameon=False, loc='lower right')
ax3.legend(frameon=False, loc='upper left')
for (k, v), c in zip(loss.items(), colors) if 'accuracy' in k and
'validation' not in k]
# axes properties and labels
for ax in [ax2, ax4]:
ax.set(xticks=[], xticklabels=[])
ax1.set(xlabel='Epoch',
ylabel='Loss')
ylabel='Loss',
ylim=(0, 1))
ax3.set(ylabel='Accuracy',
ylim=(0, 1))
# compute early stopping point
esepoch = np.argmax(loss['validation_accuracy'].mean(axis=0))
esacc = np.max(loss['validation_accuracy'].mean(axis=0))
ax1.vlines(esepoch, ymin=ax1.get_ylim()[0], ymax=ax1.get_ylim()[1],
ls='--', color='grey')
ax1.text(esepoch - 1, ax1.get_ylim()[0] + 0.01,
'epoch = {}'.format(esepoch), ha='right', color='grey')
ax1.text(esepoch + 1, ax1.get_ylim()[0] + 0.01,
'acc = {:.2f}%'.format(esacc * 100), ha='left', color='grey')
# add legends
ax1.legend(frameon=False, loc='lower left')
ax3.legend(frameon=False, loc='upper left')
# save figure
os.makedirs(outpath, exist_ok=True)
fig.savefig(os.path.join(outpath, state_file.replace('.pt', '.png')),
fig.savefig(os.path.join(
outpath, os.path.basename(state_file).replace('.pt', '.png')),
dpi=300, bbox_inches='tight')
return fig, ax
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment