From 07b4f5ea993ee19816125dcfcbefe7cbd9ff46be Mon Sep 17 00:00:00 2001
From: "Daniel.Frisinghelli" <daniel.frisinghelli@eurac.edu>
Date: Thu, 16 Jul 2020 13:22:13 +0000
Subject: [PATCH] Refactored plot_loss() function

---
 pytorch/graphics.py | 88 +++++++++++++++++++++++----------------------
 1 file changed, 46 insertions(+), 42 deletions(-)

diff --git a/pytorch/graphics.py b/pytorch/graphics.py
index d142209..166f3bf 100644
--- a/pytorch/graphics.py
+++ b/pytorch/graphics.py
@@ -13,6 +13,7 @@ import numpy as np
 import torch
 import matplotlib.pyplot as plt
 import matplotlib.patches as mpatches
+import matplotlib.lines as mlines
 from matplotlib.colors import ListedColormap, BoundaryNorm
 from matplotlib import cm as colormap
 
@@ -36,6 +37,11 @@ def contrast_stretching(image, alpha=2):
     return norm
 
 
+def running_mean(x, w):
+    cumsum = np.cumsum(np.insert(x, 0, 0))
+    return (cumsum[w:] - cumsum[:-w]) / w
+
+
 # plot_sample() plots a false color composite of the scene/tile together
 # with the model prediction and the corresponding ground truth
 def plot_sample(x, y, use_bands, labels, y_pred=None, figsize=(10, 10),
@@ -164,8 +170,8 @@ def plot_confusion_matrix(cm, labels, normalize=True,
     return fig, ax
 
 
-def plot_loss(loss_file, figsize=(10, 10),
-              colors=['lightgreen', 'skyblue', 'darkgreen', 'steelblue'],
+def plot_loss(loss_file, figsize=(10, 10), step=5,
+              colors=['lightgreen', 'green', 'skyblue', 'steelblue'],
               outpath=os.path.join(os.getcwd(), '_graphics/')):
 
     # load the model loss
@@ -176,60 +182,58 @@ def plot_loss(loss_file, figsize=(10, 10),
     loss = {k: v[np.nonzero(v)].reshape(v.shape[0], -1) for k, v in
             state.items() if k != 'epoch'}
 
+    # compute running mean with a window equal to the number of batches in
+    # an epoch
+    rm = {k: running_mean(v.flatten('F'), v.shape[0]) for k, v in loss.items()}
+
     # number of epochs trained
-    epochs = np.arange(0, state['epoch'] + 1)
+    epochs = np.arange(0, loss['tl'].shape[1])
 
     # instanciate figure
     fig, ax1 = plt.subplots(1, 1, figsize=figsize)
 
-    # plot training and validation mean loss per epoch
-    [ax1.plot(epochs, v.mean(axis=0),
-              label=k.capitalize().replace('_', ' '), color=c, lw=2)
-     for (k, v), c in zip(loss.items(), colors) if v.any() and 'loss' in k]
-
-    # plot training loss per batch
-    ax2 = ax1.twiny()
-    [ax2.plot(v.flatten('F'), color=c, alpha=0.5)
-     for (k, v), c in zip(loss.items(), colors) if 'loss' in k and
-     'validation' not in k]
-
-    # plot training and validation mean accuracy per epoch
-    ax3 = ax1.twinx()
-    [ax3.plot(epochs, v.mean(axis=0),
-              label=k.capitalize().replace('_', ' '), color=c, lw=2)
-     for (k, v), c in zip(loss.items(), colors) if v.any() and 'accuracy'
-     in k]
-
-    # plot training accuracy per batch
-    ax4 = ax3.twiny()
-    [ax4.plot(v.flatten('F'), color=c, alpha=0.5)
-     for (k, v), c in zip(loss.items(), colors) if 'accuracy' in k and
-     'validation' not in k]
+    # create axes for each parameter to plot
+    ax2 = ax1.twinx()
+    ax3 = ax1.twiny()
+    ax4 = ax2.twiny()
+
+    # list of axes
+    axes = [ax1, ax2, ax3, ax4]
+
+    # plot running mean loss and accuracy of the training dataset
+    [ax.plot(v, color=c) for (k, v), ax, c in zip(rm.items(), axes, colors)
+     if v.any()]
 
     # axes properties and labels
-    for ax in [ax2, ax4]:
+    nbatches = loss['tl'].shape[0]
+    for ax in [ax3, ax4]:
         ax.set(xticks=[], xticklabels=[])
-    ax1.set(xlabel='Epoch',
+    ax1.set(xticks=np.arange(0, nbatches * epochs[-1] + 1, nbatches * step),
+            xticklabels=epochs[::step],
+            xlabel='Epoch',
             ylabel='Loss',
             ylim=(0, 1))
-    ax3.set(ylabel='Accuracy',
-            ylim=(0, 1))
+    ax2.set(ylabel='Accuracy',
+            ylim=(0.5, 1))
 
     # compute early stopping point
-    if loss['validation_accuracy'].any():
-        esepoch = np.argmax(loss['validation_accuracy'].mean(axis=0))
-        esacc = np.max(loss['validation_accuracy'].mean(axis=0))
+    if loss['va'].any():
+        esepoch = np.argmax(loss['va'].mean(axis=0)) * nbatches
+        esacc = np.max(loss['va'].mean(axis=0))
         ax1.vlines(esepoch, ymin=ax1.get_ylim()[0], ymax=ax1.get_ylim()[1],
                    ls='--', color='grey')
-        ax1.text(esepoch - 1, ax1.get_ylim()[0] + 0.01,
-                 'epoch = {}'.format(esepoch), ha='right', color='grey')
-        ax1.text(esepoch + 1, ax1.get_ylim()[0] + 0.01,
-                 'acc = {:.2f}%'.format(esacc * 100), ha='left',
-                 color='grey')
-
-    # add legends
-    ax1.legend(frameon=False, loc='lower left')
-    ax3.legend(frameon=False, loc='upper left')
+        ax1.text(esepoch - nbatches, ax1.get_ylim()[0] + 0.01,
+                 'epoch = {}, accuracy = {:.1f}%'
+                 .format(int(esepoch / nbatches), esacc * 100),
+                 ha='right', color='grey')
+
+    # create a patch (proxy artist) for every color
+    ulabels = ['Training loss', 'Training accuracy',
+               'Validation loss', 'Validation accuracy']
+    patches = [mlines.Line2D([], [], color=c, label=l) for c, l in
+               zip(colors, ulabels)]
+    # plot patches as legend
+    ax1.legend(handles=patches, loc='lower left', frameon=False)
 
     # save figure
     os.makedirs(outpath, exist_ok=True)
-- 
GitLab