From 07b4f5ea993ee19816125dcfcbefe7cbd9ff46be Mon Sep 17 00:00:00 2001 From: "Daniel.Frisinghelli" <daniel.frisinghelli@eurac.edu> Date: Thu, 16 Jul 2020 13:22:13 +0000 Subject: [PATCH] Refactored plot_loss() function --- pytorch/graphics.py | 88 +++++++++++++++++++++++---------------------- 1 file changed, 46 insertions(+), 42 deletions(-) diff --git a/pytorch/graphics.py b/pytorch/graphics.py index d142209..166f3bf 100644 --- a/pytorch/graphics.py +++ b/pytorch/graphics.py @@ -13,6 +13,7 @@ import numpy as np import torch import matplotlib.pyplot as plt import matplotlib.patches as mpatches +import matplotlib.lines as mlines from matplotlib.colors import ListedColormap, BoundaryNorm from matplotlib import cm as colormap @@ -36,6 +37,11 @@ def contrast_stretching(image, alpha=2): return norm +def running_mean(x, w): + cumsum = np.cumsum(np.insert(x, 0, 0)) + return (cumsum[w:] - cumsum[:-w]) / w + + # plot_sample() plots a false color composite of the scene/tile together # with the model prediction and the corresponding ground truth def plot_sample(x, y, use_bands, labels, y_pred=None, figsize=(10, 10), @@ -164,8 +170,8 @@ def plot_confusion_matrix(cm, labels, normalize=True, return fig, ax -def plot_loss(loss_file, figsize=(10, 10), - colors=['lightgreen', 'skyblue', 'darkgreen', 'steelblue'], +def plot_loss(loss_file, figsize=(10, 10), step=5, + colors=['lightgreen', 'green', 'skyblue', 'steelblue'], outpath=os.path.join(os.getcwd(), '_graphics/')): # load the model loss @@ -176,60 +182,58 @@ def plot_loss(loss_file, figsize=(10, 10), loss = {k: v[np.nonzero(v)].reshape(v.shape[0], -1) for k, v in state.items() if k != 'epoch'} + # compute running mean with a window equal to the number of batches in + # an epoch + rm = {k: running_mean(v.flatten('F'), v.shape[0]) for k, v in loss.items()} + # number of epochs trained - epochs = np.arange(0, state['epoch'] + 1) + epochs = np.arange(0, loss['tl'].shape[1]) # instanciate figure fig, ax1 = plt.subplots(1, 1, figsize=figsize) - # plot training and validation mean loss per epoch - [ax1.plot(epochs, v.mean(axis=0), - label=k.capitalize().replace('_', ' '), color=c, lw=2) - for (k, v), c in zip(loss.items(), colors) if v.any() and 'loss' in k] - - # plot training loss per batch - ax2 = ax1.twiny() - [ax2.plot(v.flatten('F'), color=c, alpha=0.5) - for (k, v), c in zip(loss.items(), colors) if 'loss' in k and - 'validation' not in k] - - # plot training and validation mean accuracy per epoch - ax3 = ax1.twinx() - [ax3.plot(epochs, v.mean(axis=0), - label=k.capitalize().replace('_', ' '), color=c, lw=2) - for (k, v), c in zip(loss.items(), colors) if v.any() and 'accuracy' - in k] - - # plot training accuracy per batch - ax4 = ax3.twiny() - [ax4.plot(v.flatten('F'), color=c, alpha=0.5) - for (k, v), c in zip(loss.items(), colors) if 'accuracy' in k and - 'validation' not in k] + # create axes for each parameter to plot + ax2 = ax1.twinx() + ax3 = ax1.twiny() + ax4 = ax2.twiny() + + # list of axes + axes = [ax1, ax2, ax3, ax4] + + # plot running mean loss and accuracy of the training dataset + [ax.plot(v, color=c) for (k, v), ax, c in zip(rm.items(), axes, colors) + if v.any()] # axes properties and labels - for ax in [ax2, ax4]: + nbatches = loss['tl'].shape[0] + for ax in [ax3, ax4]: ax.set(xticks=[], xticklabels=[]) - ax1.set(xlabel='Epoch', + ax1.set(xticks=np.arange(0, nbatches * epochs[-1] + 1, nbatches * step), + xticklabels=epochs[::step], + xlabel='Epoch', ylabel='Loss', ylim=(0, 1)) - ax3.set(ylabel='Accuracy', - ylim=(0, 1)) + ax2.set(ylabel='Accuracy', + ylim=(0.5, 1)) # compute early stopping point - if loss['validation_accuracy'].any(): - esepoch = np.argmax(loss['validation_accuracy'].mean(axis=0)) - esacc = np.max(loss['validation_accuracy'].mean(axis=0)) + if loss['va'].any(): + esepoch = np.argmax(loss['va'].mean(axis=0)) * nbatches + esacc = np.max(loss['va'].mean(axis=0)) ax1.vlines(esepoch, ymin=ax1.get_ylim()[0], ymax=ax1.get_ylim()[1], ls='--', color='grey') - ax1.text(esepoch - 1, ax1.get_ylim()[0] + 0.01, - 'epoch = {}'.format(esepoch), ha='right', color='grey') - ax1.text(esepoch + 1, ax1.get_ylim()[0] + 0.01, - 'acc = {:.2f}%'.format(esacc * 100), ha='left', - color='grey') - - # add legends - ax1.legend(frameon=False, loc='lower left') - ax3.legend(frameon=False, loc='upper left') + ax1.text(esepoch - nbatches, ax1.get_ylim()[0] + 0.01, + 'epoch = {}, accuracy = {:.1f}%' + .format(int(esepoch / nbatches), esacc * 100), + ha='right', color='grey') + + # create a patch (proxy artist) for every color + ulabels = ['Training loss', 'Training accuracy', + 'Validation loss', 'Validation accuracy'] + patches = [mlines.Line2D([], [], color=c, label=l) for c, l in + zip(colors, ulabels)] + # plot patches as legend + ax1.legend(handles=patches, loc='lower left', frameon=False) # save figure os.makedirs(outpath, exist_ok=True) -- GitLab