Skip to content
Snippets Groups Projects
Commit 07b4f5ea authored by Frisinghelli Daniel's avatar Frisinghelli Daniel
Browse files

Refactored plot_loss() function

parent dfd3057a
No related branches found
No related tags found
No related merge requests found
......@@ -13,6 +13,7 @@ import numpy as np
import torch
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
import matplotlib.lines as mlines
from matplotlib.colors import ListedColormap, BoundaryNorm
from matplotlib import cm as colormap
......@@ -36,6 +37,11 @@ def contrast_stretching(image, alpha=2):
return norm
def running_mean(x, w):
cumsum = np.cumsum(np.insert(x, 0, 0))
return (cumsum[w:] - cumsum[:-w]) / w
# plot_sample() plots a false color composite of the scene/tile together
# with the model prediction and the corresponding ground truth
def plot_sample(x, y, use_bands, labels, y_pred=None, figsize=(10, 10),
......@@ -164,8 +170,8 @@ def plot_confusion_matrix(cm, labels, normalize=True,
return fig, ax
def plot_loss(loss_file, figsize=(10, 10),
colors=['lightgreen', 'skyblue', 'darkgreen', 'steelblue'],
def plot_loss(loss_file, figsize=(10, 10), step=5,
colors=['lightgreen', 'green', 'skyblue', 'steelblue'],
outpath=os.path.join(os.getcwd(), '_graphics/')):
# load the model loss
......@@ -176,60 +182,58 @@ def plot_loss(loss_file, figsize=(10, 10),
loss = {k: v[np.nonzero(v)].reshape(v.shape[0], -1) for k, v in
state.items() if k != 'epoch'}
# compute running mean with a window equal to the number of batches in
# an epoch
rm = {k: running_mean(v.flatten('F'), v.shape[0]) for k, v in loss.items()}
# number of epochs trained
epochs = np.arange(0, state['epoch'] + 1)
epochs = np.arange(0, loss['tl'].shape[1])
# instanciate figure
fig, ax1 = plt.subplots(1, 1, figsize=figsize)
# plot training and validation mean loss per epoch
[ax1.plot(epochs, v.mean(axis=0),
label=k.capitalize().replace('_', ' '), color=c, lw=2)
for (k, v), c in zip(loss.items(), colors) if v.any() and 'loss' in k]
# plot training loss per batch
ax2 = ax1.twiny()
[ax2.plot(v.flatten('F'), color=c, alpha=0.5)
for (k, v), c in zip(loss.items(), colors) if 'loss' in k and
'validation' not in k]
# plot training and validation mean accuracy per epoch
ax3 = ax1.twinx()
[ax3.plot(epochs, v.mean(axis=0),
label=k.capitalize().replace('_', ' '), color=c, lw=2)
for (k, v), c in zip(loss.items(), colors) if v.any() and 'accuracy'
in k]
# plot training accuracy per batch
ax4 = ax3.twiny()
[ax4.plot(v.flatten('F'), color=c, alpha=0.5)
for (k, v), c in zip(loss.items(), colors) if 'accuracy' in k and
'validation' not in k]
# create axes for each parameter to plot
ax2 = ax1.twinx()
ax3 = ax1.twiny()
ax4 = ax2.twiny()
# list of axes
axes = [ax1, ax2, ax3, ax4]
# plot running mean loss and accuracy of the training dataset
[ax.plot(v, color=c) for (k, v), ax, c in zip(rm.items(), axes, colors)
if v.any()]
# axes properties and labels
for ax in [ax2, ax4]:
nbatches = loss['tl'].shape[0]
for ax in [ax3, ax4]:
ax.set(xticks=[], xticklabels=[])
ax1.set(xlabel='Epoch',
ax1.set(xticks=np.arange(0, nbatches * epochs[-1] + 1, nbatches * step),
xticklabels=epochs[::step],
xlabel='Epoch',
ylabel='Loss',
ylim=(0, 1))
ax3.set(ylabel='Accuracy',
ylim=(0, 1))
ax2.set(ylabel='Accuracy',
ylim=(0.5, 1))
# compute early stopping point
if loss['validation_accuracy'].any():
esepoch = np.argmax(loss['validation_accuracy'].mean(axis=0))
esacc = np.max(loss['validation_accuracy'].mean(axis=0))
if loss['va'].any():
esepoch = np.argmax(loss['va'].mean(axis=0)) * nbatches
esacc = np.max(loss['va'].mean(axis=0))
ax1.vlines(esepoch, ymin=ax1.get_ylim()[0], ymax=ax1.get_ylim()[1],
ls='--', color='grey')
ax1.text(esepoch - 1, ax1.get_ylim()[0] + 0.01,
'epoch = {}'.format(esepoch), ha='right', color='grey')
ax1.text(esepoch + 1, ax1.get_ylim()[0] + 0.01,
'acc = {:.2f}%'.format(esacc * 100), ha='left',
color='grey')
# add legends
ax1.legend(frameon=False, loc='lower left')
ax3.legend(frameon=False, loc='upper left')
ax1.text(esepoch - nbatches, ax1.get_ylim()[0] + 0.01,
'epoch = {}, accuracy = {:.1f}%'
.format(int(esepoch / nbatches), esacc * 100),
ha='right', color='grey')
# create a patch (proxy artist) for every color
ulabels = ['Training loss', 'Training accuracy',
'Validation loss', 'Validation accuracy']
patches = [mlines.Line2D([], [], color=c, label=l) for c, l in
zip(colors, ulabels)]
# plot patches as legend
ax1.legend(handles=patches, loc='lower left', frameon=False)
# save figure
os.makedirs(outpath, exist_ok=True)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment