From 0747b10a447d664aea9e4fc8c23da96564f8f093 Mon Sep 17 00:00:00 2001
From: "Daniel.Frisinghelli" <daniel.frisinghelli@eurac.edu>
Date: Tue, 14 Jul 2020 14:17:47 +0200
Subject: [PATCH] Major code refactor: Added distinct graphics module

---
 main/eval.py        |  22 ++-
 pytorch/dataset.py  | 416 ++++++++++++--------------------------------
 pytorch/graphics.py | 240 +++++++++++++++++++++++++
 3 files changed, 369 insertions(+), 309 deletions(-)
 create mode 100644 pytorch/graphics.py

diff --git a/main/eval.py b/main/eval.py
index 9a830cb..365173f 100755
--- a/main/eval.py
+++ b/main/eval.py
@@ -13,6 +13,7 @@ sys.path.append('..')
 
 # local modules
 from pytorch.trainer import NetworkTrainer
+from pytorch.graphics import plot_confusion_matrix, plot_loss, plot_sample
 from main.config import config
 
 
@@ -31,11 +32,12 @@ if __name__ == '__main__':
               'accuracy of {:.2f}%  on the validation set!'
               .format(trainer.model.epoch, acc * 100))
 
-        # plot confusion matrix
-        trainer.dataset.plot_confusion_matrix(cm, state=trainer.state_file)
+        # plot confusion matrix: labels of the dataset
+        labels = [label['label'] for label in trainer.dataset.labels.values()]
+        plot_confusion_matrix(cm, labels, state=trainer.state_file)
 
     # plot loss and accuracy
-    trainer.dataset.plot_loss(trainer.loss_state)
+    plot_loss(trainer.loss_state)
 
     # whether to plot the samples of the validation dataset
     if trainer.plot_samples:
@@ -73,8 +75,12 @@ if __name__ == '__main__':
 
             # plot inputs, ground truth and model predictions
             sname = fname + '_sample_{}.pt'.format(sample)
-            fig, ax = trainer.dataset.plot_sample(inputs, labels, y_pred,
-                                                  bands=trainer.plot_bands,
-                                                  state=sname,
-                                                  stretch=True,
-                                                  alpha=5)
+            fig, ax = plot_sample(inputs,
+                                  labels,
+                                  trainer.dataset.use_bands,
+                                  trainer.dataset.labels,
+                                  ypred=ypred,
+                                  bands=trainer.plot_bands,
+                                  state=sname,
+                                  stretch=True,
+                                  alpha=5)
diff --git a/pytorch/dataset.py b/pytorch/dataset.py
index d8b4a01..01da32b 100644
--- a/pytorch/dataset.py
+++ b/pytorch/dataset.py
@@ -13,6 +13,8 @@ your custom dataset.
 
 # builtins
 import os
+import re
+import sys
 import csv
 import glob
 import itertools
@@ -22,11 +24,15 @@ import gdal
 import numpy as np
 import torch
 import matplotlib.pyplot as plt
-import matplotlib.patches as mpatches
-from matplotlib.colors import ListedColormap, BoundaryNorm
-from matplotlib import cm as colormap
 from torch.utils.data import Dataset
 
+# append path to local files to the python search path
+sys.path.append('..')
+
+# locals
+from pytorch.constants import (Landsat8, Sentinel2, SparcsLabels,
+                               Cloud95Labels, ProSnowLabels)
+from pytorch.graphics import plot_sample
 
 # generic image dataset class
 class ImageDataset(Dataset):
@@ -145,7 +151,6 @@ class ImageDataset(Dataset):
         raise NotImplementedError('Inherit the ImageDataset class and '
                                   'implement the method.')
 
-
     # _read_scene() reads all the bands and the ground truth mask in a
     # scene/tile to a numpy array and returns a dictionary with
     # (key, value) = ('band_name', np.ndarray(band_data))
@@ -179,6 +184,9 @@ class ImageDataset(Dataset):
     def img2np(self, path, tile_size=None, tile=None):
 
         # open the tif file
+        if path is None:
+            print('Path is of NoneType, returning.')
+            return
         img = gdal.Open(path)
 
         # check whether to read the image in tiles
@@ -298,232 +306,87 @@ class ImageDataset(Dataset):
 
         return indices
 
-    # this function applies percentile stretching at the alpha level
-    # can be used to increase constrast for visualization
-    def contrast_stretching(self, image, alpha=2):
 
-        # compute upper and lower percentiles defining the range of the stretch
-        inf, sup = np.percentile(image, (alpha, 100 - alpha))
+class StandardEoDataset(ImageDataset):
 
-        # normalize image intensity distribution to
-        # (alpha, 100 - alpha) percentiles
-        norm = ((image - inf) * (image.max() - image.min()) /
-                (sup - inf)) + image.min()
+    def __init__(self, root_dir, use_bands, tile_size):
+        # initialize super class ImageDataset
+        super().__init__(root_dir, use_bands, tile_size)
 
-        # clip: values < inf = 0, values > sup = max
-        norm[norm <= image.min()] = image.min()
-        norm[norm >= image.max()] = image.max()
+    # returns the band number of a Landsat8 or Sentinel2 tif file
+    # x: path to a tif file
+    def get_band_number(self, path):
 
-        return norm
+        # check whether the path leads to a tif file
+        if not path.endswith(('tif', 'TIF')):
+            raise ValueError('Expected a path to a tif file.')
 
-    # plot_sample() plots a false color composite of the scene/tile together
-    # with the model prediction and the corresponding ground truth
-    def plot_sample(self, x, y, y_pred=None, figsize=(10, 10),
-                    bands=['red', 'green', 'blue'], stretch=False, state=None,
-                    outpath=os.path.join(os.getcwd(), '_samples/'),  **kwargs):
+        # filename
+        fname = os.path.basename(path)
 
-        # check whether to apply constrast stretching
-        stretch = True if kwargs else False
-        func = self.contrast_stretching if stretch else lambda x: x
+        # search for numbers following a "B" in the filename
+        band = re.search('B\dA|B\d{1,2}', fname)[0].replace('B', '')
 
-        # create an rgb stack
-        rgb = np.dstack([func(x[self.use_bands.index(band)],
-                              **kwargs) for band in bands])
+        # try converting to an integer:
+        # raises a ValueError for Sentinel2 8A band
+        try:
+            band = int(band)
+        except ValueError:
+            pass
 
-        # get labels and corresponding colors
-        labels = [label['label'] for label in self.labels.values()]
-        colors = [label['color'] for label in self.labels.values()]
+        return band
 
-        # create a ListedColormap
-        cmap = ListedColormap(colors)
-        boundaries = [*self.labels.keys(), cmap.N]
-        norm = BoundaryNorm(boundaries, cmap.N)
+    # _store_bands() writes the paths to the data of each scene to a dictionary
+    # only the bands of interest are stored
+    def store_bands(self, bands, gt):
 
-        # create figure: check whether to plot model prediction
-        if y_pred is not None:
+        # store the bands of interest in a dictionary
+        scene_data = {}
+        for i, b in enumerate(bands):
+            band = self.bands[self.get_band_number(b)]
+            if band in self.use_bands:
+                scene_data[band] = b
 
-            # compute accuracy
-            acc = (y_pred == y).float().mean()
+        # store ground truth
+        scene_data['gt'] = gt
 
-            # plot model prediction
-            fig, ax = plt.subplots(1, 3, figsize=figsize)
-            ax[2].imshow(y_pred, cmap=cmap, interpolation='nearest', norm=norm)
-            ax[2].set_title('Prediction ({:.2f}%)'.format(acc * 100), pad=15)
+        return scene_data
 
-        else:
-            fig, ax = plt.subplots(1, 2, figsize=figsize)
+    # compose_scenes() creates a list of dictionaries containing the paths
+    # to the tif files of each scene
+    # if the scenes are divided into tiles, each tile has its own entry
+    # with corresponding tile id
+    def compose_scenes(self, pattern='*mask.png'):
 
-        # plot false color composite
-        ax[0].imshow(rgb)
-        ax[0].set_title('R = {}, G = {}, B = {}'.format(*bands), pad=15)
+        # list of all samples in the dataset
+        scenes = []
+        for scene in os.listdir(self.root):
 
-        # plot ground thruth mask
-        ax[1].imshow(y, cmap=cmap, interpolation='nearest', norm=norm)
-        ax[1].set_title('Ground truth', pad=15)
+            # list the spectral bands of the scene
+            bands = glob.glob(os.path.join(self.root, scene, '*B*.tif'))
 
-        # create a patch (proxy artist) for every color
-        patches = [mpatches.Patch(color=c, label=l) for c, l in
-                   zip(colors, labels)]
+            # get the ground truth mask
+            try:
+                gt = glob.glob(os.path.join(self.root, scene, pattern)).pop()
+            except IndexError:
+                gt = None
 
-        # plot patches as legend
-        plt.legend(handles=patches, bbox_to_anchor=(1.05, 1), loc=2,
-                   frameon=False)
+            # iterate over the tiles
+            for tile in range(self.tiles):
 
-        # save figure
-        if state is not None:
-            os.makedirs(outpath, exist_ok=True)
-            fig.savefig(os.path.join(outpath, state.replace('.pt', '.png')),
-                        dpi=300, bbox_inches='tight')
+                # store the bands and the ground truth mask of the tile
+                data = self.store_bands(bands, gt)
 
-        return fig, ax
+                # store tile number
+                data['tile'] = tile
 
-    # plot_confusion_matrix() plots the confusion matrix of the validation/test
-    # set returned by the pytorch.predict function
-    def plot_confusion_matrix(self, cm, labels=None, normalize=True,
-                              figsize=(10, 10), cmap='Blues', state=None,
-                              outpath=os.path.join(os.getcwd(), '_graphics/')):
-
-        # check if labels are provided
-        if labels is None:
-            labels = [label['label'] for _, label in self.labels.items()]
-
-        # number of classes
-        nclasses = len(labels)
-
-        # string format to plot values of confusion matrix
-        fmt = '.0f'
-
-        # minimum and maximum values of the colorbar
-        vmin, vmax = 0, cm.max()
-
-        # check whether to normalize the confusion matrix
-        if normalize:
-            # normalize
-            cm = cm / cm.sum(axis=1, keepdims=True)
-
-            # change string format to floating point
-            fmt = '.2f'
-            vmin, vmax= 0, 1
-
-        # create figure
-        fig, ax = plt.subplots(1, 1, figsize=figsize)
-
-        # get colormap
-        cmap = colormap.get_cmap(cmap, 256)
-
-        # plot confusion matrix
-        im = ax.imshow(cm, cmap=cmap, vmin=vmin, vmax=vmax)
-
-        # threshold determining the color of the values
-        thresh = (cm.max() + cm.min()) / 2
-
-        # brightest/darkest color of current colormap
-        cmap_min, cmap_max = cmap(0), cmap(256)
-
-        # plot values of confusion matrix
-        for i, j in itertools.product(range(nclasses), range(nclasses)):
-            ax.text(j, i, format(cm[i, j], fmt), ha='center', va='center',
-                    color = cmap_max if cm[i, j] < thresh else cmap_min)
-
-        # axes properties and labels
-        ax.set(xticks=np.arange(nclasses),
-               yticks=np.arange(nclasses),
-               xticklabels=labels,
-               yticklabels=labels,
-               ylabel='True',
-               xlabel='Predicted')
-
-        # add colorbar axes
-        cax = fig.add_axes([ax.get_position().x1 + 0.025, ax.get_position().y0,
-                            0.05, ax.get_position().y1 - ax.get_position().y0])
-        fig.colorbar(im, cax=cax)
-
-        # save figure
-        if state is not None:
-            os.makedirs(outpath, exist_ok=True)
-            fig.savefig(os.path.join(outpath, state.replace('.pt', '_cm.png')),
-                        dpi=300, bbox_inches='tight')
-
-        return fig, ax
-
-    def plot_loss(self, loss_file, figsize=(10, 10),
-                  colors=['lightgreen', 'skyblue', 'darkgreen', 'steelblue'],
-                  outpath=os.path.join(os.getcwd(), '_graphics/')):
-
-        # load the model loss
-        state = torch.load(loss_file)
-
-        # get all non-zero elements, i.e. get number of epochs trained before
-        # early stop
-        loss = {k: v[np.nonzero(v)].reshape(v.shape[0], -1) for k, v in
-                state.items() if k != 'epoch'}
-
-        # number of epochs trained
-        epochs = np.arange(0, state['epoch'] + 1)
-
-        # instanciate figure
-        fig, ax1 = plt.subplots(1, 1, figsize=figsize)
-
-        # plot training and validation mean loss per epoch
-        [ax1.plot(epochs, v.mean(axis=0),
-                  label=k.capitalize().replace('_', ' '), color=c, lw=2)
-         for (k, v), c in zip(loss.items(), colors) if v.any() and 'loss' in k]
-
-        # plot training loss per batch
-        ax2 = ax1.twiny()
-        [ax2.plot(v.flatten('F'), color=c, alpha=0.5)
-         for (k, v), c in zip(loss.items(), colors) if 'loss' in k and
-         'validation' not in k]
-
-        # plot training and validation mean accuracy per epoch
-        ax3 = ax1.twinx()
-        [ax3.plot(epochs, v.mean(axis=0),
-                  label=k.capitalize().replace('_', ' '), color=c, lw=2)
-         for (k, v), c in zip(loss.items(), colors) if v.any() and 'accuracy'
-         in k]
-
-        # plot training accuracy per batch
-        ax4 = ax3.twiny()
-        [ax4.plot(v.flatten('F'), color=c, alpha=0.5)
-         for (k, v), c in zip(loss.items(), colors) if 'accuracy' in k and
-         'validation' not in k]
-
-        # axes properties and labels
-        for ax in [ax2, ax4]:
-            ax.set(xticks=[], xticklabels=[])
-        ax1.set(xlabel='Epoch',
-                ylabel='Loss',
-                ylim=(0, 1))
-        ax3.set(ylabel='Accuracy',
-                ylim=(0, 1))
-
-        # compute early stopping point
-        if loss['validation_accuracy'].any():
-            esepoch = np.argmax(loss['validation_accuracy'].mean(axis=0))
-            esacc = np.max(loss['validation_accuracy'].mean(axis=0))
-            ax1.vlines(esepoch, ymin=ax1.get_ylim()[0], ymax=ax1.get_ylim()[1],
-                       ls='--', color='grey')
-            ax1.text(esepoch - 1, ax1.get_ylim()[0] + 0.01,
-                     'epoch = {}'.format(esepoch), ha='right', color='grey')
-            ax1.text(esepoch + 1, ax1.get_ylim()[0] + 0.01,
-                     'acc = {:.2f}%'.format(esacc * 100), ha='left',
-                     color='grey')
-
-        # add legends
-        ax1.legend(frameon=False, loc='lower left')
-        ax3.legend(frameon=False, loc='upper left')
-
-        # save figure
-        os.makedirs(outpath, exist_ok=True)
-        fig.savefig(os.path.join(
-            outpath, os.path.basename(loss_file).replace('.pt', '.png')),
-                    dpi=300, bbox_inches='tight')
-
-        return fig, ax
+                # append to list
+                scenes.append(data)
 
+        return scenes
 
 # SparcsDataset class: inherits from the generic ImageDataset class
-class SparcsDataset(ImageDataset):
+class SparcsDataset(StandardEoDataset):
 
     def __init__(self, root_dir, use_bands=['red', 'green', 'blue'],
                  tile_size=None):
@@ -532,9 +395,7 @@ class SparcsDataset(ImageDataset):
 
         # list of all scenes in the root directory
         # each scene is divided into tiles blocks
-        self.scenes = []
-        for scene in os.listdir(self.root):
-            self.scenes += self.compose_scenes(os.path.join(self.root, scene))
+        self.scenes = self.compose_scenes()
 
     # image size of the Sparcs dataset: (height, width)
     def get_size(self):
@@ -542,28 +403,13 @@ class SparcsDataset(ImageDataset):
 
     # Landsat 8 bands of the Sparcs dataset
     def get_bands(self):
-        return {
-            1: 'violet',
-            2: 'blue',
-            3: 'green',
-            4: 'red',
-            5: 'nir',
-            6: 'swir1',
-            7: 'swir2',
-            8: 'pan',
-            9: 'cirrus',
-            10: 'tir'}
+        return {band.value: band.name for band in Landsat8}
 
     # class labels of the Sparcs dataset
     def get_labels(self):
-        labels = ['Shadow', 'Shadow over Water', 'Water', 'Snow', 'Land',
-                  'Cloud', 'Flooded']
-        colors = ['black', 'darkblue', 'blue', 'lightblue', 'grey', 'white',
-                  'yellow']
-        lc = {}
-        for i, (l, c) in enumerate(zip(labels, colors)):
-            lc[i] = {'label': l, 'color': c}
-        return lc
+        return {band.value[0]: {'label': band.name.replace('_', ' '),
+                                'color': band.value[1]}
+                for band in SparcsLabels}
 
     # preprocessing of the Sparcs dataset
     def preprocess(self, data, gt):
@@ -572,60 +418,45 @@ class SparcsDataset(ImageDataset):
 
         # convert to torch tensors
         x = torch.tensor(data, dtype=torch.float32)
-        y = torch.tensor(gt, dtype=torch.uint8)
+        y = torch.tensor(gt, dtype=torch.uint8) if gt is not None else gt
         return x, y
 
-    # _compose_scenes() creates a list of dictionaries containing the paths
-    # to the files of each scene
-    # if the scenes are divided into tiles, each tile has its own entry
-    # with corresponding tile id
-    def compose_scenes(self, scene):
-
-        # list the spectral bands of the scene
-        bands = glob.glob(os.path.join(scene, '*B*.tif'))
 
-        # sort the bands in ascending order
-        bands.sort(key=self._get_band_number)
+class ProSnowDataset(StandardEoDataset):
 
-        # get the ground truth mask
-        gt = glob.glob(os.path.join(scene, '*mask.png')).pop()
+    def __init__(self, root_dir, use_bands, tile_size):
+        super().__init__(root_dir, use_bands, tile_size)
 
-        # create an entry for each scene/tile
-        scene_data = []
+        # list of samples in the dataset
+        self.scenes = self.compose_scenes()
 
-        # iterate over the tiles
-        for tile in range(self.tiles):
+    # Sentinel 2 bands
+    def get_bands(self):
+        return {band.value: band.name for band in Sentinel2}
 
-            # store the bands and the ground truth mask of the tile
-            data = self._store_bands(bands, gt)
+    # class labels of the ProSnow dataset
+    def get_labels(self):
+        return {band.value[0]: {'label': band.name, 'color': band.value[1]}
+               for band in ProSnowLabels}
 
-            # store tile number
-            data['tile'] = tile
+    # preprocessing of the ProSnow dataset
+    def preprocess(self, data, gt):
 
-            # append to list
-            scene_data.append(data)
+        # if the preprocessing is not done externally, implement it here
 
-        return scene_data
+        # convert to torch tensors
+        x = torch.tensor(data, dtype=torch.float32)
+        y = torch.tensor(gt, dtype=torch.uint8) if gt is not None else gt
+        return x, y
 
-    # returns the band number of the preprocessed Sparcs Tiff files
-    def _get_band_number(self, x):
-        return int(os.path.basename(x).split('_')[2].replace('B', ''))
 
-    # _store_bands() writes the paths to the data of each scene to a dictionary
-    # only the bands of interest are stored
-    def _store_bands(self, bands, gt):
+class ProSnowGarmisch(ProSnowDataset):
 
-        # store the bands of interest in a dictionary
-        scene_data = {}
-        for i, b in enumerate(bands):
-            band = self.bands[self._get_band_number(b)]
-            if band in self.use_bands:
-                scene_data[band] = b
-
-        # store ground truth
-        scene_data['gt'] = gt
+    def __init__(self, root_dir, use_bands=[], tile_size=None):
+        super().__init__(root_dir, use_bands, tile_size)
 
-        return scene_data
+    def get_size(self):
+        return (615, 543)
 
 
 class Cloud95Dataset(ImageDataset):
@@ -641,7 +472,7 @@ class Cloud95Dataset(ImageDataset):
 
         # list of all scenes in the root directory
         # each scene is divided into tiles blocks
-        self.scenes = self.compose_scenes(self.root)
+        self.scenes = self.compose_scenes()
 
     # image size of the Cloud-95 dataset: (height, width)
     def get_size(self):
@@ -649,12 +480,12 @@ class Cloud95Dataset(ImageDataset):
 
     # Landsat 8 bands in the Cloud-95 dataset
     def get_bands(self):
-        return {1: 'red', 2: 'green', 3: 'blue', 4: 'nir'}
+        return {band.value: band.name for band in Landsat8}
 
     # class labels of the Cloud-95 dataset
     def get_labels(self):
-        return {0: {'label': 'Clear', 'color': 'skyblue'},
-                1: {'label': 'Cloud', 'color': 'white'}}
+        return {band.value[0]: {'label': band.name, 'color': band.value[1]}
+                for band in Cloud95Labels}
 
     # preprocess Cloud-95 dataset
     def preprocess(self, data, gt):
@@ -667,8 +498,7 @@ class Cloud95Dataset(ImageDataset):
 
         return x, y
 
-
-    def compose_scenes(self, root_dir):
+    def compose_scenes(self):
 
         # whether to exclude patches with more than 80% black pixels
         ipatches = []
@@ -678,10 +508,10 @@ class Cloud95Dataset(ImageDataset):
                 # list of informative patches
                 ipatches = list(itertools.chain.from_iterable(reader))
 
-        # get the names of the directories containing the TIFF files of
+        # get the names of the directories containing the tif files of
         # the bands of interest
         band_dirs = {}
-        for dirpath, dirname, files in os.walk(root_dir):
+        for dirpath, dirname, files in os.walk(self.root):
             # check if the current directory path includes the name of a band
             # or the name of the ground truth mask
             cband = [band for band in self.use_bands + ['gt'] if
@@ -713,7 +543,7 @@ class Cloud95Dataset(ImageDataset):
 
                 # iterate over the bands of interest
                 for band in band_dirs.keys():
-                    # save path to current band TIFF file to dictionary
+                    # save path to current band tif file to dictionary
                     scene[band] = os.path.join(band_dirs[band],
                                                file.replace(biter, band))
 
@@ -730,8 +560,8 @@ if __name__ == '__main__':
 
     # define path to working directory
     # wd = '//projectdata.eurac.edu/projects/cci_snow/dfrisinghelli/'
-    wd = '/mnt/CEPH_PROJECTS/cci_snow/dfrisinghelli'
-    # wd = 'C:/Eurac/2020/'
+    # wd = '/mnt/CEPH_PROJECTS/cci_snow/dfrisinghelli'
+    wd = 'C:/Eurac/2020/'
 
     # path to the preprocessed sparcs dataset
     sparcs_path = os.path.join(wd, '_Datasets/Sparcs')
@@ -739,6 +569,9 @@ if __name__ == '__main__':
     # path to the Cloud-95 dataset
     cloud_path = os.path.join(wd, '_Datasets/Cloud95/Training')
 
+    # path to the ProSnow dataset
+    prosnow_path = os.path.join(wd, '_Datasets/ProSnow/Garmisch')
+
     # the csv file containing the names of the informative patches
     patches = 'training_patches_95-cloud_nonempty.csv'
 
@@ -749,24 +582,5 @@ if __name__ == '__main__':
     sparcs_dataset = SparcsDataset(sparcs_path, tile_size=None,
                                    use_bands=['nir', 'red', 'green'])
 
-    # a sample from the sparcs dataset
-    sample_s = np.random.randint(len(sparcs_dataset), size=1).item()
-    s_x, s_y = sparcs_dataset[sample_s]
-    fig, ax = sparcs_dataset.plot_sample(s_x, s_y,
-                                         bands=['nir', 'red', 'green'])
-
-    # a sample from the cloud dataset
-    sample_c = np.random.randint(len(cloud_dataset), size=1).item()
-    c_x, c_y = cloud_dataset[sample_c]
-    fig, ax = cloud_dataset.plot_sample(c_x, c_y,
-                                        bands=['nir', 'red', 'green'])
-
-    # print shape of the sample
-    for i, l, d in zip([s_x, c_x], [s_y, c_y],
-                       [sparcs_dataset, cloud_dataset]):
-        print('A sample from the {}:'.format(d.__class__.__name__))
-        print('Shape of input data: {}'.format(i.shape))
-        print('Shape of ground truth: {}'.format(l.shape))
-
-    # show figures
-    plt.show()
+    # instanciate the ProSnow class
+    prosnow_dataset = ProSnowGarmisch(prosnow_path)
diff --git a/pytorch/graphics.py b/pytorch/graphics.py
new file mode 100644
index 0000000..d142209
--- /dev/null
+++ b/pytorch/graphics.py
@@ -0,0 +1,240 @@
+# -*- coding: utf-8 -*-
+"""
+Created on Tue Jul 14 11:04:27 2020
+
+@author: Daniel
+"""
+# builtins
+import os
+import itertools
+
+# externals
+import numpy as np
+import torch
+import matplotlib.pyplot as plt
+import matplotlib.patches as mpatches
+from matplotlib.colors import ListedColormap, BoundaryNorm
+from matplotlib import cm as colormap
+
+
+# this function applies percentile stretching at the alpha level
+# can be used to increase constrast for visualization
+def contrast_stretching(image, alpha=2):
+
+    # compute upper and lower percentiles defining the range of the stretch
+    inf, sup = np.percentile(image, (alpha, 100 - alpha))
+
+    # normalize image intensity distribution to
+    # (alpha, 100 - alpha) percentiles
+    norm = ((image - inf) * (image.max() - image.min()) /
+            (sup - inf)) + image.min()
+
+    # clip: values < inf = 0, values > sup = max
+    norm[norm <= image.min()] = image.min()
+    norm[norm >= image.max()] = image.max()
+
+    return norm
+
+
+# plot_sample() plots a false color composite of the scene/tile together
+# with the model prediction and the corresponding ground truth
+def plot_sample(x, y, use_bands, labels, y_pred=None, figsize=(10, 10),
+                bands=['red', 'green', 'blue'], stretch=False, state=None,
+                outpath=os.path.join(os.getcwd(), '_samples/'),  **kwargs):
+
+    # check whether to apply constrast stretching
+    stretch = True if kwargs else False
+    func = contrast_stretching if stretch else lambda x: x
+
+    # create an rgb stack
+    rgb = np.dstack([func(x[use_bands.index(band)], **kwargs)
+                     for band in bands])
+
+    # get labels and corresponding colors
+    ulabels = [label['label'] for label in labels.values()]
+    colors = [label['color'] for label in labels.values()]
+
+    # create a ListedColormap
+    cmap = ListedColormap(colors)
+    boundaries = [*labels.keys(), cmap.N]
+    norm = BoundaryNorm(boundaries, cmap.N)
+
+    # create figure: check whether to plot model prediction
+    if y_pred is not None:
+
+        # compute accuracy
+        acc = (y_pred == y).float().mean()
+
+        # plot model prediction
+        fig, ax = plt.subplots(1, 3, figsize=figsize)
+        ax[2].imshow(y_pred, cmap=cmap, interpolation='nearest', norm=norm)
+        ax[2].set_title('Prediction ({:.2f}%)'.format(acc * 100), pad=15)
+
+    else:
+        fig, ax = plt.subplots(1, 2, figsize=figsize)
+
+    # plot false color composite
+    ax[0].imshow(rgb)
+    ax[0].set_title('R = {}, G = {}, B = {}'.format(*bands), pad=15)
+
+    # plot ground thruth mask
+    ax[1].imshow(y, cmap=cmap, interpolation='nearest', norm=norm)
+    ax[1].set_title('Ground truth', pad=15)
+
+    # create a patch (proxy artist) for every color
+    patches = [mpatches.Patch(color=c, label=l) for c, l in
+               zip(colors, ulabels)]
+
+    # plot patches as legend
+    plt.legend(handles=patches, bbox_to_anchor=(1.05, 1), loc=2,
+               frameon=False)
+
+    # save figure
+    if state is not None:
+        os.makedirs(outpath, exist_ok=True)
+        fig.savefig(os.path.join(outpath, state.replace('.pt', '.png')),
+                    dpi=300, bbox_inches='tight')
+
+    return fig, ax
+
+
+# plot_confusion_matrix() plots the confusion matrix of the validation/test
+# set returned by the pytorch.predict function
+def plot_confusion_matrix(cm, labels, normalize=True,
+                          figsize=(10, 10), cmap='Blues', state=None,
+                          outpath=os.path.join(os.getcwd(), '_graphics/')):
+
+    # number of classes
+    nclasses = len(labels)
+
+    # string format to plot values of confusion matrix
+    fmt = '.0f'
+
+    # minimum and maximum values of the colorbar
+    vmin, vmax = 0, cm.max()
+
+    # check whether to normalize the confusion matrix
+    if normalize:
+        # normalize
+        cm = cm / cm.sum(axis=1, keepdims=True)
+
+        # change string format to floating point
+        fmt = '.2f'
+        vmin, vmax = 0, 1
+
+    # create figure
+    fig, ax = plt.subplots(1, 1, figsize=figsize)
+
+    # get colormap
+    cmap = colormap.get_cmap(cmap, 256)
+
+    # plot confusion matrix
+    im = ax.imshow(cm, cmap=cmap, vmin=vmin, vmax=vmax)
+
+    # threshold determining the color of the values
+    thresh = (cm.max() + cm.min()) / 2
+
+    # brightest/darkest color of current colormap
+    cmap_min, cmap_max = cmap(0), cmap(256)
+
+    # plot values of confusion matrix
+    for i, j in itertools.product(range(nclasses), range(nclasses)):
+        ax.text(j, i, format(cm[i, j], fmt), ha='center', va='center',
+                color=cmap_max if cm[i, j] < thresh else cmap_min)
+
+    # axes properties and labels
+    ax.set(xticks=np.arange(nclasses),
+           yticks=np.arange(nclasses),
+           xticklabels=labels,
+           yticklabels=labels,
+           ylabel='True',
+           xlabel='Predicted')
+
+    # add colorbar axes
+    cax = fig.add_axes([ax.get_position().x1 + 0.025, ax.get_position().y0,
+                        0.05, ax.get_position().y1 - ax.get_position().y0])
+    fig.colorbar(im, cax=cax)
+
+    # save figure
+    if state is not None:
+        os.makedirs(outpath, exist_ok=True)
+        fig.savefig(os.path.join(outpath, state.replace('.pt', '_cm.png')),
+                    dpi=300, bbox_inches='tight')
+
+    return fig, ax
+
+
+def plot_loss(loss_file, figsize=(10, 10),
+              colors=['lightgreen', 'skyblue', 'darkgreen', 'steelblue'],
+              outpath=os.path.join(os.getcwd(), '_graphics/')):
+
+    # load the model loss
+    state = torch.load(loss_file)
+
+    # get all non-zero elements, i.e. get number of epochs trained before
+    # early stop
+    loss = {k: v[np.nonzero(v)].reshape(v.shape[0], -1) for k, v in
+            state.items() if k != 'epoch'}
+
+    # number of epochs trained
+    epochs = np.arange(0, state['epoch'] + 1)
+
+    # instanciate figure
+    fig, ax1 = plt.subplots(1, 1, figsize=figsize)
+
+    # plot training and validation mean loss per epoch
+    [ax1.plot(epochs, v.mean(axis=0),
+              label=k.capitalize().replace('_', ' '), color=c, lw=2)
+     for (k, v), c in zip(loss.items(), colors) if v.any() and 'loss' in k]
+
+    # plot training loss per batch
+    ax2 = ax1.twiny()
+    [ax2.plot(v.flatten('F'), color=c, alpha=0.5)
+     for (k, v), c in zip(loss.items(), colors) if 'loss' in k and
+     'validation' not in k]
+
+    # plot training and validation mean accuracy per epoch
+    ax3 = ax1.twinx()
+    [ax3.plot(epochs, v.mean(axis=0),
+              label=k.capitalize().replace('_', ' '), color=c, lw=2)
+     for (k, v), c in zip(loss.items(), colors) if v.any() and 'accuracy'
+     in k]
+
+    # plot training accuracy per batch
+    ax4 = ax3.twiny()
+    [ax4.plot(v.flatten('F'), color=c, alpha=0.5)
+     for (k, v), c in zip(loss.items(), colors) if 'accuracy' in k and
+     'validation' not in k]
+
+    # axes properties and labels
+    for ax in [ax2, ax4]:
+        ax.set(xticks=[], xticklabels=[])
+    ax1.set(xlabel='Epoch',
+            ylabel='Loss',
+            ylim=(0, 1))
+    ax3.set(ylabel='Accuracy',
+            ylim=(0, 1))
+
+    # compute early stopping point
+    if loss['validation_accuracy'].any():
+        esepoch = np.argmax(loss['validation_accuracy'].mean(axis=0))
+        esacc = np.max(loss['validation_accuracy'].mean(axis=0))
+        ax1.vlines(esepoch, ymin=ax1.get_ylim()[0], ymax=ax1.get_ylim()[1],
+                   ls='--', color='grey')
+        ax1.text(esepoch - 1, ax1.get_ylim()[0] + 0.01,
+                 'epoch = {}'.format(esepoch), ha='right', color='grey')
+        ax1.text(esepoch + 1, ax1.get_ylim()[0] + 0.01,
+                 'acc = {:.2f}%'.format(esacc * 100), ha='left',
+                 color='grey')
+
+    # add legends
+    ax1.legend(frameon=False, loc='lower left')
+    ax3.legend(frameon=False, loc='upper left')
+
+    # save figure
+    os.makedirs(outpath, exist_ok=True)
+    fig.savefig(os.path.join(
+        outpath, os.path.basename(loss_file).replace('.pt', '.png')),
+                dpi=300, bbox_inches='tight')
+
+    return fig, ax
-- 
GitLab