From f8e77f34480e616cd0bcf9126aa0fda33137b453 Mon Sep 17 00:00:00 2001
From: "Daniel.Frisinghelli" <daniel.frisinghelli@eurac.edu>
Date: Wed, 27 Jan 2021 18:14:21 +0100
Subject: [PATCH] Implemented plot function for class distributions.

---
 pysegcnn/core/graphics.py | 91 +++++++++++++++++++++++++++++++++++++--
 1 file changed, 88 insertions(+), 3 deletions(-)

diff --git a/pysegcnn/core/graphics.py b/pysegcnn/core/graphics.py
index de88ee0..9a426b2 100644
--- a/pysegcnn/core/graphics.py
+++ b/pysegcnn/core/graphics.py
@@ -35,9 +35,9 @@ from pysegcnn.core.utils import accuracy_function, check_filename_length
 from pysegcnn.main.train_config import HERE
 
 # plot font size configuration
-SMALL = 10
-MEDIUM = 12
-BIG = 14
+SMALL = 12
+MEDIUM = 14
+BIG = 16
 
 # controls default font size
 plt.rc('font', size=MEDIUM)
@@ -562,6 +562,91 @@ def plot_loss(state_file, figsize=(10, 10), step=5,
     return fig
 
 
+def plot_class_distribution(ds, figsize=(16, 9), alpha=0.5):
+    """Plot the spectral distribution of the different classes in ``ds``.
+
+    Parameters
+    ----------
+    ds : :py:class:`pysegcnn.core.dataset.ImageDataset`
+        An instance of the dataset to plot the class distribution over the
+        different spectral bands. Make sure to initialize ``ds`` with the
+        parameter ``tile_size=None``, which conserves the original size of
+        each image in the dataset.
+
+    Returns
+    -------
+    cls_df : :py:class:`pandas.DataFrame`
+        The class distribution DataFrame.
+
+    """
+    # compute class distribution
+    cls_ds = ds.class_distribution()
+
+    # drop classes which are not represented in the dataset
+    cls_ds = {k: v for k, v in cls_ds.items() if np.any(v)}
+
+    # labels for the different classes
+    labels = [ds.labels[cls_id]['label'] for cls_id in cls_ds.keys()]
+
+    # number of spectral bands in the dataset
+    nbands = len(ds.use_bands)
+
+    # create a figure based on the number of spectral bands in the dataset
+    fig, axes = plt.subplots(min(3, nbands), int(np.ceil(max(1, nbands / 3))),
+                             figsize=figsize, sharex=True, sharey=True)
+    axes = axes.flatten()
+
+    # iterate over the different bands
+    for band in range(nbands):
+
+        # current axis
+        ax = axes[band]
+
+        # get the spectral data for each class
+        data = [x[:, band] for x in cls_ds.values()]
+
+        # plot spectral distribution of the classes in the current band
+        bplot = ax.boxplot(data, labels=labels, patch_artist=True,
+                           whis=[5, 95], showfliers=False)
+
+        # set axis y-limits: physical limits are (0, 1) for reflectance data
+        ax.set_ylim(0, 1)
+
+        # set colors of the boxes for the classes
+        for k, artists in bplot.items():
+
+            # the artists to color
+            if k in ['boxes', 'medians']:
+
+                # iterate over the artists
+                for c, art in enumerate(artists):
+
+                    # line artists
+                    if isinstance(art, matplotlib.lines.Line2D):
+                        # set the colors of the lines in the boxplot
+                        art.set_color(ds.labels[c]['color'])
+
+                    # patch artists
+                    elif isinstance(art, matplotlib.patches.Patch):
+                        # set the colors of the patches
+                        art.set_facecolor(ds.labels[c]['color'])
+                        art.set_alpha(alpha)
+
+        # add name of the spectral band to the plot
+        ax.text(x=0.6, y=0.95, s='({})'.format(ds.use_bands[band]),
+                ha='left', va='top', weight='bold')
+
+    # adjust space between subplots
+    fig.subplots_adjust(hspace=0.075, wspace=0.025)
+
+    # remove empty axes
+    for ax in axes:
+        if not ax.lines:
+            fig.delaxes(ax)
+
+    return fig
+
+
 class Animate(object):
     """Easily create animations with :py:mod:`matplotlib`.
 
-- 
GitLab