Skip to content
Snippets Groups Projects
Commit e554e653 authored by Frisinghelli Daniel's avatar Frisinghelli Daniel
Browse files

Made the ground truth mask in plot_sample optional

parent 239f3e57
No related branches found
No related tags found
No related merge requests found
......@@ -32,8 +32,6 @@ from pysegcnn.core.trainer import accuracy_function
from pysegcnn.main.config import HERE
# this function applies percentile stretching at the alpha level
# can be used to increase constrast for visualization
def contrast_stretching(image, alpha=5):
"""Apply percentile stretching to an image to increase constrast.
......@@ -85,9 +83,7 @@ def running_mean(x, w):
return (cumsum[w:] - cumsum[:-w]) / w
# plot_sample() plots a false color composite of the scene/tile together
# with the model prediction and the corresponding ground truth
def plot_sample(x, y, use_bands, labels, y_pred=None, figsize=(10, 10),
def plot_sample(x, use_bands, labels, y=None, y_pred=None, figsize=(10, 10),
bands=['nir', 'red', 'green'], state=None,
outpath=os.path.join(HERE, '_samples/'), alpha=0):
"""Plot false color composite (FCC), ground truth and model prediction.
......@@ -96,8 +92,6 @@ def plot_sample(x, y, use_bands, labels, y_pred=None, figsize=(10, 10),
----------
x : `numpy.ndarray` or `torch.Tensor`, (b, h, w)
Array containing the raw data of the tile, shape=(bands, height, width)
y : `numpy.ndarray` or `torch.Tensor`, (h, w)
Array containing the ground truth of tile ``x``, shape=(height, width)
use_bands : `list` of `str`
List describing the order of the bands in ``x``.
labels : `dict` [`int`, `dict`]
......@@ -107,9 +101,12 @@ def plot_sample(x, y, use_bands, labels, y_pred=None, figsize=(10, 10),
A named color (`str`).
``'label'``
The name of the class label (`str`).
y_pred : `numpy.ndarray` or `None`, optional
y : `numpy.ndarray` or `torch.Tensor` or `None`, optional
Array containing the ground truth of tile ``x``, shape=(height, width).
The default is None.
y_pred : `numpy.ndarray` or `torch.Tensor` or `None`, optional
Array containing the prediction for tile ``x``, shape=(height, width).
The default is None, i.e. only FCC and ground truth are plotted.
The default is None.
figsize : `tuple`, optional
The figure size in centimeters. The default is (10, 10).
bands : `list` [`str`], optional
......@@ -119,7 +116,7 @@ def plot_sample(x, y, use_bands, labels, y_pred=None, figsize=(10, 10),
state file ending with '.pt'. The default is None, i.e. plot is not
saved to disk.
outpath : `str` or `pathlib.Path`, optional
Output path. The default is os.path.join(HERE, '_samples/').
Output path. The default is 'pysegcnn/main/_samples'.
alpha : `int`, optional
The level of the percentiles to increase constrast in the FCC.
The default is 0, i.e. no stretching.
......@@ -128,8 +125,8 @@ def plot_sample(x, y, use_bands, labels, y_pred=None, figsize=(10, 10),
-------
fig : `matplotlib.figure.Figure`
The figure handle.
ax : `matplotlib.axes._subplots.AxesSubplot`
The axes handle.
ax : `numpy.ndarray` [`matplotlib.axes._subplots.AxesSubplot`]
An array of the axes handles.
"""
# check whether to apply constrast stretching
......@@ -145,35 +142,46 @@ def plot_sample(x, y, use_bands, labels, y_pred=None, figsize=(10, 10),
boundaries = [*labels.keys(), cmap.N]
norm = BoundaryNorm(boundaries, cmap.N)
# create figure: check whether to plot model prediction
if y_pred is not None:
# compute accuracy
acc = accuracy_function(y_pred, y)
# plot model prediction
fig, ax = plt.subplots(1, 3, figsize=figsize)
ax[2].imshow(y_pred, cmap=cmap, interpolation='nearest', norm=norm)
ax[2].set_title('Prediction ({:.2f}%)'.format(acc * 100), pad=15)
# create a patch (proxy artist) for every color
patches = [mpatches.Patch(color=c, label=l) for c, l in
zip(colors, ulabels)]
else:
fig, ax = plt.subplots(1, 2, figsize=figsize)
# initialize figure
fig, ax = plt.subplots(1, 3, figsize=figsize)
# plot false color composite
ax[0].imshow(rgb)
ax[0].set_title('R = {}, G = {}, B = {}'.format(*bands), pad=15)
# plot ground thruth mask
ax[1].imshow(y, cmap=cmap, interpolation='nearest', norm=norm)
ax[1].set_title('Ground truth', pad=15)
# check whether to plot ground truth
acc = None
if y is None:
# remove axis to plot ground truth from figure
fig.delaxes(ax[1])
else:
# plot ground thruth mask
ax[1].imshow(y, cmap=cmap, interpolation='nearest', norm=norm)
ax[1].set_title('Ground truth', pad=15)
# check whether to plot model prediction
if y_pred is None:
# remove axis to plot model prediction from figure
fig.delaxes(ax[2])
else:
# plot model prediction
ax[2].imshow(y_pred, cmap=cmap, interpolation='nearest', norm=norm)
# create a patch (proxy artist) for every color
patches = [mpatches.Patch(color=c, label=l) for c, l in
zip(colors, ulabels)]
# set title
title = 'Prediction'
if y is not None:
acc = accuracy_function(y_pred, y)
title += ' ({:.2f}%)'.format(acc * 100)
ax[2].set_title(title, pad=15)
# plot patches as legend
plt.legend(handles=patches, bbox_to_anchor=(1.05, 1), loc=2,
frameon=False)
# if a ground truth or a model prediction is plotted, add legend
if len(fig.axes) > 1:
plt.legend(handles=patches, bbox_to_anchor=(1.05, 1), loc=2,
frameon=False)
# save figure
if state is not None:
......@@ -184,8 +192,6 @@ def plot_sample(x, y, use_bands, labels, y_pred=None, figsize=(10, 10),
return fig, ax
# plot_confusion_matrix() plots the confusion matrix of the validation/test
# set returned by the pytorch.predict function
def plot_confusion_matrix(cm, labels, normalize=True,
figsize=(10, 10), cmap='Blues', state=None,
outpath=os.path.join(HERE, '_graphics/')):
......@@ -213,7 +219,7 @@ def plot_confusion_matrix(cm, labels, normalize=True,
state file ending with '.pt'. The default is None, i.e. plot is not
saved to disk.
outpath : `str` or `pathlib.Path`, optional
Output path. The default is os.path.join(HERE, '_graphics/').
Output path. The default is 'pysegcnn/main/_graphics/'.
Returns
-------
......@@ -306,7 +312,7 @@ def plot_loss(state_file, figsize=(10, 10), step=5,
A list of four named colors supported by `matplotlib`.
The default is ['lightgreen', 'green', 'skyblue', 'steelblue'].
outpath : `str` or `pathlib.Path`, optional
Output path. The default is os.path.join(HERE, '_graphics/').
Output path. The default is 'pysegcnn/main/_graphics/'.
Returns
-------
......
......@@ -168,9 +168,9 @@ def predict_samples(ds, model, cm=False, plot=False, **kwargs):
# plot inputs, ground truth and model predictions
sname = fname + '_{}_{}.pt'.format(ds.name, batch)
fig, ax = plot_sample(inputs.numpy().clip(0, 1),
labels,
ds.dataset.use_bands,
ds.dataset.labels,
y=labels,
y_pred=prd,
state=sname,
**kwargs)
......@@ -298,9 +298,9 @@ def predict_scenes(ds, model, scene_id=None, cm=False, plot=False, **kwargs):
# plot current scene
if plot:
fig, ax = plot_sample(inputs.clip(0, 1),
labels,
ds.dataset.use_bands,
ds.dataset.labels,
y=labels,
y_pred=prdtcn,
state=sname,
**kwargs)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment