From e554e6537200cf2717ec529ddb166de3661b0a7a Mon Sep 17 00:00:00 2001
From: "Daniel.Frisinghelli" <daniel.frisinghelli@eurac.edu>
Date: Thu, 20 Aug 2020 15:57:43 +0200
Subject: [PATCH] Made the ground truth mask in plot_sample optional

---
 pysegcnn/core/graphics.py | 80 +++++++++++++++++++++------------------
 pysegcnn/core/predict.py  |  4 +-
 2 files changed, 45 insertions(+), 39 deletions(-)

diff --git a/pysegcnn/core/graphics.py b/pysegcnn/core/graphics.py
index 5303d97..dfb1034 100644
--- a/pysegcnn/core/graphics.py
+++ b/pysegcnn/core/graphics.py
@@ -32,8 +32,6 @@ from pysegcnn.core.trainer import accuracy_function
 from pysegcnn.main.config import HERE
 
 
-# this function applies percentile stretching at the alpha level
-# can be used to increase constrast for visualization
 def contrast_stretching(image, alpha=5):
     """Apply percentile stretching to an image to increase constrast.
 
@@ -85,9 +83,7 @@ def running_mean(x, w):
     return (cumsum[w:] - cumsum[:-w]) / w
 
 
-# plot_sample() plots a false color composite of the scene/tile together
-# with the model prediction and the corresponding ground truth
-def plot_sample(x, y, use_bands, labels, y_pred=None, figsize=(10, 10),
+def plot_sample(x, use_bands, labels, y=None, y_pred=None, figsize=(10, 10),
                 bands=['nir', 'red', 'green'], state=None,
                 outpath=os.path.join(HERE, '_samples/'), alpha=0):
     """Plot false color composite (FCC), ground truth and model prediction.
@@ -96,8 +92,6 @@ def plot_sample(x, y, use_bands, labels, y_pred=None, figsize=(10, 10),
     ----------
     x : `numpy.ndarray` or `torch.Tensor`, (b, h, w)
         Array containing the raw data of the tile, shape=(bands, height, width)
-    y : `numpy.ndarray` or `torch.Tensor`, (h, w)
-        Array containing the ground truth of tile ``x``, shape=(height, width)
     use_bands : `list` of `str`
         List describing the order of the bands in ``x``.
     labels : `dict` [`int`, `dict`]
@@ -107,9 +101,12 @@ def plot_sample(x, y, use_bands, labels, y_pred=None, figsize=(10, 10),
                 A named color (`str`).
             ``'label'``
                 The name of the class label (`str`).
-    y_pred : `numpy.ndarray` or `None`, optional
+    y : `numpy.ndarray` or `torch.Tensor` or `None`, optional
+        Array containing the ground truth of tile ``x``, shape=(height, width).
+        The default is None.
+    y_pred : `numpy.ndarray` or `torch.Tensor` or `None`, optional
         Array containing the prediction for tile ``x``, shape=(height, width).
-        The default is None, i.e. only FCC and ground truth are plotted.
+        The default is None.
     figsize : `tuple`, optional
         The figure size in centimeters. The default is (10, 10).
     bands : `list` [`str`], optional
@@ -119,7 +116,7 @@ def plot_sample(x, y, use_bands, labels, y_pred=None, figsize=(10, 10),
         state file ending with '.pt'. The default is None, i.e. plot is not
         saved to disk.
     outpath : `str` or `pathlib.Path`, optional
-        Output path. The default is os.path.join(HERE, '_samples/').
+        Output path. The default is 'pysegcnn/main/_samples'.
     alpha : `int`, optional
         The level of the percentiles to increase constrast in the FCC.
         The default is 0, i.e. no stretching.
@@ -128,8 +125,8 @@ def plot_sample(x, y, use_bands, labels, y_pred=None, figsize=(10, 10),
     -------
     fig : `matplotlib.figure.Figure`
         The figure handle.
-    ax : `matplotlib.axes._subplots.AxesSubplot`
-        The axes handle.
+    ax : `numpy.ndarray` [`matplotlib.axes._subplots.AxesSubplot`]
+        An array of the axes handles.
 
     """
     # check whether to apply constrast stretching
@@ -145,35 +142,46 @@ def plot_sample(x, y, use_bands, labels, y_pred=None, figsize=(10, 10),
     boundaries = [*labels.keys(), cmap.N]
     norm = BoundaryNorm(boundaries, cmap.N)
 
-    # create figure: check whether to plot model prediction
-    if y_pred is not None:
-
-        # compute accuracy
-        acc = accuracy_function(y_pred, y)
-
-        # plot model prediction
-        fig, ax = plt.subplots(1, 3, figsize=figsize)
-        ax[2].imshow(y_pred, cmap=cmap, interpolation='nearest', norm=norm)
-        ax[2].set_title('Prediction ({:.2f}%)'.format(acc * 100), pad=15)
+    # create a patch (proxy artist) for every color
+    patches = [mpatches.Patch(color=c, label=l) for c, l in
+               zip(colors, ulabels)]
 
-    else:
-        fig, ax = plt.subplots(1, 2, figsize=figsize)
+    # initialize figure
+    fig, ax = plt.subplots(1, 3, figsize=figsize)
 
     # plot false color composite
     ax[0].imshow(rgb)
     ax[0].set_title('R = {}, G = {}, B = {}'.format(*bands), pad=15)
 
-    # plot ground thruth mask
-    ax[1].imshow(y, cmap=cmap, interpolation='nearest', norm=norm)
-    ax[1].set_title('Ground truth', pad=15)
+    # check whether to plot ground truth
+    acc = None
+    if y is None:
+        # remove axis to plot ground truth from figure
+        fig.delaxes(ax[1])
+    else:
+        # plot ground thruth mask
+        ax[1].imshow(y, cmap=cmap, interpolation='nearest', norm=norm)
+        ax[1].set_title('Ground truth', pad=15)
+
+    # check whether to plot model prediction
+    if y_pred is None:
+        # remove axis to plot model prediction from figure
+        fig.delaxes(ax[2])
+    else:
+        # plot model prediction
+        ax[2].imshow(y_pred, cmap=cmap, interpolation='nearest', norm=norm)
 
-    # create a patch (proxy artist) for every color
-    patches = [mpatches.Patch(color=c, label=l) for c, l in
-               zip(colors, ulabels)]
+        # set title
+        title = 'Prediction'
+        if y is not None:
+            acc = accuracy_function(y_pred, y)
+            title += ' ({:.2f}%)'.format(acc * 100)
+        ax[2].set_title(title, pad=15)
 
-    # plot patches as legend
-    plt.legend(handles=patches, bbox_to_anchor=(1.05, 1), loc=2,
-               frameon=False)
+    # if a ground truth or a model prediction is plotted, add legend
+    if len(fig.axes) > 1:
+        plt.legend(handles=patches, bbox_to_anchor=(1.05, 1), loc=2,
+                   frameon=False)
 
     # save figure
     if state is not None:
@@ -184,8 +192,6 @@ def plot_sample(x, y, use_bands, labels, y_pred=None, figsize=(10, 10),
     return fig, ax
 
 
-# plot_confusion_matrix() plots the confusion matrix of the validation/test
-# set returned by the pytorch.predict function
 def plot_confusion_matrix(cm, labels, normalize=True,
                           figsize=(10, 10), cmap='Blues', state=None,
                           outpath=os.path.join(HERE, '_graphics/')):
@@ -213,7 +219,7 @@ def plot_confusion_matrix(cm, labels, normalize=True,
         state file ending with '.pt'. The default is None, i.e. plot is not
         saved to disk.
     outpath : `str` or `pathlib.Path`, optional
-        Output path. The default is os.path.join(HERE, '_graphics/').
+        Output path. The default is 'pysegcnn/main/_graphics/'.
 
     Returns
     -------
@@ -306,7 +312,7 @@ def plot_loss(state_file, figsize=(10, 10), step=5,
         A list of four named colors supported by `matplotlib`.
         The default is ['lightgreen', 'green', 'skyblue', 'steelblue'].
     outpath : `str` or `pathlib.Path`, optional
-        Output path. The default is os.path.join(HERE, '_graphics/').
+        Output path. The default is 'pysegcnn/main/_graphics/'.
 
     Returns
     -------
diff --git a/pysegcnn/core/predict.py b/pysegcnn/core/predict.py
index 9405d7c..6403a7a 100644
--- a/pysegcnn/core/predict.py
+++ b/pysegcnn/core/predict.py
@@ -168,9 +168,9 @@ def predict_samples(ds, model, cm=False, plot=False, **kwargs):
             # plot inputs, ground truth and model predictions
             sname = fname + '_{}_{}.pt'.format(ds.name, batch)
             fig, ax = plot_sample(inputs.numpy().clip(0, 1),
-                                  labels,
                                   ds.dataset.use_bands,
                                   ds.dataset.labels,
+                                  y=labels,
                                   y_pred=prd,
                                   state=sname,
                                   **kwargs)
@@ -298,9 +298,9 @@ def predict_scenes(ds, model, scene_id=None, cm=False, plot=False, **kwargs):
         # plot current scene
         if plot:
             fig, ax = plot_sample(inputs.clip(0, 1),
-                                  labels,
                                   ds.dataset.use_bands,
                                   ds.dataset.labels,
+                                  y=labels,
                                   y_pred=prdtcn,
                                   state=sname,
                                   **kwargs)
-- 
GitLab