From 0747b10a447d664aea9e4fc8c23da96564f8f093 Mon Sep 17 00:00:00 2001 From: "Daniel.Frisinghelli" <daniel.frisinghelli@eurac.edu> Date: Tue, 14 Jul 2020 14:17:47 +0200 Subject: [PATCH] Major code refactor: Added distinct graphics module --- main/eval.py | 22 ++- pytorch/dataset.py | 416 ++++++++++++-------------------------------- pytorch/graphics.py | 240 +++++++++++++++++++++++++ 3 files changed, 369 insertions(+), 309 deletions(-) create mode 100644 pytorch/graphics.py diff --git a/main/eval.py b/main/eval.py index 9a830cb..365173f 100755 --- a/main/eval.py +++ b/main/eval.py @@ -13,6 +13,7 @@ sys.path.append('..') # local modules from pytorch.trainer import NetworkTrainer +from pytorch.graphics import plot_confusion_matrix, plot_loss, plot_sample from main.config import config @@ -31,11 +32,12 @@ if __name__ == '__main__': 'accuracy of {:.2f}% on the validation set!' .format(trainer.model.epoch, acc * 100)) - # plot confusion matrix - trainer.dataset.plot_confusion_matrix(cm, state=trainer.state_file) + # plot confusion matrix: labels of the dataset + labels = [label['label'] for label in trainer.dataset.labels.values()] + plot_confusion_matrix(cm, labels, state=trainer.state_file) # plot loss and accuracy - trainer.dataset.plot_loss(trainer.loss_state) + plot_loss(trainer.loss_state) # whether to plot the samples of the validation dataset if trainer.plot_samples: @@ -73,8 +75,12 @@ if __name__ == '__main__': # plot inputs, ground truth and model predictions sname = fname + '_sample_{}.pt'.format(sample) - fig, ax = trainer.dataset.plot_sample(inputs, labels, y_pred, - bands=trainer.plot_bands, - state=sname, - stretch=True, - alpha=5) + fig, ax = plot_sample(inputs, + labels, + trainer.dataset.use_bands, + trainer.dataset.labels, + ypred=ypred, + bands=trainer.plot_bands, + state=sname, + stretch=True, + alpha=5) diff --git a/pytorch/dataset.py b/pytorch/dataset.py index d8b4a01..01da32b 100644 --- a/pytorch/dataset.py +++ b/pytorch/dataset.py @@ -13,6 +13,8 @@ your custom dataset. # builtins import os +import re +import sys import csv import glob import itertools @@ -22,11 +24,15 @@ import gdal import numpy as np import torch import matplotlib.pyplot as plt -import matplotlib.patches as mpatches -from matplotlib.colors import ListedColormap, BoundaryNorm -from matplotlib import cm as colormap from torch.utils.data import Dataset +# append path to local files to the python search path +sys.path.append('..') + +# locals +from pytorch.constants import (Landsat8, Sentinel2, SparcsLabels, + Cloud95Labels, ProSnowLabels) +from pytorch.graphics import plot_sample # generic image dataset class class ImageDataset(Dataset): @@ -145,7 +151,6 @@ class ImageDataset(Dataset): raise NotImplementedError('Inherit the ImageDataset class and ' 'implement the method.') - # _read_scene() reads all the bands and the ground truth mask in a # scene/tile to a numpy array and returns a dictionary with # (key, value) = ('band_name', np.ndarray(band_data)) @@ -179,6 +184,9 @@ class ImageDataset(Dataset): def img2np(self, path, tile_size=None, tile=None): # open the tif file + if path is None: + print('Path is of NoneType, returning.') + return img = gdal.Open(path) # check whether to read the image in tiles @@ -298,232 +306,87 @@ class ImageDataset(Dataset): return indices - # this function applies percentile stretching at the alpha level - # can be used to increase constrast for visualization - def contrast_stretching(self, image, alpha=2): - # compute upper and lower percentiles defining the range of the stretch - inf, sup = np.percentile(image, (alpha, 100 - alpha)) +class StandardEoDataset(ImageDataset): - # normalize image intensity distribution to - # (alpha, 100 - alpha) percentiles - norm = ((image - inf) * (image.max() - image.min()) / - (sup - inf)) + image.min() + def __init__(self, root_dir, use_bands, tile_size): + # initialize super class ImageDataset + super().__init__(root_dir, use_bands, tile_size) - # clip: values < inf = 0, values > sup = max - norm[norm <= image.min()] = image.min() - norm[norm >= image.max()] = image.max() + # returns the band number of a Landsat8 or Sentinel2 tif file + # x: path to a tif file + def get_band_number(self, path): - return norm + # check whether the path leads to a tif file + if not path.endswith(('tif', 'TIF')): + raise ValueError('Expected a path to a tif file.') - # plot_sample() plots a false color composite of the scene/tile together - # with the model prediction and the corresponding ground truth - def plot_sample(self, x, y, y_pred=None, figsize=(10, 10), - bands=['red', 'green', 'blue'], stretch=False, state=None, - outpath=os.path.join(os.getcwd(), '_samples/'), **kwargs): + # filename + fname = os.path.basename(path) - # check whether to apply constrast stretching - stretch = True if kwargs else False - func = self.contrast_stretching if stretch else lambda x: x + # search for numbers following a "B" in the filename + band = re.search('B\dA|B\d{1,2}', fname)[0].replace('B', '') - # create an rgb stack - rgb = np.dstack([func(x[self.use_bands.index(band)], - **kwargs) for band in bands]) + # try converting to an integer: + # raises a ValueError for Sentinel2 8A band + try: + band = int(band) + except ValueError: + pass - # get labels and corresponding colors - labels = [label['label'] for label in self.labels.values()] - colors = [label['color'] for label in self.labels.values()] + return band - # create a ListedColormap - cmap = ListedColormap(colors) - boundaries = [*self.labels.keys(), cmap.N] - norm = BoundaryNorm(boundaries, cmap.N) + # _store_bands() writes the paths to the data of each scene to a dictionary + # only the bands of interest are stored + def store_bands(self, bands, gt): - # create figure: check whether to plot model prediction - if y_pred is not None: + # store the bands of interest in a dictionary + scene_data = {} + for i, b in enumerate(bands): + band = self.bands[self.get_band_number(b)] + if band in self.use_bands: + scene_data[band] = b - # compute accuracy - acc = (y_pred == y).float().mean() + # store ground truth + scene_data['gt'] = gt - # plot model prediction - fig, ax = plt.subplots(1, 3, figsize=figsize) - ax[2].imshow(y_pred, cmap=cmap, interpolation='nearest', norm=norm) - ax[2].set_title('Prediction ({:.2f}%)'.format(acc * 100), pad=15) + return scene_data - else: - fig, ax = plt.subplots(1, 2, figsize=figsize) + # compose_scenes() creates a list of dictionaries containing the paths + # to the tif files of each scene + # if the scenes are divided into tiles, each tile has its own entry + # with corresponding tile id + def compose_scenes(self, pattern='*mask.png'): - # plot false color composite - ax[0].imshow(rgb) - ax[0].set_title('R = {}, G = {}, B = {}'.format(*bands), pad=15) + # list of all samples in the dataset + scenes = [] + for scene in os.listdir(self.root): - # plot ground thruth mask - ax[1].imshow(y, cmap=cmap, interpolation='nearest', norm=norm) - ax[1].set_title('Ground truth', pad=15) + # list the spectral bands of the scene + bands = glob.glob(os.path.join(self.root, scene, '*B*.tif')) - # create a patch (proxy artist) for every color - patches = [mpatches.Patch(color=c, label=l) for c, l in - zip(colors, labels)] + # get the ground truth mask + try: + gt = glob.glob(os.path.join(self.root, scene, pattern)).pop() + except IndexError: + gt = None - # plot patches as legend - plt.legend(handles=patches, bbox_to_anchor=(1.05, 1), loc=2, - frameon=False) + # iterate over the tiles + for tile in range(self.tiles): - # save figure - if state is not None: - os.makedirs(outpath, exist_ok=True) - fig.savefig(os.path.join(outpath, state.replace('.pt', '.png')), - dpi=300, bbox_inches='tight') + # store the bands and the ground truth mask of the tile + data = self.store_bands(bands, gt) - return fig, ax + # store tile number + data['tile'] = tile - # plot_confusion_matrix() plots the confusion matrix of the validation/test - # set returned by the pytorch.predict function - def plot_confusion_matrix(self, cm, labels=None, normalize=True, - figsize=(10, 10), cmap='Blues', state=None, - outpath=os.path.join(os.getcwd(), '_graphics/')): - - # check if labels are provided - if labels is None: - labels = [label['label'] for _, label in self.labels.items()] - - # number of classes - nclasses = len(labels) - - # string format to plot values of confusion matrix - fmt = '.0f' - - # minimum and maximum values of the colorbar - vmin, vmax = 0, cm.max() - - # check whether to normalize the confusion matrix - if normalize: - # normalize - cm = cm / cm.sum(axis=1, keepdims=True) - - # change string format to floating point - fmt = '.2f' - vmin, vmax= 0, 1 - - # create figure - fig, ax = plt.subplots(1, 1, figsize=figsize) - - # get colormap - cmap = colormap.get_cmap(cmap, 256) - - # plot confusion matrix - im = ax.imshow(cm, cmap=cmap, vmin=vmin, vmax=vmax) - - # threshold determining the color of the values - thresh = (cm.max() + cm.min()) / 2 - - # brightest/darkest color of current colormap - cmap_min, cmap_max = cmap(0), cmap(256) - - # plot values of confusion matrix - for i, j in itertools.product(range(nclasses), range(nclasses)): - ax.text(j, i, format(cm[i, j], fmt), ha='center', va='center', - color = cmap_max if cm[i, j] < thresh else cmap_min) - - # axes properties and labels - ax.set(xticks=np.arange(nclasses), - yticks=np.arange(nclasses), - xticklabels=labels, - yticklabels=labels, - ylabel='True', - xlabel='Predicted') - - # add colorbar axes - cax = fig.add_axes([ax.get_position().x1 + 0.025, ax.get_position().y0, - 0.05, ax.get_position().y1 - ax.get_position().y0]) - fig.colorbar(im, cax=cax) - - # save figure - if state is not None: - os.makedirs(outpath, exist_ok=True) - fig.savefig(os.path.join(outpath, state.replace('.pt', '_cm.png')), - dpi=300, bbox_inches='tight') - - return fig, ax - - def plot_loss(self, loss_file, figsize=(10, 10), - colors=['lightgreen', 'skyblue', 'darkgreen', 'steelblue'], - outpath=os.path.join(os.getcwd(), '_graphics/')): - - # load the model loss - state = torch.load(loss_file) - - # get all non-zero elements, i.e. get number of epochs trained before - # early stop - loss = {k: v[np.nonzero(v)].reshape(v.shape[0], -1) for k, v in - state.items() if k != 'epoch'} - - # number of epochs trained - epochs = np.arange(0, state['epoch'] + 1) - - # instanciate figure - fig, ax1 = plt.subplots(1, 1, figsize=figsize) - - # plot training and validation mean loss per epoch - [ax1.plot(epochs, v.mean(axis=0), - label=k.capitalize().replace('_', ' '), color=c, lw=2) - for (k, v), c in zip(loss.items(), colors) if v.any() and 'loss' in k] - - # plot training loss per batch - ax2 = ax1.twiny() - [ax2.plot(v.flatten('F'), color=c, alpha=0.5) - for (k, v), c in zip(loss.items(), colors) if 'loss' in k and - 'validation' not in k] - - # plot training and validation mean accuracy per epoch - ax3 = ax1.twinx() - [ax3.plot(epochs, v.mean(axis=0), - label=k.capitalize().replace('_', ' '), color=c, lw=2) - for (k, v), c in zip(loss.items(), colors) if v.any() and 'accuracy' - in k] - - # plot training accuracy per batch - ax4 = ax3.twiny() - [ax4.plot(v.flatten('F'), color=c, alpha=0.5) - for (k, v), c in zip(loss.items(), colors) if 'accuracy' in k and - 'validation' not in k] - - # axes properties and labels - for ax in [ax2, ax4]: - ax.set(xticks=[], xticklabels=[]) - ax1.set(xlabel='Epoch', - ylabel='Loss', - ylim=(0, 1)) - ax3.set(ylabel='Accuracy', - ylim=(0, 1)) - - # compute early stopping point - if loss['validation_accuracy'].any(): - esepoch = np.argmax(loss['validation_accuracy'].mean(axis=0)) - esacc = np.max(loss['validation_accuracy'].mean(axis=0)) - ax1.vlines(esepoch, ymin=ax1.get_ylim()[0], ymax=ax1.get_ylim()[1], - ls='--', color='grey') - ax1.text(esepoch - 1, ax1.get_ylim()[0] + 0.01, - 'epoch = {}'.format(esepoch), ha='right', color='grey') - ax1.text(esepoch + 1, ax1.get_ylim()[0] + 0.01, - 'acc = {:.2f}%'.format(esacc * 100), ha='left', - color='grey') - - # add legends - ax1.legend(frameon=False, loc='lower left') - ax3.legend(frameon=False, loc='upper left') - - # save figure - os.makedirs(outpath, exist_ok=True) - fig.savefig(os.path.join( - outpath, os.path.basename(loss_file).replace('.pt', '.png')), - dpi=300, bbox_inches='tight') - - return fig, ax + # append to list + scenes.append(data) + return scenes # SparcsDataset class: inherits from the generic ImageDataset class -class SparcsDataset(ImageDataset): +class SparcsDataset(StandardEoDataset): def __init__(self, root_dir, use_bands=['red', 'green', 'blue'], tile_size=None): @@ -532,9 +395,7 @@ class SparcsDataset(ImageDataset): # list of all scenes in the root directory # each scene is divided into tiles blocks - self.scenes = [] - for scene in os.listdir(self.root): - self.scenes += self.compose_scenes(os.path.join(self.root, scene)) + self.scenes = self.compose_scenes() # image size of the Sparcs dataset: (height, width) def get_size(self): @@ -542,28 +403,13 @@ class SparcsDataset(ImageDataset): # Landsat 8 bands of the Sparcs dataset def get_bands(self): - return { - 1: 'violet', - 2: 'blue', - 3: 'green', - 4: 'red', - 5: 'nir', - 6: 'swir1', - 7: 'swir2', - 8: 'pan', - 9: 'cirrus', - 10: 'tir'} + return {band.value: band.name for band in Landsat8} # class labels of the Sparcs dataset def get_labels(self): - labels = ['Shadow', 'Shadow over Water', 'Water', 'Snow', 'Land', - 'Cloud', 'Flooded'] - colors = ['black', 'darkblue', 'blue', 'lightblue', 'grey', 'white', - 'yellow'] - lc = {} - for i, (l, c) in enumerate(zip(labels, colors)): - lc[i] = {'label': l, 'color': c} - return lc + return {band.value[0]: {'label': band.name.replace('_', ' '), + 'color': band.value[1]} + for band in SparcsLabels} # preprocessing of the Sparcs dataset def preprocess(self, data, gt): @@ -572,60 +418,45 @@ class SparcsDataset(ImageDataset): # convert to torch tensors x = torch.tensor(data, dtype=torch.float32) - y = torch.tensor(gt, dtype=torch.uint8) + y = torch.tensor(gt, dtype=torch.uint8) if gt is not None else gt return x, y - # _compose_scenes() creates a list of dictionaries containing the paths - # to the files of each scene - # if the scenes are divided into tiles, each tile has its own entry - # with corresponding tile id - def compose_scenes(self, scene): - - # list the spectral bands of the scene - bands = glob.glob(os.path.join(scene, '*B*.tif')) - # sort the bands in ascending order - bands.sort(key=self._get_band_number) +class ProSnowDataset(StandardEoDataset): - # get the ground truth mask - gt = glob.glob(os.path.join(scene, '*mask.png')).pop() + def __init__(self, root_dir, use_bands, tile_size): + super().__init__(root_dir, use_bands, tile_size) - # create an entry for each scene/tile - scene_data = [] + # list of samples in the dataset + self.scenes = self.compose_scenes() - # iterate over the tiles - for tile in range(self.tiles): + # Sentinel 2 bands + def get_bands(self): + return {band.value: band.name for band in Sentinel2} - # store the bands and the ground truth mask of the tile - data = self._store_bands(bands, gt) + # class labels of the ProSnow dataset + def get_labels(self): + return {band.value[0]: {'label': band.name, 'color': band.value[1]} + for band in ProSnowLabels} - # store tile number - data['tile'] = tile + # preprocessing of the ProSnow dataset + def preprocess(self, data, gt): - # append to list - scene_data.append(data) + # if the preprocessing is not done externally, implement it here - return scene_data + # convert to torch tensors + x = torch.tensor(data, dtype=torch.float32) + y = torch.tensor(gt, dtype=torch.uint8) if gt is not None else gt + return x, y - # returns the band number of the preprocessed Sparcs Tiff files - def _get_band_number(self, x): - return int(os.path.basename(x).split('_')[2].replace('B', '')) - # _store_bands() writes the paths to the data of each scene to a dictionary - # only the bands of interest are stored - def _store_bands(self, bands, gt): +class ProSnowGarmisch(ProSnowDataset): - # store the bands of interest in a dictionary - scene_data = {} - for i, b in enumerate(bands): - band = self.bands[self._get_band_number(b)] - if band in self.use_bands: - scene_data[band] = b - - # store ground truth - scene_data['gt'] = gt + def __init__(self, root_dir, use_bands=[], tile_size=None): + super().__init__(root_dir, use_bands, tile_size) - return scene_data + def get_size(self): + return (615, 543) class Cloud95Dataset(ImageDataset): @@ -641,7 +472,7 @@ class Cloud95Dataset(ImageDataset): # list of all scenes in the root directory # each scene is divided into tiles blocks - self.scenes = self.compose_scenes(self.root) + self.scenes = self.compose_scenes() # image size of the Cloud-95 dataset: (height, width) def get_size(self): @@ -649,12 +480,12 @@ class Cloud95Dataset(ImageDataset): # Landsat 8 bands in the Cloud-95 dataset def get_bands(self): - return {1: 'red', 2: 'green', 3: 'blue', 4: 'nir'} + return {band.value: band.name for band in Landsat8} # class labels of the Cloud-95 dataset def get_labels(self): - return {0: {'label': 'Clear', 'color': 'skyblue'}, - 1: {'label': 'Cloud', 'color': 'white'}} + return {band.value[0]: {'label': band.name, 'color': band.value[1]} + for band in Cloud95Labels} # preprocess Cloud-95 dataset def preprocess(self, data, gt): @@ -667,8 +498,7 @@ class Cloud95Dataset(ImageDataset): return x, y - - def compose_scenes(self, root_dir): + def compose_scenes(self): # whether to exclude patches with more than 80% black pixels ipatches = [] @@ -678,10 +508,10 @@ class Cloud95Dataset(ImageDataset): # list of informative patches ipatches = list(itertools.chain.from_iterable(reader)) - # get the names of the directories containing the TIFF files of + # get the names of the directories containing the tif files of # the bands of interest band_dirs = {} - for dirpath, dirname, files in os.walk(root_dir): + for dirpath, dirname, files in os.walk(self.root): # check if the current directory path includes the name of a band # or the name of the ground truth mask cband = [band for band in self.use_bands + ['gt'] if @@ -713,7 +543,7 @@ class Cloud95Dataset(ImageDataset): # iterate over the bands of interest for band in band_dirs.keys(): - # save path to current band TIFF file to dictionary + # save path to current band tif file to dictionary scene[band] = os.path.join(band_dirs[band], file.replace(biter, band)) @@ -730,8 +560,8 @@ if __name__ == '__main__': # define path to working directory # wd = '//projectdata.eurac.edu/projects/cci_snow/dfrisinghelli/' - wd = '/mnt/CEPH_PROJECTS/cci_snow/dfrisinghelli' - # wd = 'C:/Eurac/2020/' + # wd = '/mnt/CEPH_PROJECTS/cci_snow/dfrisinghelli' + wd = 'C:/Eurac/2020/' # path to the preprocessed sparcs dataset sparcs_path = os.path.join(wd, '_Datasets/Sparcs') @@ -739,6 +569,9 @@ if __name__ == '__main__': # path to the Cloud-95 dataset cloud_path = os.path.join(wd, '_Datasets/Cloud95/Training') + # path to the ProSnow dataset + prosnow_path = os.path.join(wd, '_Datasets/ProSnow/Garmisch') + # the csv file containing the names of the informative patches patches = 'training_patches_95-cloud_nonempty.csv' @@ -749,24 +582,5 @@ if __name__ == '__main__': sparcs_dataset = SparcsDataset(sparcs_path, tile_size=None, use_bands=['nir', 'red', 'green']) - # a sample from the sparcs dataset - sample_s = np.random.randint(len(sparcs_dataset), size=1).item() - s_x, s_y = sparcs_dataset[sample_s] - fig, ax = sparcs_dataset.plot_sample(s_x, s_y, - bands=['nir', 'red', 'green']) - - # a sample from the cloud dataset - sample_c = np.random.randint(len(cloud_dataset), size=1).item() - c_x, c_y = cloud_dataset[sample_c] - fig, ax = cloud_dataset.plot_sample(c_x, c_y, - bands=['nir', 'red', 'green']) - - # print shape of the sample - for i, l, d in zip([s_x, c_x], [s_y, c_y], - [sparcs_dataset, cloud_dataset]): - print('A sample from the {}:'.format(d.__class__.__name__)) - print('Shape of input data: {}'.format(i.shape)) - print('Shape of ground truth: {}'.format(l.shape)) - - # show figures - plt.show() + # instanciate the ProSnow class + prosnow_dataset = ProSnowGarmisch(prosnow_path) diff --git a/pytorch/graphics.py b/pytorch/graphics.py new file mode 100644 index 0000000..d142209 --- /dev/null +++ b/pytorch/graphics.py @@ -0,0 +1,240 @@ +# -*- coding: utf-8 -*- +""" +Created on Tue Jul 14 11:04:27 2020 + +@author: Daniel +""" +# builtins +import os +import itertools + +# externals +import numpy as np +import torch +import matplotlib.pyplot as plt +import matplotlib.patches as mpatches +from matplotlib.colors import ListedColormap, BoundaryNorm +from matplotlib import cm as colormap + + +# this function applies percentile stretching at the alpha level +# can be used to increase constrast for visualization +def contrast_stretching(image, alpha=2): + + # compute upper and lower percentiles defining the range of the stretch + inf, sup = np.percentile(image, (alpha, 100 - alpha)) + + # normalize image intensity distribution to + # (alpha, 100 - alpha) percentiles + norm = ((image - inf) * (image.max() - image.min()) / + (sup - inf)) + image.min() + + # clip: values < inf = 0, values > sup = max + norm[norm <= image.min()] = image.min() + norm[norm >= image.max()] = image.max() + + return norm + + +# plot_sample() plots a false color composite of the scene/tile together +# with the model prediction and the corresponding ground truth +def plot_sample(x, y, use_bands, labels, y_pred=None, figsize=(10, 10), + bands=['red', 'green', 'blue'], stretch=False, state=None, + outpath=os.path.join(os.getcwd(), '_samples/'), **kwargs): + + # check whether to apply constrast stretching + stretch = True if kwargs else False + func = contrast_stretching if stretch else lambda x: x + + # create an rgb stack + rgb = np.dstack([func(x[use_bands.index(band)], **kwargs) + for band in bands]) + + # get labels and corresponding colors + ulabels = [label['label'] for label in labels.values()] + colors = [label['color'] for label in labels.values()] + + # create a ListedColormap + cmap = ListedColormap(colors) + boundaries = [*labels.keys(), cmap.N] + norm = BoundaryNorm(boundaries, cmap.N) + + # create figure: check whether to plot model prediction + if y_pred is not None: + + # compute accuracy + acc = (y_pred == y).float().mean() + + # plot model prediction + fig, ax = plt.subplots(1, 3, figsize=figsize) + ax[2].imshow(y_pred, cmap=cmap, interpolation='nearest', norm=norm) + ax[2].set_title('Prediction ({:.2f}%)'.format(acc * 100), pad=15) + + else: + fig, ax = plt.subplots(1, 2, figsize=figsize) + + # plot false color composite + ax[0].imshow(rgb) + ax[0].set_title('R = {}, G = {}, B = {}'.format(*bands), pad=15) + + # plot ground thruth mask + ax[1].imshow(y, cmap=cmap, interpolation='nearest', norm=norm) + ax[1].set_title('Ground truth', pad=15) + + # create a patch (proxy artist) for every color + patches = [mpatches.Patch(color=c, label=l) for c, l in + zip(colors, ulabels)] + + # plot patches as legend + plt.legend(handles=patches, bbox_to_anchor=(1.05, 1), loc=2, + frameon=False) + + # save figure + if state is not None: + os.makedirs(outpath, exist_ok=True) + fig.savefig(os.path.join(outpath, state.replace('.pt', '.png')), + dpi=300, bbox_inches='tight') + + return fig, ax + + +# plot_confusion_matrix() plots the confusion matrix of the validation/test +# set returned by the pytorch.predict function +def plot_confusion_matrix(cm, labels, normalize=True, + figsize=(10, 10), cmap='Blues', state=None, + outpath=os.path.join(os.getcwd(), '_graphics/')): + + # number of classes + nclasses = len(labels) + + # string format to plot values of confusion matrix + fmt = '.0f' + + # minimum and maximum values of the colorbar + vmin, vmax = 0, cm.max() + + # check whether to normalize the confusion matrix + if normalize: + # normalize + cm = cm / cm.sum(axis=1, keepdims=True) + + # change string format to floating point + fmt = '.2f' + vmin, vmax = 0, 1 + + # create figure + fig, ax = plt.subplots(1, 1, figsize=figsize) + + # get colormap + cmap = colormap.get_cmap(cmap, 256) + + # plot confusion matrix + im = ax.imshow(cm, cmap=cmap, vmin=vmin, vmax=vmax) + + # threshold determining the color of the values + thresh = (cm.max() + cm.min()) / 2 + + # brightest/darkest color of current colormap + cmap_min, cmap_max = cmap(0), cmap(256) + + # plot values of confusion matrix + for i, j in itertools.product(range(nclasses), range(nclasses)): + ax.text(j, i, format(cm[i, j], fmt), ha='center', va='center', + color=cmap_max if cm[i, j] < thresh else cmap_min) + + # axes properties and labels + ax.set(xticks=np.arange(nclasses), + yticks=np.arange(nclasses), + xticklabels=labels, + yticklabels=labels, + ylabel='True', + xlabel='Predicted') + + # add colorbar axes + cax = fig.add_axes([ax.get_position().x1 + 0.025, ax.get_position().y0, + 0.05, ax.get_position().y1 - ax.get_position().y0]) + fig.colorbar(im, cax=cax) + + # save figure + if state is not None: + os.makedirs(outpath, exist_ok=True) + fig.savefig(os.path.join(outpath, state.replace('.pt', '_cm.png')), + dpi=300, bbox_inches='tight') + + return fig, ax + + +def plot_loss(loss_file, figsize=(10, 10), + colors=['lightgreen', 'skyblue', 'darkgreen', 'steelblue'], + outpath=os.path.join(os.getcwd(), '_graphics/')): + + # load the model loss + state = torch.load(loss_file) + + # get all non-zero elements, i.e. get number of epochs trained before + # early stop + loss = {k: v[np.nonzero(v)].reshape(v.shape[0], -1) for k, v in + state.items() if k != 'epoch'} + + # number of epochs trained + epochs = np.arange(0, state['epoch'] + 1) + + # instanciate figure + fig, ax1 = plt.subplots(1, 1, figsize=figsize) + + # plot training and validation mean loss per epoch + [ax1.plot(epochs, v.mean(axis=0), + label=k.capitalize().replace('_', ' '), color=c, lw=2) + for (k, v), c in zip(loss.items(), colors) if v.any() and 'loss' in k] + + # plot training loss per batch + ax2 = ax1.twiny() + [ax2.plot(v.flatten('F'), color=c, alpha=0.5) + for (k, v), c in zip(loss.items(), colors) if 'loss' in k and + 'validation' not in k] + + # plot training and validation mean accuracy per epoch + ax3 = ax1.twinx() + [ax3.plot(epochs, v.mean(axis=0), + label=k.capitalize().replace('_', ' '), color=c, lw=2) + for (k, v), c in zip(loss.items(), colors) if v.any() and 'accuracy' + in k] + + # plot training accuracy per batch + ax4 = ax3.twiny() + [ax4.plot(v.flatten('F'), color=c, alpha=0.5) + for (k, v), c in zip(loss.items(), colors) if 'accuracy' in k and + 'validation' not in k] + + # axes properties and labels + for ax in [ax2, ax4]: + ax.set(xticks=[], xticklabels=[]) + ax1.set(xlabel='Epoch', + ylabel='Loss', + ylim=(0, 1)) + ax3.set(ylabel='Accuracy', + ylim=(0, 1)) + + # compute early stopping point + if loss['validation_accuracy'].any(): + esepoch = np.argmax(loss['validation_accuracy'].mean(axis=0)) + esacc = np.max(loss['validation_accuracy'].mean(axis=0)) + ax1.vlines(esepoch, ymin=ax1.get_ylim()[0], ymax=ax1.get_ylim()[1], + ls='--', color='grey') + ax1.text(esepoch - 1, ax1.get_ylim()[0] + 0.01, + 'epoch = {}'.format(esepoch), ha='right', color='grey') + ax1.text(esepoch + 1, ax1.get_ylim()[0] + 0.01, + 'acc = {:.2f}%'.format(esacc * 100), ha='left', + color='grey') + + # add legends + ax1.legend(frameon=False, loc='lower left') + ax3.legend(frameon=False, loc='upper left') + + # save figure + os.makedirs(outpath, exist_ok=True) + fig.savefig(os.path.join( + outpath, os.path.basename(loss_file).replace('.pt', '.png')), + dpi=300, bbox_inches='tight') + + return fig, ax -- GitLab