diff --git a/pysegcnn/core/graphics.py b/pysegcnn/core/graphics.py index 0dc3d01c484764b5228ef1228c68663016d98e27..199946c4b71b078d72517531068bf8a5ce598dca 100644 --- a/pysegcnn/core/graphics.py +++ b/pysegcnn/core/graphics.py @@ -593,8 +593,9 @@ def plot_class_distribution(ds, figsize=(16, 9), alpha=0.5): npix_per_class = {k: '{:.2f}M'.format(v.shape[0] * 1e-6) for k, v in cls_ds.items()} - # labels for the different classes + # labels and colors for the different classes labels = [ds.labels[cls_id]['label'] for cls_id in cls_ds.keys()] + colors = [ds.labels[cls_id]['color'] for cls_id in cls_ds.keys()] # number of spectral bands in the dataset nbands = len(ds.use_bands) @@ -637,7 +638,7 @@ def plot_class_distribution(ds, figsize=(16, 9), alpha=0.5): # patch artists if isinstance(art, matplotlib.patches.Patch): # set the colors of the patches - art.set_facecolor(ds.labels[c]['color']) + art.set_facecolor(colors[c]) art.set_alpha(alpha) # add name of the spectral band to the plot