Source code for core.graphics

"""Functions to plot multispectral image data and model output.

License
-------

    Copyright (c) 2020 Daniel Frisinghelli

    This source code is licensed under the GNU General Public License v3.

    See the LICENSE file in the repository's root directory.

"""

# !/usr/bin/env python
# -*- coding: utf-8 -*-

# builtins
import os
import itertools

# externals
import numpy as np
import torch
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
import matplotlib.lines as mlines
from matplotlib.colors import ListedColormap, BoundaryNorm
from matplotlib import cm as colormap

# locals
from pysegcnn.core.trainer import accuracy_function
from pysegcnn.main.config import HERE


[docs]def contrast_stretching(image, alpha=5): """Apply percentile stretching to an image to increase constrast. Parameters ---------- image : `numpy.ndarray` the input image. alpha : `int`, optional The level of the percentiles. The default is 5. Returns ------- norm : `numpy.ndarray` the stretched image. """ # compute upper and lower percentiles defining the range of the stretch inf, sup = np.percentile(image, (alpha, 100 - alpha)) # normalize image intensity distribution to # (alpha, 100 - alpha) percentiles norm = ((image - inf) * (image.max() - image.min()) / (sup - inf)) + image.min() # clip: values < min = min, values > max = max norm[norm <= image.min()] = image.min() norm[norm >= image.max()] = image.max() return norm
[docs]def running_mean(x, w): """Compute a running mean of the input sequence. Parameters ---------- x : array_like The sequence to compute a running mean on. w : `int` The window length of the running mean. Returns ------- rm : `numpy.ndarray` The running mean of the sequence ``x``. """ cumsum = np.cumsum(np.insert(x, 0, 0)) return (cumsum[w:] - cumsum[:-w]) / w
[docs]def plot_sample(x, use_bands, labels, y=None, y_pred=None, figsize=(10, 10), bands=['nir', 'red', 'green'], state=None, outpath=os.path.join(HERE, '_samples/'), alpha=0): """Plot false color composite (FCC), ground truth and model prediction. Parameters ---------- x : `numpy.ndarray` or `torch.Tensor`, (b, h, w) Array containing the raw data of the tile, shape=(bands, height, width) use_bands : `list` of `str` List describing the order of the bands in ``x``. labels : `dict` [`int`, `dict`] The label dictionary. The keys are the values of the class labels in the ground truth ``y``. Each nested `dict` should have keys: ``'color'`` A named color (`str`). ``'label'`` The name of the class label (`str`). y : `numpy.ndarray` or `torch.Tensor` or `None`, optional Array containing the ground truth of tile ``x``, shape=(height, width). The default is None. y_pred : `numpy.ndarray` or `torch.Tensor` or `None`, optional Array containing the prediction for tile ``x``, shape=(height, width). The default is None. figsize : `tuple`, optional The figure size in centimeters. The default is (10, 10). bands : `list` [`str`], optional The bands to build the FCC. The default is ['nir', 'red', 'green']. state : `str` or `None`, optional Filename to save the plot to. ``state`` should be an existing model state file ending with '.pt'. The default is None, i.e. plot is not saved to disk. outpath : `str` or `pathlib.Path`, optional Output path. The default is 'pysegcnn/main/_samples'. alpha : `int`, optional The level of the percentiles to increase constrast in the FCC. The default is 0, i.e. no stretching. Returns ------- fig : `matplotlib.figure.Figure` The figure handle. ax : `numpy.ndarray` [`matplotlib.axes._subplots.AxesSubplot`] An array of the axes handles. """ # check whether to apply constrast stretching rgb = np.dstack([contrast_stretching(x[use_bands.index(band)], alpha) for band in bands]) # get labels and corresponding colors ulabels = [label['label'] for label in labels.values()] colors = [label['color'] for label in labels.values()] # create a ListedColormap cmap = ListedColormap(colors) boundaries = [*labels.keys(), cmap.N] norm = BoundaryNorm(boundaries, cmap.N) # create a patch (proxy artist) for every color patches = [mpatches.Patch(color=c, label=l) for c, l in zip(colors, ulabels)] # initialize figure fig, ax = plt.subplots(1, 3, figsize=figsize) # plot false color composite ax[0].imshow(rgb) ax[0].set_title('R = {}, G = {}, B = {}'.format(*bands), pad=15) # check whether to plot ground truth acc = None if y is None: # remove axis to plot ground truth from figure fig.delaxes(ax[1]) else: # plot ground thruth mask ax[1].imshow(y, cmap=cmap, interpolation='nearest', norm=norm) ax[1].set_title('Ground truth', pad=15) # check whether to plot model prediction if y_pred is None: # remove axis to plot model prediction from figure fig.delaxes(ax[2]) else: # plot model prediction ax[2].imshow(y_pred, cmap=cmap, interpolation='nearest', norm=norm) # set title title = 'Prediction' if y is not None: acc = accuracy_function(y_pred, y) title += ' ({:.2f}%)'.format(acc * 100) ax[2].set_title(title, pad=15) # if a ground truth or a model prediction is plotted, add legend if len(fig.axes) > 1: plt.legend(handles=patches, bbox_to_anchor=(1.05, 1), loc=2, frameon=False) # save figure if state is not None: os.makedirs(outpath, exist_ok=True) fig.savefig(os.path.join(outpath, state.replace('.pt', '.png')), dpi=300, bbox_inches='tight') return fig, ax
[docs]def plot_confusion_matrix(cm, labels, normalize=True, figsize=(10, 10), cmap='Blues', state=None, outpath=os.path.join(HERE, '_graphics/')): """Plot the confusion matrix ``cm``. Parameters ---------- cm : `numpy.ndarray` The confusion matrix. labels : `dict` [`int`, `dict`] The label dictionary. The keys are the values of the class labels in the ground truth ``y``. Each nested `dict` should have keys: ``'color'`` A named color (`str`). ``'label'`` The name of the class label (`str`). normalize : `bool`, optional Whether to normalize the confusion matrix. The default is True. figsize : `tuple`, optional The figure size in centimeters. The default is (10, 10). cmap : `str`, optional A colormap in `matplotlib.pyplot.colormaps()`. The default is 'Blues'. state : `str` or `None`, optional Filename to save the plot to. ``state`` should be an existing model state file ending with '.pt'. The default is None, i.e. plot is not saved to disk. outpath : `str` or `pathlib.Path`, optional Output path. The default is 'pysegcnn/main/_graphics/'. Returns ------- fig : `matplotlib.figure.Figure` The figure handle. ax : `matplotlib.axes._subplots.AxesSubplot` The axes handle. """ # number of classes labels = [label['label'] for label in labels.values()] nclasses = len(labels) # string format to plot values of confusion matrix fmt = '.0f' # minimum and maximum values of the colorbar vmin, vmax = 0, cm.max() # check whether to normalize the confusion matrix if normalize: # normalize norm = cm.sum(axis=1, keepdims=True) # check for division by zero norm[norm == 0] = 1 cm = cm / norm # change string format to floating point fmt = '.2f' vmin, vmax = 0, 1 # create figure fig, ax = plt.subplots(1, 1, figsize=figsize) # get colormap cmap = colormap.get_cmap(cmap, 256) # plot confusion matrix im = ax.imshow(cm, cmap=cmap, vmin=vmin, vmax=vmax) # threshold determining the color of the values thresh = (cm.max() + cm.min()) / 2 # brightest/darkest color of current colormap cmap_min, cmap_max = cmap(0), cmap(256) # plot values of confusion matrix for i, j in itertools.product(range(nclasses), range(nclasses)): ax.text(j, i, format(cm[i, j], fmt), ha='center', va='center', color=cmap_max if cm[i, j] < thresh else cmap_min) # axes properties and labels ax.set(xticks=np.arange(nclasses), yticks=np.arange(nclasses), xticklabels=labels, yticklabels=labels, ylabel='True', xlabel='Predicted') # add colorbar axes cax = fig.add_axes([ax.get_position().x1 + 0.025, ax.get_position().y0, 0.05, ax.get_position().y1 - ax.get_position().y0]) fig.colorbar(im, cax=cax) # save figure if state is not None: os.makedirs(outpath, exist_ok=True) fig.savefig(os.path.join(outpath, state), dpi=300, bbox_inches='tight') return fig, ax
[docs]def plot_loss(state_file, figsize=(10, 10), step=5, colors=['lightgreen', 'green', 'skyblue', 'steelblue'], outpath=os.path.join(HERE, '_graphics/')): """Plot the observed loss and accuracy of a model run. Parameters ---------- state_file : `str` or `pathlib.Path` The model state file. Model state files are stored in pysegcnn/main/_models. figsize : `tuple`, optional The figure size in centimeters. The default is (10, 10). step : `int`, optional The step of epochs for the x-axis labels. The default is 5, i.e. label each fifth epoch. colors : `list` [`str`], optional A list of four named colors supported by `matplotlib`. The default is ['lightgreen', 'green', 'skyblue', 'steelblue']. outpath : `str` or `pathlib.Path`, optional Output path. The default is 'pysegcnn/main/_graphics/'. Returns ------- fig : `matplotlib.figure.Figure` The figure handle. """ # load the model state model_state = torch.load(state_file) # get all non-zero elements, i.e. get number of epochs trained before # early stop loss = {k: v[np.nonzero(v)].reshape(v.shape[0], -1) for k, v in model_state['state'].items()} # compute running mean with a window equal to the number of batches in # an epoch rm = {k: running_mean(v.flatten('F'), v.shape[0]) for k, v in loss.items()} # sort the keys of the dictionary alphabetically rm = {k: rm[k] for k in sorted(rm)} # number of epochs trained epochs = np.arange(0, loss['tl'].shape[1]) # instanciate figure fig, ax1 = plt.subplots(1, 1, figsize=figsize) # create axes for each parameter to plot ax2 = ax1.twinx() ax3 = ax1.twiny() ax4 = ax2.twiny() # list of axes axes = [ax2, ax1, ax4, ax3] # plot running mean loss and accuracy of the training dataset [ax.plot(v, color=c) for (k, v), ax, c in zip(rm.items(), axes, colors) if v.any()] # axes properties and labels nbatches = loss['tl'].shape[0] ax3.set(xticks=[], xticklabels=[]) ax4.set(xticks=[], xticklabels=[]) ax1.set(xticks=np.arange(0, nbatches * epochs[-1] + 1, nbatches * step), xticklabels=epochs[::step], xlabel='Epoch', ylabel='Loss', ylim=(0, 1)) ax2.set(ylabel='Accuracy', ylim=(0.5, 1)) # compute early stopping point if loss['va'].any(): esepoch = np.argmax(loss['va'].mean(axis=0)) * nbatches + 1 esacc = np.max(loss['va'].mean(axis=0)) ax1.vlines(esepoch, ymin=ax1.get_ylim()[0], ymax=ax1.get_ylim()[1], ls='--', color='grey') ax1.text(esepoch - nbatches, ax1.get_ylim()[0] + 0.01, 'epoch = {}, accuracy = {:.1f}%' .format(int(esepoch / nbatches), esacc * 100), ha='right', color='grey') # create a patch (proxy artist) for every color ulabels = ['Training accuracy', 'Training loss', 'Validation accuracy', 'Validation loss'] patches = [mlines.Line2D([], [], color=c, label=l) for c, l in zip(colors, ulabels)] # plot patches as legend ax1.legend(handles=patches, loc='lower left', frameon=False) # save figure os.makedirs(outpath, exist_ok=True) fig.savefig(os.path.join( outpath, os.path.basename(state_file).replace('.pt', '.png')), dpi=300, bbox_inches='tight') return fig