From 2037ea8f778e7a8498520f483e39649d0c81feac Mon Sep 17 00:00:00 2001
From: "Daniel.Frisinghelli" <daniel.frisinghelli@eurac.edu>
Date: Fri, 22 Jan 2021 17:41:56 +0100
Subject: [PATCH] Moved random seed parameter to split configurations.

---
 pysegcnn/core/dataset.py | 37 ++++++++++++++++++++-----------------
 1 file changed, 20 insertions(+), 17 deletions(-)

diff --git a/pysegcnn/core/dataset.py b/pysegcnn/core/dataset.py
index 596b731..5668a52 100644
--- a/pysegcnn/core/dataset.py
+++ b/pysegcnn/core/dataset.py
@@ -67,8 +67,6 @@ class ImageDataset(Dataset):
         A regural expression to match the ground truth naming convention.
     sort : `bool`
         Whether to chronologically sort the samples.
-    seed : `int`
-        The random seed.
     transforms : `list`
         List of :py:class:`pysegcnn.core.transforms.Augment` instances.
     merge_labels : `dict` [`str`, `str`]
@@ -113,7 +111,7 @@ class ImageDataset(Dataset):
     """
 
     def __init__(self, root_dir, use_bands=[], tile_size=None, pad=False,
-                 gt_pattern='(.*)gt\\.tif', sort=False, seed=0, transforms=[],
+                 gt_pattern='(.*)gt\\.tif', sort=False, transforms=[],
                  merge_labels={}):
         r"""Initialize.
 
@@ -140,10 +138,6 @@ class ImageDataset(Dataset):
         sort : `bool`, optional
             Whether to chronologically sort the samples. Useful for time series
             data. The default is `False`.
-        seed : `int`, optional
-            The random seed. Used to split the dataset into training,
-            validation and test set. Useful for reproducibility. The default is
-            `0`.
         transforms : `list`, optional
             List of :py:class:`pysegcnn.core.transforms.Augment` instances.
             Each item in ``transforms`` generates a distinct transformed
@@ -167,7 +161,6 @@ class ImageDataset(Dataset):
         self.pad = pad
         self.gt_pattern = gt_pattern
         self.sort = sort
-        self.seed = seed
         self.transforms = transforms
         self.merge_labels = merge_labels
 
@@ -711,11 +704,11 @@ class StandardEoDataset(ImageDataset):
     """
 
     def __init__(self, root_dir, use_bands=[], tile_size=None, pad=False,
-                 gt_pattern='(.*)gt\\.tif', sort=False, seed=0, transforms=[],
+                 gt_pattern='(.*)gt\\.tif', sort=False, transforms=[],
                  merge_labels={}):
         # initialize super class ImageDataset
         super().__init__(root_dir, use_bands, tile_size, pad, gt_pattern,
-                         sort, seed, transforms, merge_labels)
+                         sort, transforms, merge_labels)
 
     def _get_band_number(self, path):
         """Return the band number of a scene .tif file.
@@ -804,8 +797,12 @@ class StandardEoDataset(ImageDataset):
 
     def compose_scenes(self):
         """Build the list of samples of the dataset."""
-        # search the root directory
+
+        # initialize scene list and counter
         scenes = []
+        nscenes = 0
+
+        # search the root directory
         for dirpath, dirname, files in os.walk(self.root):
 
             # search for a ground truth in the current directory
@@ -855,9 +852,15 @@ class StandardEoDataset(ImageDataset):
                         # store optional transformation
                         data['transform'] = transf
 
+                        # store scene counter
+                        data['scene'] = nscenes
+
                         # append to list
                         scenes.append(data)
 
+                # advance scene counter
+                nscenes += 1
+
         # sort list of scenes and ground truths in chronological order
         if self.sort:
             scenes.sort(key=lambda k: k['date'])
@@ -878,11 +881,11 @@ class SparcsDataset(StandardEoDataset):
     """
 
     def __init__(self, root_dir, use_bands=[], tile_size=None, pad=False,
-                 gt_pattern='(.*)gt\\.tif', sort=False, seed=0, transforms=[],
+                 gt_pattern='(.*)gt\\.tif', sort=False, transforms=[],
                  merge_labels={}):
         # initialize super class StandardEoDataset
         super().__init__(root_dir, use_bands, tile_size, pad, gt_pattern,
-                         sort, seed, transforms, merge_labels)
+                         sort, transforms, merge_labels)
 
     @staticmethod
     def get_size():
@@ -951,11 +954,11 @@ class AlcdDataset(StandardEoDataset):
     """
 
     def __init__(self, root_dir, use_bands=[], tile_size=None, pad=False,
-                 gt_pattern='(.*)gt\\.tif', sort=False, seed=0, transforms=[],
+                 gt_pattern='(.*)gt\\.tif', sort=False, transforms=[],
                  merge_labels={}):
         # initialize super class StandardEoDataset
         super().__init__(root_dir, use_bands, tile_size, pad, gt_pattern,
-                         sort, seed, transforms, merge_labels)
+                         sort, transforms, merge_labels)
 
     @staticmethod
     def get_size():
@@ -1024,11 +1027,11 @@ class Cloud95Dataset(ImageDataset):
     """
 
     def __init__(self, root_dir, use_bands=[], tile_size=None, pad=False,
-                 gt_pattern='(.*)gt\\.tif', sort=False, seed=0, transforms=[],
+                 gt_pattern='(.*)gt\\.tif', sort=False, transforms=[],
                  merge_labels={}):
         # initialize super class StandardEoDataset
         super().__init__(root_dir, use_bands, tile_size, pad, gt_pattern,
-                         sort, seed, transforms, merge_labels)
+                         sort, transforms, merge_labels)
 
         # the csv file containing the names of the informative patches
         # patches with more than 80% black pixels, i.e. patches resulting from
-- 
GitLab