From f8e77f34480e616cd0bcf9126aa0fda33137b453 Mon Sep 17 00:00:00 2001 From: "Daniel.Frisinghelli" <daniel.frisinghelli@eurac.edu> Date: Wed, 27 Jan 2021 18:14:21 +0100 Subject: [PATCH] Implemented plot function for class distributions. --- pysegcnn/core/graphics.py | 91 +++++++++++++++++++++++++++++++++++++-- 1 file changed, 88 insertions(+), 3 deletions(-) diff --git a/pysegcnn/core/graphics.py b/pysegcnn/core/graphics.py index de88ee0..9a426b2 100644 --- a/pysegcnn/core/graphics.py +++ b/pysegcnn/core/graphics.py @@ -35,9 +35,9 @@ from pysegcnn.core.utils import accuracy_function, check_filename_length from pysegcnn.main.train_config import HERE # plot font size configuration -SMALL = 10 -MEDIUM = 12 -BIG = 14 +SMALL = 12 +MEDIUM = 14 +BIG = 16 # controls default font size plt.rc('font', size=MEDIUM) @@ -562,6 +562,91 @@ def plot_loss(state_file, figsize=(10, 10), step=5, return fig +def plot_class_distribution(ds, figsize=(16, 9), alpha=0.5): + """Plot the spectral distribution of the different classes in ``ds``. + + Parameters + ---------- + ds : :py:class:`pysegcnn.core.dataset.ImageDataset` + An instance of the dataset to plot the class distribution over the + different spectral bands. Make sure to initialize ``ds`` with the + parameter ``tile_size=None``, which conserves the original size of + each image in the dataset. + + Returns + ------- + cls_df : :py:class:`pandas.DataFrame` + The class distribution DataFrame. + + """ + # compute class distribution + cls_ds = ds.class_distribution() + + # drop classes which are not represented in the dataset + cls_ds = {k: v for k, v in cls_ds.items() if np.any(v)} + + # labels for the different classes + labels = [ds.labels[cls_id]['label'] for cls_id in cls_ds.keys()] + + # number of spectral bands in the dataset + nbands = len(ds.use_bands) + + # create a figure based on the number of spectral bands in the dataset + fig, axes = plt.subplots(min(3, nbands), int(np.ceil(max(1, nbands / 3))), + figsize=figsize, sharex=True, sharey=True) + axes = axes.flatten() + + # iterate over the different bands + for band in range(nbands): + + # current axis + ax = axes[band] + + # get the spectral data for each class + data = [x[:, band] for x in cls_ds.values()] + + # plot spectral distribution of the classes in the current band + bplot = ax.boxplot(data, labels=labels, patch_artist=True, + whis=[5, 95], showfliers=False) + + # set axis y-limits: physical limits are (0, 1) for reflectance data + ax.set_ylim(0, 1) + + # set colors of the boxes for the classes + for k, artists in bplot.items(): + + # the artists to color + if k in ['boxes', 'medians']: + + # iterate over the artists + for c, art in enumerate(artists): + + # line artists + if isinstance(art, matplotlib.lines.Line2D): + # set the colors of the lines in the boxplot + art.set_color(ds.labels[c]['color']) + + # patch artists + elif isinstance(art, matplotlib.patches.Patch): + # set the colors of the patches + art.set_facecolor(ds.labels[c]['color']) + art.set_alpha(alpha) + + # add name of the spectral band to the plot + ax.text(x=0.6, y=0.95, s='({})'.format(ds.use_bands[band]), + ha='left', va='top', weight='bold') + + # adjust space between subplots + fig.subplots_adjust(hspace=0.075, wspace=0.025) + + # remove empty axes + for ax in axes: + if not ax.lines: + fig.delaxes(ax) + + return fig + + class Animate(object): """Easily create animations with :py:mod:`matplotlib`. -- GitLab