Skip to content
Snippets Groups Projects
Commit 07b4f5ea authored by Frisinghelli Daniel's avatar Frisinghelli Daniel
Browse files

Refactored plot_loss() function

parent dfd3057a
No related branches found
No related tags found
No related merge requests found
...@@ -13,6 +13,7 @@ import numpy as np ...@@ -13,6 +13,7 @@ import numpy as np
import torch import torch
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import matplotlib.patches as mpatches import matplotlib.patches as mpatches
import matplotlib.lines as mlines
from matplotlib.colors import ListedColormap, BoundaryNorm from matplotlib.colors import ListedColormap, BoundaryNorm
from matplotlib import cm as colormap from matplotlib import cm as colormap
...@@ -36,6 +37,11 @@ def contrast_stretching(image, alpha=2): ...@@ -36,6 +37,11 @@ def contrast_stretching(image, alpha=2):
return norm return norm
def running_mean(x, w):
cumsum = np.cumsum(np.insert(x, 0, 0))
return (cumsum[w:] - cumsum[:-w]) / w
# plot_sample() plots a false color composite of the scene/tile together # plot_sample() plots a false color composite of the scene/tile together
# with the model prediction and the corresponding ground truth # with the model prediction and the corresponding ground truth
def plot_sample(x, y, use_bands, labels, y_pred=None, figsize=(10, 10), def plot_sample(x, y, use_bands, labels, y_pred=None, figsize=(10, 10),
...@@ -164,8 +170,8 @@ def plot_confusion_matrix(cm, labels, normalize=True, ...@@ -164,8 +170,8 @@ def plot_confusion_matrix(cm, labels, normalize=True,
return fig, ax return fig, ax
def plot_loss(loss_file, figsize=(10, 10), def plot_loss(loss_file, figsize=(10, 10), step=5,
colors=['lightgreen', 'skyblue', 'darkgreen', 'steelblue'], colors=['lightgreen', 'green', 'skyblue', 'steelblue'],
outpath=os.path.join(os.getcwd(), '_graphics/')): outpath=os.path.join(os.getcwd(), '_graphics/')):
# load the model loss # load the model loss
...@@ -176,60 +182,58 @@ def plot_loss(loss_file, figsize=(10, 10), ...@@ -176,60 +182,58 @@ def plot_loss(loss_file, figsize=(10, 10),
loss = {k: v[np.nonzero(v)].reshape(v.shape[0], -1) for k, v in loss = {k: v[np.nonzero(v)].reshape(v.shape[0], -1) for k, v in
state.items() if k != 'epoch'} state.items() if k != 'epoch'}
# compute running mean with a window equal to the number of batches in
# an epoch
rm = {k: running_mean(v.flatten('F'), v.shape[0]) for k, v in loss.items()}
# number of epochs trained # number of epochs trained
epochs = np.arange(0, state['epoch'] + 1) epochs = np.arange(0, loss['tl'].shape[1])
# instanciate figure # instanciate figure
fig, ax1 = plt.subplots(1, 1, figsize=figsize) fig, ax1 = plt.subplots(1, 1, figsize=figsize)
# plot training and validation mean loss per epoch # create axes for each parameter to plot
[ax1.plot(epochs, v.mean(axis=0), ax2 = ax1.twinx()
label=k.capitalize().replace('_', ' '), color=c, lw=2) ax3 = ax1.twiny()
for (k, v), c in zip(loss.items(), colors) if v.any() and 'loss' in k] ax4 = ax2.twiny()
# plot training loss per batch # list of axes
ax2 = ax1.twiny() axes = [ax1, ax2, ax3, ax4]
[ax2.plot(v.flatten('F'), color=c, alpha=0.5)
for (k, v), c in zip(loss.items(), colors) if 'loss' in k and # plot running mean loss and accuracy of the training dataset
'validation' not in k] [ax.plot(v, color=c) for (k, v), ax, c in zip(rm.items(), axes, colors)
if v.any()]
# plot training and validation mean accuracy per epoch
ax3 = ax1.twinx()
[ax3.plot(epochs, v.mean(axis=0),
label=k.capitalize().replace('_', ' '), color=c, lw=2)
for (k, v), c in zip(loss.items(), colors) if v.any() and 'accuracy'
in k]
# plot training accuracy per batch
ax4 = ax3.twiny()
[ax4.plot(v.flatten('F'), color=c, alpha=0.5)
for (k, v), c in zip(loss.items(), colors) if 'accuracy' in k and
'validation' not in k]
# axes properties and labels # axes properties and labels
for ax in [ax2, ax4]: nbatches = loss['tl'].shape[0]
for ax in [ax3, ax4]:
ax.set(xticks=[], xticklabels=[]) ax.set(xticks=[], xticklabels=[])
ax1.set(xlabel='Epoch', ax1.set(xticks=np.arange(0, nbatches * epochs[-1] + 1, nbatches * step),
xticklabels=epochs[::step],
xlabel='Epoch',
ylabel='Loss', ylabel='Loss',
ylim=(0, 1)) ylim=(0, 1))
ax3.set(ylabel='Accuracy', ax2.set(ylabel='Accuracy',
ylim=(0, 1)) ylim=(0.5, 1))
# compute early stopping point # compute early stopping point
if loss['validation_accuracy'].any(): if loss['va'].any():
esepoch = np.argmax(loss['validation_accuracy'].mean(axis=0)) esepoch = np.argmax(loss['va'].mean(axis=0)) * nbatches
esacc = np.max(loss['validation_accuracy'].mean(axis=0)) esacc = np.max(loss['va'].mean(axis=0))
ax1.vlines(esepoch, ymin=ax1.get_ylim()[0], ymax=ax1.get_ylim()[1], ax1.vlines(esepoch, ymin=ax1.get_ylim()[0], ymax=ax1.get_ylim()[1],
ls='--', color='grey') ls='--', color='grey')
ax1.text(esepoch - 1, ax1.get_ylim()[0] + 0.01, ax1.text(esepoch - nbatches, ax1.get_ylim()[0] + 0.01,
'epoch = {}'.format(esepoch), ha='right', color='grey') 'epoch = {}, accuracy = {:.1f}%'
ax1.text(esepoch + 1, ax1.get_ylim()[0] + 0.01, .format(int(esepoch / nbatches), esacc * 100),
'acc = {:.2f}%'.format(esacc * 100), ha='left', ha='right', color='grey')
color='grey')
# create a patch (proxy artist) for every color
# add legends ulabels = ['Training loss', 'Training accuracy',
ax1.legend(frameon=False, loc='lower left') 'Validation loss', 'Validation accuracy']
ax3.legend(frameon=False, loc='upper left') patches = [mlines.Line2D([], [], color=c, label=l) for c, l in
zip(colors, ulabels)]
# plot patches as legend
ax1.legend(handles=patches, loc='lower left', frameon=False)
# save figure # save figure
os.makedirs(outpath, exist_ok=True) os.makedirs(outpath, exist_ok=True)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment