From 5dacadeaa267bd1f968249fb8c876750163cc81e Mon Sep 17 00:00:00 2001
From: "Daniel.Frisinghelli" <daniel.frisinghelli@eurac.edu>
Date: Tue, 26 Jan 2021 17:57:39 +0100
Subject: [PATCH] Implemented the computation of the spectral distribution of
 the different classes.

---
 pysegcnn/core/dataset.py | 45 +++++++++++++++++++++++++++++++++-------
 1 file changed, 38 insertions(+), 7 deletions(-)

diff --git a/pysegcnn/core/dataset.py b/pysegcnn/core/dataset.py
index 5668a52..2366ecf 100644
--- a/pysegcnn/core/dataset.py
+++ b/pysegcnn/core/dataset.py
@@ -32,6 +32,7 @@ import pathlib
 
 # externals
 import numpy as np
+import pandas as pd
 import torch
 from torch.utils.data import Dataset
 
@@ -40,7 +41,8 @@ from pysegcnn.core.constants import (MultiSpectralSensor, Landsat8, Sentinel2,
                                      Label, SparcsLabels, Cloud95Labels,
                                      AlcdLabels)
 from pysegcnn.core.utils import (img2np, is_divisible, tile_topleft_corner,
-                                 parse_landsat_scene, parse_sentinel2_scene)
+                                 parse_landsat_scene, parse_sentinel2_scene,
+                                 array_replace)
 
 # module level logger
 LOGGER = logging.getLogger(__name__)
@@ -210,13 +212,12 @@ class ImageDataset(Dataset):
         # always use the original dataset together with the augmentations
         self.transforms = [None] + self.transforms
 
-        # when padding, add a new "no data" label to the ground truth
-        self.cval = self.label_class.No_data.id
+        # when padding, add a new "padded" label to the ground truth
         if self.pad and sum(self.padding) > 0:
-            LOGGER.info('Adding label "No data" with value={} to ground truth.'
+            self.cval = self.label_class.No_data.id
+            LOGGER.info('Padding to defined tile size. Padding value: {}.'
                         .format(self.cval))
         else:
-            # self._labels.pop(self.cval)
             self.cval = 0
 
         # remove labels to merge from dataset instance labels
@@ -629,6 +630,35 @@ class ImageDataset(Dataset):
         """
         return torch.tensor(np.asarray(x).copy(), dtype=dtype)
 
+    def class_distribution(self):
+
+        # initialize class distribution dataframe
+        columns = [band.capitalize() for band in self.use_bands] + ['Class']
+        cls_df = pd.DataFrame(columns=columns)
+
+        # create the lookup table to replace the class identifiers by their
+        # corresponding labels
+        lookup = np.array(list({k: v['label'] for k, v in self.labels.items()}
+                               .items())).astype(object)
+
+        # iterate over the samples of the dataset
+        for i in range(len(self)):
+            # get the data of the current sample
+            LOGGER.info('Sample: {}/{}'.format(i + 1, len(self)))
+            x, y = self[i]
+
+            # reshape the current sample
+            data = np.hstack([x.flatten(start_dim=1).T, np.expand_dims(
+                array_replace(y.flatten(), lookup), axis=1)])
+
+            # the pixels of the current sample to the dataframe
+            df = pd.DataFrame(data, columns=columns)
+
+            # update class distribution dataframe
+            cls_df = cls_df.append(df)
+
+        return cls_df
+
     def __repr__(self):
         """Dataset representation.
 
@@ -659,8 +689,9 @@ class ImageDataset(Dataset):
         # tiles
         fs += '\n    (tiles):\n        '
         fs += '- number of tiles per scene: {}\n        '.format(self.tiles)
-        fs += '- tile size: {}\n        '.format((self.tile_size,
-                                                  self.tile_size))
+        fs += '- tile size: {}\n        '.format(
+            2 * (self.tile_size, ) if self.tile_size is not None else
+            (self.height, self.width))
         fs += '- number of tiles: {}'.format(len(self.scenes))
 
         # classes of interest
-- 
GitLab