From 2037ea8f778e7a8498520f483e39649d0c81feac Mon Sep 17 00:00:00 2001 From: "Daniel.Frisinghelli" <daniel.frisinghelli@eurac.edu> Date: Fri, 22 Jan 2021 17:41:56 +0100 Subject: [PATCH] Moved random seed parameter to split configurations. --- pysegcnn/core/dataset.py | 37 ++++++++++++++++++++----------------- 1 file changed, 20 insertions(+), 17 deletions(-) diff --git a/pysegcnn/core/dataset.py b/pysegcnn/core/dataset.py index 596b731..5668a52 100644 --- a/pysegcnn/core/dataset.py +++ b/pysegcnn/core/dataset.py @@ -67,8 +67,6 @@ class ImageDataset(Dataset): A regural expression to match the ground truth naming convention. sort : `bool` Whether to chronologically sort the samples. - seed : `int` - The random seed. transforms : `list` List of :py:class:`pysegcnn.core.transforms.Augment` instances. merge_labels : `dict` [`str`, `str`] @@ -113,7 +111,7 @@ class ImageDataset(Dataset): """ def __init__(self, root_dir, use_bands=[], tile_size=None, pad=False, - gt_pattern='(.*)gt\\.tif', sort=False, seed=0, transforms=[], + gt_pattern='(.*)gt\\.tif', sort=False, transforms=[], merge_labels={}): r"""Initialize. @@ -140,10 +138,6 @@ class ImageDataset(Dataset): sort : `bool`, optional Whether to chronologically sort the samples. Useful for time series data. The default is `False`. - seed : `int`, optional - The random seed. Used to split the dataset into training, - validation and test set. Useful for reproducibility. The default is - `0`. transforms : `list`, optional List of :py:class:`pysegcnn.core.transforms.Augment` instances. Each item in ``transforms`` generates a distinct transformed @@ -167,7 +161,6 @@ class ImageDataset(Dataset): self.pad = pad self.gt_pattern = gt_pattern self.sort = sort - self.seed = seed self.transforms = transforms self.merge_labels = merge_labels @@ -711,11 +704,11 @@ class StandardEoDataset(ImageDataset): """ def __init__(self, root_dir, use_bands=[], tile_size=None, pad=False, - gt_pattern='(.*)gt\\.tif', sort=False, seed=0, transforms=[], + gt_pattern='(.*)gt\\.tif', sort=False, transforms=[], merge_labels={}): # initialize super class ImageDataset super().__init__(root_dir, use_bands, tile_size, pad, gt_pattern, - sort, seed, transforms, merge_labels) + sort, transforms, merge_labels) def _get_band_number(self, path): """Return the band number of a scene .tif file. @@ -804,8 +797,12 @@ class StandardEoDataset(ImageDataset): def compose_scenes(self): """Build the list of samples of the dataset.""" - # search the root directory + + # initialize scene list and counter scenes = [] + nscenes = 0 + + # search the root directory for dirpath, dirname, files in os.walk(self.root): # search for a ground truth in the current directory @@ -855,9 +852,15 @@ class StandardEoDataset(ImageDataset): # store optional transformation data['transform'] = transf + # store scene counter + data['scene'] = nscenes + # append to list scenes.append(data) + # advance scene counter + nscenes += 1 + # sort list of scenes and ground truths in chronological order if self.sort: scenes.sort(key=lambda k: k['date']) @@ -878,11 +881,11 @@ class SparcsDataset(StandardEoDataset): """ def __init__(self, root_dir, use_bands=[], tile_size=None, pad=False, - gt_pattern='(.*)gt\\.tif', sort=False, seed=0, transforms=[], + gt_pattern='(.*)gt\\.tif', sort=False, transforms=[], merge_labels={}): # initialize super class StandardEoDataset super().__init__(root_dir, use_bands, tile_size, pad, gt_pattern, - sort, seed, transforms, merge_labels) + sort, transforms, merge_labels) @staticmethod def get_size(): @@ -951,11 +954,11 @@ class AlcdDataset(StandardEoDataset): """ def __init__(self, root_dir, use_bands=[], tile_size=None, pad=False, - gt_pattern='(.*)gt\\.tif', sort=False, seed=0, transforms=[], + gt_pattern='(.*)gt\\.tif', sort=False, transforms=[], merge_labels={}): # initialize super class StandardEoDataset super().__init__(root_dir, use_bands, tile_size, pad, gt_pattern, - sort, seed, transforms, merge_labels) + sort, transforms, merge_labels) @staticmethod def get_size(): @@ -1024,11 +1027,11 @@ class Cloud95Dataset(ImageDataset): """ def __init__(self, root_dir, use_bands=[], tile_size=None, pad=False, - gt_pattern='(.*)gt\\.tif', sort=False, seed=0, transforms=[], + gt_pattern='(.*)gt\\.tif', sort=False, transforms=[], merge_labels={}): # initialize super class StandardEoDataset super().__init__(root_dir, use_bands, tile_size, pad, gt_pattern, - sort, seed, transforms, merge_labels) + sort, transforms, merge_labels) # the csv file containing the names of the informative patches # patches with more than 80% black pixels, i.e. patches resulting from -- GitLab