From 7e80b2c95f5067274e5464f123770a93ad64e72a Mon Sep 17 00:00:00 2001
From: "Daniel.Frisinghelli" <daniel.frisinghelli@eurac.edu>
Date: Wed, 3 Feb 2021 17:39:53 +0100
Subject: [PATCH] Fixed a bug: NoData class is now skipped in class
 distribution.

---
 pysegcnn/core/dataset.py | 13 +++++++------
 1 file changed, 7 insertions(+), 6 deletions(-)

diff --git a/pysegcnn/core/dataset.py b/pysegcnn/core/dataset.py
index b4ed855..b9c2107 100644
--- a/pysegcnn/core/dataset.py
+++ b/pysegcnn/core/dataset.py
@@ -32,7 +32,6 @@ import pathlib
 
 # externals
 import numpy as np
-import pandas as pd
 import torch
 from torch.utils.data import Dataset
 
@@ -41,8 +40,7 @@ from pysegcnn.core.constants import (MultiSpectralSensor, Landsat8, Sentinel2,
                                      Label, SparcsLabels, Cloud95Labels,
                                      AlcdLabels)
 from pysegcnn.core.utils import (img2np, is_divisible, tile_topleft_corner,
-                                 parse_landsat_scene, parse_sentinel2_scene,
-                                 array_replace)
+                                 parse_landsat_scene, parse_sentinel2_scene)
 
 # module level logger
 LOGGER = logging.getLogger(__name__)
@@ -632,10 +630,13 @@ class ImageDataset(Dataset):
 
     def class_distribution(self):
 
-        # initialize dictionary of class spectral distribution
         # exclude NoData class
+        labels = {k: v for k, v in self.labels.items() if
+                  v['label'] != 'No data'}
+
+        # initialize dictionary of class spectral distribution
         cls_ds = {k: np.empty(shape=(0, len(self.use_bands)), dtype=np.float32)
-                  for k, v in self.labels.items() if v['label'] != 'No data'}
+                  for k, v in labels.items()}
 
         # iterate over the samples of the dataset
         for i in range(len(self)):
@@ -644,7 +645,7 @@ class ImageDataset(Dataset):
             x, y = self[i]
 
             # iterate over the different classes
-            for k, v in self.labels.items():
+            for k, v in labels.items():
                 # get values equal to the current class
                 mask = np.where(y == k)
 
-- 
GitLab