diff --git a/pysegcnn/core/graphics.py b/pysegcnn/core/graphics.py
index 0dc3d01c484764b5228ef1228c68663016d98e27..199946c4b71b078d72517531068bf8a5ce598dca 100644
--- a/pysegcnn/core/graphics.py
+++ b/pysegcnn/core/graphics.py
@@ -593,8 +593,9 @@ def plot_class_distribution(ds, figsize=(16, 9), alpha=0.5):
     npix_per_class = {k: '{:.2f}M'.format(v.shape[0] * 1e-6) for k, v in
                       cls_ds.items()}
 
-    # labels for the different classes
+    # labels and colors for the different classes
     labels = [ds.labels[cls_id]['label'] for cls_id in cls_ds.keys()]
+    colors = [ds.labels[cls_id]['color'] for cls_id in cls_ds.keys()]
 
     # number of spectral bands in the dataset
     nbands = len(ds.use_bands)
@@ -637,7 +638,7 @@ def plot_class_distribution(ds, figsize=(16, 9), alpha=0.5):
                     # patch artists
                     if isinstance(art, matplotlib.patches.Patch):
                         # set the colors of the patches
-                        art.set_facecolor(ds.labels[c]['color'])
+                        art.set_facecolor(colors[c])
                         art.set_alpha(alpha)
 
         # add name of the spectral band to the plot