From e30dbfac1078d60c231cc7973353f677fa2fba77 Mon Sep 17 00:00:00 2001
From: "Daniel.Frisinghelli" <daniel.frisinghelli@eurac.edu>
Date: Wed, 27 Jan 2021 18:15:04 +0100
Subject: [PATCH] Implemented a more efficient way to calculate class
 distribution.

---
 pysegcnn/core/dataset.py | 38 ++++++++++++++++++++++++++------------
 1 file changed, 26 insertions(+), 12 deletions(-)

diff --git a/pysegcnn/core/dataset.py b/pysegcnn/core/dataset.py
index 2366ecf..c04103c 100644
--- a/pysegcnn/core/dataset.py
+++ b/pysegcnn/core/dataset.py
@@ -632,14 +632,18 @@ class ImageDataset(Dataset):
 
     def class_distribution(self):
 
+        # initialize dictionary of class spectral distribution
+        cls_ds = {k: np.empty(shape=(0, len(self.use_bands))) for k, _ in
+                  self.labels.items()}
+
         # initialize class distribution dataframe
-        columns = [band.capitalize() for band in self.use_bands] + ['Class']
-        cls_df = pd.DataFrame(columns=columns)
+        # columns = [band.capitalize() for band in self.use_bands] + ['Class']
+        # cls_df = pd.DataFrame(columns=columns)
 
         # create the lookup table to replace the class identifiers by their
         # corresponding labels
-        lookup = np.array(list({k: v['label'] for k, v in self.labels.items()}
-                               .items())).astype(object)
+        # lookup = np.array(list({k: v['label'] for k, v in self.labels.items()}
+        #                        .items())).astype(object)
 
         # iterate over the samples of the dataset
         for i in range(len(self)):
@@ -647,17 +651,27 @@ class ImageDataset(Dataset):
             LOGGER.info('Sample: {}/{}'.format(i + 1, len(self)))
             x, y = self[i]
 
-            # reshape the current sample
-            data = np.hstack([x.flatten(start_dim=1).T, np.expand_dims(
-                array_replace(y.flatten(), lookup), axis=1)])
+            # iterate over the different classes
+            for k, v in self.labels.items():
+                # get values equal to the current class
+                mask = np.where(y == k)
+
+                # subset input to current class
+                cls_ds[k] = np.vstack([cls_ds[k], x[:, mask[0], mask[1]].T])
+
+        return cls_ds
+
+        # reshape the current sample
+        # data = np.hstack([x.flatten(start_dim=1).T, np.expand_dims(
+        #     array_replace(y.flatten(), lookup), axis=1)])
 
-            # the pixels of the current sample to the dataframe
-            df = pd.DataFrame(data, columns=columns)
+        # # the pixels of the current sample to the dataframe
+        # df = pd.DataFrame(data, columns=columns)
 
-            # update class distribution dataframe
-            cls_df = cls_df.append(df)
+        # # update class distribution dataframe
+        # cls_df = cls_df.append(df)
 
-        return cls_df
+        # return cls_df
 
     def __repr__(self):
         """Dataset representation.
-- 
GitLab