Skip to content
Snippets Groups Projects
Commit 558e5a74 authored by Frisinghelli Daniel's avatar Frisinghelli Daniel
Browse files

Added descriptions to class distribution plot.

parent 1e62b4ab
No related branches found
No related tags found
No related merge requests found
......@@ -18,6 +18,7 @@ License
import os
import pathlib
import itertools
import logging
# externals
import numpy as np
......@@ -58,6 +59,9 @@ plt.rc('figure', titlesize=BIG)
# training metrics
METRICS = ['train_loss', 'train_accu', 'valid_loss', 'valid_accu']
# module level logger
LOGGER = logging.getLogger(__name__)
def contrast_stretching(image, alpha=5):
"""Apply `normalization`_ to an image to increase constrast.
......@@ -585,6 +589,10 @@ def plot_class_distribution(ds, figsize=(16, 9), alpha=0.5):
# drop classes which are not represented in the dataset
cls_ds = {k: v for k, v in cls_ds.items() if np.any(v)}
# number of pixels of each class
npix_per_class = {k: '{:.2f}M'.format(v.shape[0] * 1e-6) for k, v in
cls_ds.items()}
# labels for the different classes
labels = [ds.labels[cls_id]['label'] for cls_id in cls_ds.keys()]
......@@ -598,6 +606,7 @@ def plot_class_distribution(ds, figsize=(16, 9), alpha=0.5):
# iterate over the different bands
for band in range(nbands):
LOGGER.info('Plotting band: {}'.format(ds.use_bands[band]))
# current axis
ax = axes[band]
......@@ -636,13 +645,23 @@ def plot_class_distribution(ds, figsize=(16, 9), alpha=0.5):
ax.text(x=0.6, y=0.95, s='({})'.format(ds.use_bands[band]),
ha='left', va='top', weight='bold')
# create a patch (proxy artist) for every class
patches = [mpatches.Patch(color=ds.labels[k]['color'], label=v) for k, v in
npix_per_class.items()]
# add legend with number of pixels per class to plot
axes[-1].legend(handles=patches, loc=2, frameon=False)
# adjust space between subplots
fig.subplots_adjust(hspace=0.075, wspace=0.025)
# remove empty axes
for ax in axes:
if not ax.lines:
fig.delaxes(ax)
# hide axis ticks, labels and text artists
ax.axis('off')
for t in ax.texts:
t.set_visible(False)
return fig
......@@ -777,24 +796,3 @@ class Animate(object):
# save animation to disk
self.animation.save(str(self.path.joinpath(filename)),
writer='imagemagick', **kwargs)
def _plot_composites(ds, path, fmt, dpi=300, alpha=0):
"""Utility function to plot each scene of a dataset."""
# iterate over the scenes of the dataset
for scene in range(len(ds)):
# name of the current scene
scene_id = ds.scenes[scene]['id']
print(scene_id)
# get the data of the current scene
x, y = ds[scene]
# plot the current scene
fig = plot_sample(x, ds.use_bands, ds.labels, y=y, hide_labels=True,
bands=['swir2', 'nir', 'green'], alpha=alpha)
# save the figure as vector graphic
fig.savefig(path.joinpath(scene_id + '.{}'.format(fmt)), dpi=dpi,
bbox_inches='tight', format=fmt)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment