diff --git a/pysegcnn/core/graphics.py b/pysegcnn/core/graphics.py index 3f2bfdb24f47f37a62bf6a5e7a824197c910d2d8..0dc3d01c484764b5228ef1228c68663016d98e27 100644 --- a/pysegcnn/core/graphics.py +++ b/pysegcnn/core/graphics.py @@ -600,7 +600,7 @@ def plot_class_distribution(ds, figsize=(16, 9), alpha=0.5): nbands = len(ds.use_bands) # create a figure based on the number of spectral bands in the dataset - fig, axes = plt.subplots(min(3, nbands), int(np.ceil(max(1, nbands / 3))), + fig, axes = plt.subplots(int(np.ceil(max(1, nbands / 2))), min(2, nbands), figsize=figsize, sharex=True, sharey=True) axes = axes.flatten() @@ -616,10 +616,14 @@ def plot_class_distribution(ds, figsize=(16, 9), alpha=0.5): # plot spectral distribution of the classes in the current band bplot = ax.boxplot(data, labels=labels, patch_artist=True, - whis=[5, 95], showfliers=False) + whis=[5, 95], showfliers=False, showmeans=True, + meanline=True, meanprops={'color': 'black', + 'ls': 'dashed'}, + medianprops={'color': 'black'}) # set axis y-limits: physical limits are (0, 1) for reflectance data - ax.set_ylim(0, 1) + ax.set_ylim(0, 1.2) + ax.set_yticks(np.arange(ax.get_ylim()[0], 1.2, 0.2)) # set colors of the boxes for the classes for k, artists in bplot.items(): @@ -630,30 +634,37 @@ def plot_class_distribution(ds, figsize=(16, 9), alpha=0.5): # iterate over the artists for c, art in enumerate(artists): - # line artists - if isinstance(art, matplotlib.lines.Line2D): - # set the colors of the lines in the boxplot - art.set_color(ds.labels[c]['color']) - # patch artists - elif isinstance(art, matplotlib.patches.Patch): + if isinstance(art, matplotlib.patches.Patch): # set the colors of the patches art.set_facecolor(ds.labels[c]['color']) art.set_alpha(alpha) # add name of the spectral band to the plot - ax.text(x=0.6, y=0.95, s='({}: {})'.format( - getattr(ds.sensor, ds.use_bands[band]).value, + ax.text(x=0.6, y=ax.get_ylim()[1] - 0.05, s='({}: {})'.format( + getattr(ds.sensor, ds.use_bands[band]).number, getattr(ds.sensor, ds.use_bands[band]).name), - ha='left', va='top', weight='bold') + ha='left', va='top', weight='bold') + + # add mean value to plot + upper_whisker = [item.get_ydata()[1] for item in + bplot['whiskers'][1::2]] + for i, (d, w) in enumerate(zip(data, upper_whisker)): + # calculate mean + mu = d.mean() + + # add mean value to plot at the top of each whisker + ax.text(x=i + 1, y=w.item() + 0.025, s='{:.2f}'.format(mu), + ha='center', va='bottom') # create a patch (proxy artist) for every class patches = [mpatches.Patch(color=ds.labels[k]['color'], label=v) for k, v in npix_per_class.items()] # add legend with number of pixels per class to plot - axes[-1].legend(handles=patches, loc='upper right', frameon=False, - bbox_to_anchor=(-0.15, -0.15), ncol=len(labels)) + axes[-2].legend(handles=patches, loc='upper left', frameon=False, + bbox_to_anchor=(0, -axes[-2].get_position().y0), + ncol=len(labels)) # adjust space between subplots fig.subplots_adjust(hspace=0.075, wspace=0.025)