From e554e6537200cf2717ec529ddb166de3661b0a7a Mon Sep 17 00:00:00 2001 From: "Daniel.Frisinghelli" <daniel.frisinghelli@eurac.edu> Date: Thu, 20 Aug 2020 15:57:43 +0200 Subject: [PATCH] Made the ground truth mask in plot_sample optional --- pysegcnn/core/graphics.py | 80 +++++++++++++++++++++------------------ pysegcnn/core/predict.py | 4 +- 2 files changed, 45 insertions(+), 39 deletions(-) diff --git a/pysegcnn/core/graphics.py b/pysegcnn/core/graphics.py index 5303d97..dfb1034 100644 --- a/pysegcnn/core/graphics.py +++ b/pysegcnn/core/graphics.py @@ -32,8 +32,6 @@ from pysegcnn.core.trainer import accuracy_function from pysegcnn.main.config import HERE -# this function applies percentile stretching at the alpha level -# can be used to increase constrast for visualization def contrast_stretching(image, alpha=5): """Apply percentile stretching to an image to increase constrast. @@ -85,9 +83,7 @@ def running_mean(x, w): return (cumsum[w:] - cumsum[:-w]) / w -# plot_sample() plots a false color composite of the scene/tile together -# with the model prediction and the corresponding ground truth -def plot_sample(x, y, use_bands, labels, y_pred=None, figsize=(10, 10), +def plot_sample(x, use_bands, labels, y=None, y_pred=None, figsize=(10, 10), bands=['nir', 'red', 'green'], state=None, outpath=os.path.join(HERE, '_samples/'), alpha=0): """Plot false color composite (FCC), ground truth and model prediction. @@ -96,8 +92,6 @@ def plot_sample(x, y, use_bands, labels, y_pred=None, figsize=(10, 10), ---------- x : `numpy.ndarray` or `torch.Tensor`, (b, h, w) Array containing the raw data of the tile, shape=(bands, height, width) - y : `numpy.ndarray` or `torch.Tensor`, (h, w) - Array containing the ground truth of tile ``x``, shape=(height, width) use_bands : `list` of `str` List describing the order of the bands in ``x``. labels : `dict` [`int`, `dict`] @@ -107,9 +101,12 @@ def plot_sample(x, y, use_bands, labels, y_pred=None, figsize=(10, 10), A named color (`str`). ``'label'`` The name of the class label (`str`). - y_pred : `numpy.ndarray` or `None`, optional + y : `numpy.ndarray` or `torch.Tensor` or `None`, optional + Array containing the ground truth of tile ``x``, shape=(height, width). + The default is None. + y_pred : `numpy.ndarray` or `torch.Tensor` or `None`, optional Array containing the prediction for tile ``x``, shape=(height, width). - The default is None, i.e. only FCC and ground truth are plotted. + The default is None. figsize : `tuple`, optional The figure size in centimeters. The default is (10, 10). bands : `list` [`str`], optional @@ -119,7 +116,7 @@ def plot_sample(x, y, use_bands, labels, y_pred=None, figsize=(10, 10), state file ending with '.pt'. The default is None, i.e. plot is not saved to disk. outpath : `str` or `pathlib.Path`, optional - Output path. The default is os.path.join(HERE, '_samples/'). + Output path. The default is 'pysegcnn/main/_samples'. alpha : `int`, optional The level of the percentiles to increase constrast in the FCC. The default is 0, i.e. no stretching. @@ -128,8 +125,8 @@ def plot_sample(x, y, use_bands, labels, y_pred=None, figsize=(10, 10), ------- fig : `matplotlib.figure.Figure` The figure handle. - ax : `matplotlib.axes._subplots.AxesSubplot` - The axes handle. + ax : `numpy.ndarray` [`matplotlib.axes._subplots.AxesSubplot`] + An array of the axes handles. """ # check whether to apply constrast stretching @@ -145,35 +142,46 @@ def plot_sample(x, y, use_bands, labels, y_pred=None, figsize=(10, 10), boundaries = [*labels.keys(), cmap.N] norm = BoundaryNorm(boundaries, cmap.N) - # create figure: check whether to plot model prediction - if y_pred is not None: - - # compute accuracy - acc = accuracy_function(y_pred, y) - - # plot model prediction - fig, ax = plt.subplots(1, 3, figsize=figsize) - ax[2].imshow(y_pred, cmap=cmap, interpolation='nearest', norm=norm) - ax[2].set_title('Prediction ({:.2f}%)'.format(acc * 100), pad=15) + # create a patch (proxy artist) for every color + patches = [mpatches.Patch(color=c, label=l) for c, l in + zip(colors, ulabels)] - else: - fig, ax = plt.subplots(1, 2, figsize=figsize) + # initialize figure + fig, ax = plt.subplots(1, 3, figsize=figsize) # plot false color composite ax[0].imshow(rgb) ax[0].set_title('R = {}, G = {}, B = {}'.format(*bands), pad=15) - # plot ground thruth mask - ax[1].imshow(y, cmap=cmap, interpolation='nearest', norm=norm) - ax[1].set_title('Ground truth', pad=15) + # check whether to plot ground truth + acc = None + if y is None: + # remove axis to plot ground truth from figure + fig.delaxes(ax[1]) + else: + # plot ground thruth mask + ax[1].imshow(y, cmap=cmap, interpolation='nearest', norm=norm) + ax[1].set_title('Ground truth', pad=15) + + # check whether to plot model prediction + if y_pred is None: + # remove axis to plot model prediction from figure + fig.delaxes(ax[2]) + else: + # plot model prediction + ax[2].imshow(y_pred, cmap=cmap, interpolation='nearest', norm=norm) - # create a patch (proxy artist) for every color - patches = [mpatches.Patch(color=c, label=l) for c, l in - zip(colors, ulabels)] + # set title + title = 'Prediction' + if y is not None: + acc = accuracy_function(y_pred, y) + title += ' ({:.2f}%)'.format(acc * 100) + ax[2].set_title(title, pad=15) - # plot patches as legend - plt.legend(handles=patches, bbox_to_anchor=(1.05, 1), loc=2, - frameon=False) + # if a ground truth or a model prediction is plotted, add legend + if len(fig.axes) > 1: + plt.legend(handles=patches, bbox_to_anchor=(1.05, 1), loc=2, + frameon=False) # save figure if state is not None: @@ -184,8 +192,6 @@ def plot_sample(x, y, use_bands, labels, y_pred=None, figsize=(10, 10), return fig, ax -# plot_confusion_matrix() plots the confusion matrix of the validation/test -# set returned by the pytorch.predict function def plot_confusion_matrix(cm, labels, normalize=True, figsize=(10, 10), cmap='Blues', state=None, outpath=os.path.join(HERE, '_graphics/')): @@ -213,7 +219,7 @@ def plot_confusion_matrix(cm, labels, normalize=True, state file ending with '.pt'. The default is None, i.e. plot is not saved to disk. outpath : `str` or `pathlib.Path`, optional - Output path. The default is os.path.join(HERE, '_graphics/'). + Output path. The default is 'pysegcnn/main/_graphics/'. Returns ------- @@ -306,7 +312,7 @@ def plot_loss(state_file, figsize=(10, 10), step=5, A list of four named colors supported by `matplotlib`. The default is ['lightgreen', 'green', 'skyblue', 'steelblue']. outpath : `str` or `pathlib.Path`, optional - Output path. The default is os.path.join(HERE, '_graphics/'). + Output path. The default is 'pysegcnn/main/_graphics/'. Returns ------- diff --git a/pysegcnn/core/predict.py b/pysegcnn/core/predict.py index 9405d7c..6403a7a 100644 --- a/pysegcnn/core/predict.py +++ b/pysegcnn/core/predict.py @@ -168,9 +168,9 @@ def predict_samples(ds, model, cm=False, plot=False, **kwargs): # plot inputs, ground truth and model predictions sname = fname + '_{}_{}.pt'.format(ds.name, batch) fig, ax = plot_sample(inputs.numpy().clip(0, 1), - labels, ds.dataset.use_bands, ds.dataset.labels, + y=labels, y_pred=prd, state=sname, **kwargs) @@ -298,9 +298,9 @@ def predict_scenes(ds, model, scene_id=None, cm=False, plot=False, **kwargs): # plot current scene if plot: fig, ax = plot_sample(inputs.clip(0, 1), - labels, ds.dataset.use_bands, ds.dataset.labels, + y=labels, y_pred=prdtcn, state=sname, **kwargs) -- GitLab