From 558e5a74e9fa092efdf9f6303c1838c4e78fb643 Mon Sep 17 00:00:00 2001
From: "Daniel.Frisinghelli" <daniel.frisinghelli@eurac.edu>
Date: Thu, 28 Jan 2021 17:39:45 +0100
Subject: [PATCH] Added descriptions to class distribution plot.

---
 pysegcnn/core/graphics.py | 42 +++++++++++++++++++--------------------
 1 file changed, 20 insertions(+), 22 deletions(-)

diff --git a/pysegcnn/core/graphics.py b/pysegcnn/core/graphics.py
index 9a426b2..368208c 100644
--- a/pysegcnn/core/graphics.py
+++ b/pysegcnn/core/graphics.py
@@ -18,6 +18,7 @@ License
 import os
 import pathlib
 import itertools
+import logging
 
 # externals
 import numpy as np
@@ -58,6 +59,9 @@ plt.rc('figure', titlesize=BIG)
 # training metrics
 METRICS = ['train_loss', 'train_accu', 'valid_loss', 'valid_accu']
 
+# module level logger
+LOGGER = logging.getLogger(__name__)
+
 
 def contrast_stretching(image, alpha=5):
     """Apply `normalization`_ to an image to increase constrast.
@@ -585,6 +589,10 @@ def plot_class_distribution(ds, figsize=(16, 9), alpha=0.5):
     # drop classes which are not represented in the dataset
     cls_ds = {k: v for k, v in cls_ds.items() if np.any(v)}
 
+    # number of pixels of each class
+    npix_per_class = {k: '{:.2f}M'.format(v.shape[0] * 1e-6) for k, v in
+                      cls_ds.items()}
+
     # labels for the different classes
     labels = [ds.labels[cls_id]['label'] for cls_id in cls_ds.keys()]
 
@@ -598,6 +606,7 @@ def plot_class_distribution(ds, figsize=(16, 9), alpha=0.5):
 
     # iterate over the different bands
     for band in range(nbands):
+        LOGGER.info('Plotting band: {}'.format(ds.use_bands[band]))
 
         # current axis
         ax = axes[band]
@@ -636,13 +645,23 @@ def plot_class_distribution(ds, figsize=(16, 9), alpha=0.5):
         ax.text(x=0.6, y=0.95, s='({})'.format(ds.use_bands[band]),
                 ha='left', va='top', weight='bold')
 
+    # create a patch (proxy artist) for every class
+    patches = [mpatches.Patch(color=ds.labels[k]['color'], label=v) for k, v in
+               npix_per_class.items()]
+
+    # add legend with number of pixels per class to plot
+    axes[-1].legend(handles=patches, loc=2, frameon=False)
+
     # adjust space between subplots
     fig.subplots_adjust(hspace=0.075, wspace=0.025)
 
     # remove empty axes
     for ax in axes:
         if not ax.lines:
-            fig.delaxes(ax)
+            # hide axis ticks, labels and text artists
+            ax.axis('off')
+            for t in ax.texts:
+                t.set_visible(False)
 
     return fig
 
@@ -777,24 +796,3 @@ class Animate(object):
         # save animation to disk
         self.animation.save(str(self.path.joinpath(filename)),
                             writer='imagemagick', **kwargs)
-
-
-def _plot_composites(ds, path, fmt, dpi=300, alpha=0):
-    """Utility function to plot each scene of a dataset."""
-
-    # iterate over the scenes of the dataset
-    for scene in range(len(ds)):
-        # name of the current scene
-        scene_id = ds.scenes[scene]['id']
-        print(scene_id)
-
-        # get the data of the current scene
-        x, y = ds[scene]
-
-        # plot the current scene
-        fig = plot_sample(x, ds.use_bands, ds.labels, y=y, hide_labels=True,
-                          bands=['swir2', 'nir', 'green'], alpha=alpha)
-
-        # save the figure as vector graphic
-        fig.savefig(path.joinpath(scene_id + '.{}'.format(fmt)), dpi=dpi,
-                    bbox_inches='tight', format=fmt)
-- 
GitLab