"""Functions to plot multispectral image data and model output.
License
-------
Copyright (c) 2020 Daniel Frisinghelli
This source code is licensed under the GNU General Public License v3.
See the LICENSE file in the repository's root directory.
"""
# !/usr/bin/env python
# -*- coding: utf-8 -*-
# builtins
import os
import itertools
# externals
import numpy as np
import torch
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
import matplotlib.lines as mlines
from matplotlib.colors import ListedColormap, BoundaryNorm
from matplotlib import cm as colormap
# locals
from pysegcnn.core.trainer import accuracy_function
from pysegcnn.main.config import HERE
[docs]def contrast_stretching(image, alpha=5):
"""Apply percentile stretching to an image to increase constrast.
Parameters
----------
image : `numpy.ndarray`
the input image.
alpha : `int`, optional
The level of the percentiles. The default is 5.
Returns
-------
norm : `numpy.ndarray`
the stretched image.
"""
# compute upper and lower percentiles defining the range of the stretch
inf, sup = np.percentile(image, (alpha, 100 - alpha))
# normalize image intensity distribution to
# (alpha, 100 - alpha) percentiles
norm = ((image - inf) * (image.max() - image.min()) /
(sup - inf)) + image.min()
# clip: values < min = min, values > max = max
norm[norm <= image.min()] = image.min()
norm[norm >= image.max()] = image.max()
return norm
[docs]def running_mean(x, w):
"""Compute a running mean of the input sequence.
Parameters
----------
x : array_like
The sequence to compute a running mean on.
w : `int`
The window length of the running mean.
Returns
-------
rm : `numpy.ndarray`
The running mean of the sequence ``x``.
"""
cumsum = np.cumsum(np.insert(x, 0, 0))
return (cumsum[w:] - cumsum[:-w]) / w
[docs]def plot_sample(x, use_bands, labels, y=None, y_pred=None, figsize=(10, 10),
bands=['nir', 'red', 'green'], state=None,
outpath=os.path.join(HERE, '_samples/'), alpha=0):
"""Plot false color composite (FCC), ground truth and model prediction.
Parameters
----------
x : `numpy.ndarray` or `torch.Tensor`, (b, h, w)
Array containing the raw data of the tile, shape=(bands, height, width)
use_bands : `list` of `str`
List describing the order of the bands in ``x``.
labels : `dict` [`int`, `dict`]
The label dictionary. The keys are the values of the class labels
in the ground truth ``y``. Each nested `dict` should have keys:
``'color'``
A named color (`str`).
``'label'``
The name of the class label (`str`).
y : `numpy.ndarray` or `torch.Tensor` or `None`, optional
Array containing the ground truth of tile ``x``, shape=(height, width).
The default is None.
y_pred : `numpy.ndarray` or `torch.Tensor` or `None`, optional
Array containing the prediction for tile ``x``, shape=(height, width).
The default is None.
figsize : `tuple`, optional
The figure size in centimeters. The default is (10, 10).
bands : `list` [`str`], optional
The bands to build the FCC. The default is ['nir', 'red', 'green'].
state : `str` or `None`, optional
Filename to save the plot to. ``state`` should be an existing model
state file ending with '.pt'. The default is None, i.e. plot is not
saved to disk.
outpath : `str` or `pathlib.Path`, optional
Output path. The default is 'pysegcnn/main/_samples'.
alpha : `int`, optional
The level of the percentiles to increase constrast in the FCC.
The default is 0, i.e. no stretching.
Returns
-------
fig : `matplotlib.figure.Figure`
The figure handle.
ax : `numpy.ndarray` [`matplotlib.axes._subplots.AxesSubplot`]
An array of the axes handles.
"""
# check whether to apply constrast stretching
rgb = np.dstack([contrast_stretching(x[use_bands.index(band)], alpha)
for band in bands])
# get labels and corresponding colors
ulabels = [label['label'] for label in labels.values()]
colors = [label['color'] for label in labels.values()]
# create a ListedColormap
cmap = ListedColormap(colors)
boundaries = [*labels.keys(), cmap.N]
norm = BoundaryNorm(boundaries, cmap.N)
# create a patch (proxy artist) for every color
patches = [mpatches.Patch(color=c, label=l) for c, l in
zip(colors, ulabels)]
# initialize figure
fig, ax = plt.subplots(1, 3, figsize=figsize)
# plot false color composite
ax[0].imshow(rgb)
ax[0].set_title('R = {}, G = {}, B = {}'.format(*bands), pad=15)
# check whether to plot ground truth
acc = None
if y is None:
# remove axis to plot ground truth from figure
fig.delaxes(ax[1])
else:
# plot ground thruth mask
ax[1].imshow(y, cmap=cmap, interpolation='nearest', norm=norm)
ax[1].set_title('Ground truth', pad=15)
# check whether to plot model prediction
if y_pred is None:
# remove axis to plot model prediction from figure
fig.delaxes(ax[2])
else:
# plot model prediction
ax[2].imshow(y_pred, cmap=cmap, interpolation='nearest', norm=norm)
# set title
title = 'Prediction'
if y is not None:
acc = accuracy_function(y_pred, y)
title += ' ({:.2f}%)'.format(acc * 100)
ax[2].set_title(title, pad=15)
# if a ground truth or a model prediction is plotted, add legend
if len(fig.axes) > 1:
plt.legend(handles=patches, bbox_to_anchor=(1.05, 1), loc=2,
frameon=False)
# save figure
if state is not None:
os.makedirs(outpath, exist_ok=True)
fig.savefig(os.path.join(outpath, state.replace('.pt', '.png')),
dpi=300, bbox_inches='tight')
return fig, ax
[docs]def plot_confusion_matrix(cm, labels, normalize=True,
figsize=(10, 10), cmap='Blues', state=None,
outpath=os.path.join(HERE, '_graphics/')):
"""Plot the confusion matrix ``cm``.
Parameters
----------
cm : `numpy.ndarray`
The confusion matrix.
labels : `dict` [`int`, `dict`]
The label dictionary. The keys are the values of the class labels
in the ground truth ``y``. Each nested `dict` should have keys:
``'color'``
A named color (`str`).
``'label'``
The name of the class label (`str`).
normalize : `bool`, optional
Whether to normalize the confusion matrix. The default is True.
figsize : `tuple`, optional
The figure size in centimeters. The default is (10, 10).
cmap : `str`, optional
A colormap in `matplotlib.pyplot.colormaps()`. The default is 'Blues'.
state : `str` or `None`, optional
Filename to save the plot to. ``state`` should be an existing model
state file ending with '.pt'. The default is None, i.e. plot is not
saved to disk.
outpath : `str` or `pathlib.Path`, optional
Output path. The default is 'pysegcnn/main/_graphics/'.
Returns
-------
fig : `matplotlib.figure.Figure`
The figure handle.
ax : `matplotlib.axes._subplots.AxesSubplot`
The axes handle.
"""
# number of classes
labels = [label['label'] for label in labels.values()]
nclasses = len(labels)
# string format to plot values of confusion matrix
fmt = '.0f'
# minimum and maximum values of the colorbar
vmin, vmax = 0, cm.max()
# check whether to normalize the confusion matrix
if normalize:
# normalize
norm = cm.sum(axis=1, keepdims=True)
# check for division by zero
norm[norm == 0] = 1
cm = cm / norm
# change string format to floating point
fmt = '.2f'
vmin, vmax = 0, 1
# create figure
fig, ax = plt.subplots(1, 1, figsize=figsize)
# get colormap
cmap = colormap.get_cmap(cmap, 256)
# plot confusion matrix
im = ax.imshow(cm, cmap=cmap, vmin=vmin, vmax=vmax)
# threshold determining the color of the values
thresh = (cm.max() + cm.min()) / 2
# brightest/darkest color of current colormap
cmap_min, cmap_max = cmap(0), cmap(256)
# plot values of confusion matrix
for i, j in itertools.product(range(nclasses), range(nclasses)):
ax.text(j, i, format(cm[i, j], fmt), ha='center', va='center',
color=cmap_max if cm[i, j] < thresh else cmap_min)
# axes properties and labels
ax.set(xticks=np.arange(nclasses),
yticks=np.arange(nclasses),
xticklabels=labels,
yticklabels=labels,
ylabel='True',
xlabel='Predicted')
# add colorbar axes
cax = fig.add_axes([ax.get_position().x1 + 0.025, ax.get_position().y0,
0.05, ax.get_position().y1 - ax.get_position().y0])
fig.colorbar(im, cax=cax)
# save figure
if state is not None:
os.makedirs(outpath, exist_ok=True)
fig.savefig(os.path.join(outpath, state), dpi=300, bbox_inches='tight')
return fig, ax
[docs]def plot_loss(state_file, figsize=(10, 10), step=5,
colors=['lightgreen', 'green', 'skyblue', 'steelblue'],
outpath=os.path.join(HERE, '_graphics/')):
"""Plot the observed loss and accuracy of a model run.
Parameters
----------
state_file : `str` or `pathlib.Path`
The model state file. Model state files are stored in
pysegcnn/main/_models.
figsize : `tuple`, optional
The figure size in centimeters. The default is (10, 10).
step : `int`, optional
The step of epochs for the x-axis labels. The default is 5, i.e. label
each fifth epoch.
colors : `list` [`str`], optional
A list of four named colors supported by `matplotlib`.
The default is ['lightgreen', 'green', 'skyblue', 'steelblue'].
outpath : `str` or `pathlib.Path`, optional
Output path. The default is 'pysegcnn/main/_graphics/'.
Returns
-------
fig : `matplotlib.figure.Figure`
The figure handle.
"""
# load the model state
model_state = torch.load(state_file)
# get all non-zero elements, i.e. get number of epochs trained before
# early stop
loss = {k: v[np.nonzero(v)].reshape(v.shape[0], -1) for k, v in
model_state['state'].items()}
# compute running mean with a window equal to the number of batches in
# an epoch
rm = {k: running_mean(v.flatten('F'), v.shape[0]) for k, v in loss.items()}
# sort the keys of the dictionary alphabetically
rm = {k: rm[k] for k in sorted(rm)}
# number of epochs trained
epochs = np.arange(0, loss['tl'].shape[1])
# instanciate figure
fig, ax1 = plt.subplots(1, 1, figsize=figsize)
# create axes for each parameter to plot
ax2 = ax1.twinx()
ax3 = ax1.twiny()
ax4 = ax2.twiny()
# list of axes
axes = [ax2, ax1, ax4, ax3]
# plot running mean loss and accuracy of the training dataset
[ax.plot(v, color=c) for (k, v), ax, c in zip(rm.items(), axes, colors)
if v.any()]
# axes properties and labels
nbatches = loss['tl'].shape[0]
ax3.set(xticks=[], xticklabels=[])
ax4.set(xticks=[], xticklabels=[])
ax1.set(xticks=np.arange(0, nbatches * epochs[-1] + 1, nbatches * step),
xticklabels=epochs[::step],
xlabel='Epoch',
ylabel='Loss',
ylim=(0, 1))
ax2.set(ylabel='Accuracy',
ylim=(0.5, 1))
# compute early stopping point
if loss['va'].any():
esepoch = np.argmax(loss['va'].mean(axis=0)) * nbatches + 1
esacc = np.max(loss['va'].mean(axis=0))
ax1.vlines(esepoch, ymin=ax1.get_ylim()[0], ymax=ax1.get_ylim()[1],
ls='--', color='grey')
ax1.text(esepoch - nbatches, ax1.get_ylim()[0] + 0.01,
'epoch = {}, accuracy = {:.1f}%'
.format(int(esepoch / nbatches), esacc * 100),
ha='right', color='grey')
# create a patch (proxy artist) for every color
ulabels = ['Training accuracy', 'Training loss',
'Validation accuracy', 'Validation loss']
patches = [mlines.Line2D([], [], color=c, label=l) for c, l in
zip(colors, ulabels)]
# plot patches as legend
ax1.legend(handles=patches, loc='lower left', frameon=False)
# save figure
os.makedirs(outpath, exist_ok=True)
fig.savefig(os.path.join(
outpath, os.path.basename(state_file).replace('.pt', '.png')),
dpi=300, bbox_inches='tight')
return fig