diff --git a/pysegcnn/core/graphics.py b/pysegcnn/core/graphics.py index 9a426b2e04a6edd6c4158105a3e2b54f346a1c83..368208c8a9918b4c652fb6857e075a00a4696311 100644 --- a/pysegcnn/core/graphics.py +++ b/pysegcnn/core/graphics.py @@ -18,6 +18,7 @@ License import os import pathlib import itertools +import logging # externals import numpy as np @@ -58,6 +59,9 @@ plt.rc('figure', titlesize=BIG) # training metrics METRICS = ['train_loss', 'train_accu', 'valid_loss', 'valid_accu'] +# module level logger +LOGGER = logging.getLogger(__name__) + def contrast_stretching(image, alpha=5): """Apply `normalization`_ to an image to increase constrast. @@ -585,6 +589,10 @@ def plot_class_distribution(ds, figsize=(16, 9), alpha=0.5): # drop classes which are not represented in the dataset cls_ds = {k: v for k, v in cls_ds.items() if np.any(v)} + # number of pixels of each class + npix_per_class = {k: '{:.2f}M'.format(v.shape[0] * 1e-6) for k, v in + cls_ds.items()} + # labels for the different classes labels = [ds.labels[cls_id]['label'] for cls_id in cls_ds.keys()] @@ -598,6 +606,7 @@ def plot_class_distribution(ds, figsize=(16, 9), alpha=0.5): # iterate over the different bands for band in range(nbands): + LOGGER.info('Plotting band: {}'.format(ds.use_bands[band])) # current axis ax = axes[band] @@ -636,13 +645,23 @@ def plot_class_distribution(ds, figsize=(16, 9), alpha=0.5): ax.text(x=0.6, y=0.95, s='({})'.format(ds.use_bands[band]), ha='left', va='top', weight='bold') + # create a patch (proxy artist) for every class + patches = [mpatches.Patch(color=ds.labels[k]['color'], label=v) for k, v in + npix_per_class.items()] + + # add legend with number of pixels per class to plot + axes[-1].legend(handles=patches, loc=2, frameon=False) + # adjust space between subplots fig.subplots_adjust(hspace=0.075, wspace=0.025) # remove empty axes for ax in axes: if not ax.lines: - fig.delaxes(ax) + # hide axis ticks, labels and text artists + ax.axis('off') + for t in ax.texts: + t.set_visible(False) return fig @@ -777,24 +796,3 @@ class Animate(object): # save animation to disk self.animation.save(str(self.path.joinpath(filename)), writer='imagemagick', **kwargs) - - -def _plot_composites(ds, path, fmt, dpi=300, alpha=0): - """Utility function to plot each scene of a dataset.""" - - # iterate over the scenes of the dataset - for scene in range(len(ds)): - # name of the current scene - scene_id = ds.scenes[scene]['id'] - print(scene_id) - - # get the data of the current scene - x, y = ds[scene] - - # plot the current scene - fig = plot_sample(x, ds.use_bands, ds.labels, y=y, hide_labels=True, - bands=['swir2', 'nir', 'green'], alpha=alpha) - - # save the figure as vector graphic - fig.savefig(path.joinpath(scene_id + '.{}'.format(fmt)), dpi=dpi, - bbox_inches='tight', format=fmt)