Skip to content
Snippets Groups Projects
Commit 2037ea8f authored by Frisinghelli Daniel's avatar Frisinghelli Daniel
Browse files

Moved random seed parameter to split configurations.

parent ad3d88eb
No related branches found
No related tags found
No related merge requests found
...@@ -67,8 +67,6 @@ class ImageDataset(Dataset): ...@@ -67,8 +67,6 @@ class ImageDataset(Dataset):
A regural expression to match the ground truth naming convention. A regural expression to match the ground truth naming convention.
sort : `bool` sort : `bool`
Whether to chronologically sort the samples. Whether to chronologically sort the samples.
seed : `int`
The random seed.
transforms : `list` transforms : `list`
List of :py:class:`pysegcnn.core.transforms.Augment` instances. List of :py:class:`pysegcnn.core.transforms.Augment` instances.
merge_labels : `dict` [`str`, `str`] merge_labels : `dict` [`str`, `str`]
...@@ -113,7 +111,7 @@ class ImageDataset(Dataset): ...@@ -113,7 +111,7 @@ class ImageDataset(Dataset):
""" """
def __init__(self, root_dir, use_bands=[], tile_size=None, pad=False, def __init__(self, root_dir, use_bands=[], tile_size=None, pad=False,
gt_pattern='(.*)gt\\.tif', sort=False, seed=0, transforms=[], gt_pattern='(.*)gt\\.tif', sort=False, transforms=[],
merge_labels={}): merge_labels={}):
r"""Initialize. r"""Initialize.
...@@ -140,10 +138,6 @@ class ImageDataset(Dataset): ...@@ -140,10 +138,6 @@ class ImageDataset(Dataset):
sort : `bool`, optional sort : `bool`, optional
Whether to chronologically sort the samples. Useful for time series Whether to chronologically sort the samples. Useful for time series
data. The default is `False`. data. The default is `False`.
seed : `int`, optional
The random seed. Used to split the dataset into training,
validation and test set. Useful for reproducibility. The default is
`0`.
transforms : `list`, optional transforms : `list`, optional
List of :py:class:`pysegcnn.core.transforms.Augment` instances. List of :py:class:`pysegcnn.core.transforms.Augment` instances.
Each item in ``transforms`` generates a distinct transformed Each item in ``transforms`` generates a distinct transformed
...@@ -167,7 +161,6 @@ class ImageDataset(Dataset): ...@@ -167,7 +161,6 @@ class ImageDataset(Dataset):
self.pad = pad self.pad = pad
self.gt_pattern = gt_pattern self.gt_pattern = gt_pattern
self.sort = sort self.sort = sort
self.seed = seed
self.transforms = transforms self.transforms = transforms
self.merge_labels = merge_labels self.merge_labels = merge_labels
...@@ -711,11 +704,11 @@ class StandardEoDataset(ImageDataset): ...@@ -711,11 +704,11 @@ class StandardEoDataset(ImageDataset):
""" """
def __init__(self, root_dir, use_bands=[], tile_size=None, pad=False, def __init__(self, root_dir, use_bands=[], tile_size=None, pad=False,
gt_pattern='(.*)gt\\.tif', sort=False, seed=0, transforms=[], gt_pattern='(.*)gt\\.tif', sort=False, transforms=[],
merge_labels={}): merge_labels={}):
# initialize super class ImageDataset # initialize super class ImageDataset
super().__init__(root_dir, use_bands, tile_size, pad, gt_pattern, super().__init__(root_dir, use_bands, tile_size, pad, gt_pattern,
sort, seed, transforms, merge_labels) sort, transforms, merge_labels)
def _get_band_number(self, path): def _get_band_number(self, path):
"""Return the band number of a scene .tif file. """Return the band number of a scene .tif file.
...@@ -804,8 +797,12 @@ class StandardEoDataset(ImageDataset): ...@@ -804,8 +797,12 @@ class StandardEoDataset(ImageDataset):
def compose_scenes(self): def compose_scenes(self):
"""Build the list of samples of the dataset.""" """Build the list of samples of the dataset."""
# search the root directory
# initialize scene list and counter
scenes = [] scenes = []
nscenes = 0
# search the root directory
for dirpath, dirname, files in os.walk(self.root): for dirpath, dirname, files in os.walk(self.root):
# search for a ground truth in the current directory # search for a ground truth in the current directory
...@@ -855,9 +852,15 @@ class StandardEoDataset(ImageDataset): ...@@ -855,9 +852,15 @@ class StandardEoDataset(ImageDataset):
# store optional transformation # store optional transformation
data['transform'] = transf data['transform'] = transf
# store scene counter
data['scene'] = nscenes
# append to list # append to list
scenes.append(data) scenes.append(data)
# advance scene counter
nscenes += 1
# sort list of scenes and ground truths in chronological order # sort list of scenes and ground truths in chronological order
if self.sort: if self.sort:
scenes.sort(key=lambda k: k['date']) scenes.sort(key=lambda k: k['date'])
...@@ -878,11 +881,11 @@ class SparcsDataset(StandardEoDataset): ...@@ -878,11 +881,11 @@ class SparcsDataset(StandardEoDataset):
""" """
def __init__(self, root_dir, use_bands=[], tile_size=None, pad=False, def __init__(self, root_dir, use_bands=[], tile_size=None, pad=False,
gt_pattern='(.*)gt\\.tif', sort=False, seed=0, transforms=[], gt_pattern='(.*)gt\\.tif', sort=False, transforms=[],
merge_labels={}): merge_labels={}):
# initialize super class StandardEoDataset # initialize super class StandardEoDataset
super().__init__(root_dir, use_bands, tile_size, pad, gt_pattern, super().__init__(root_dir, use_bands, tile_size, pad, gt_pattern,
sort, seed, transforms, merge_labels) sort, transforms, merge_labels)
@staticmethod @staticmethod
def get_size(): def get_size():
...@@ -951,11 +954,11 @@ class AlcdDataset(StandardEoDataset): ...@@ -951,11 +954,11 @@ class AlcdDataset(StandardEoDataset):
""" """
def __init__(self, root_dir, use_bands=[], tile_size=None, pad=False, def __init__(self, root_dir, use_bands=[], tile_size=None, pad=False,
gt_pattern='(.*)gt\\.tif', sort=False, seed=0, transforms=[], gt_pattern='(.*)gt\\.tif', sort=False, transforms=[],
merge_labels={}): merge_labels={}):
# initialize super class StandardEoDataset # initialize super class StandardEoDataset
super().__init__(root_dir, use_bands, tile_size, pad, gt_pattern, super().__init__(root_dir, use_bands, tile_size, pad, gt_pattern,
sort, seed, transforms, merge_labels) sort, transforms, merge_labels)
@staticmethod @staticmethod
def get_size(): def get_size():
...@@ -1024,11 +1027,11 @@ class Cloud95Dataset(ImageDataset): ...@@ -1024,11 +1027,11 @@ class Cloud95Dataset(ImageDataset):
""" """
def __init__(self, root_dir, use_bands=[], tile_size=None, pad=False, def __init__(self, root_dir, use_bands=[], tile_size=None, pad=False,
gt_pattern='(.*)gt\\.tif', sort=False, seed=0, transforms=[], gt_pattern='(.*)gt\\.tif', sort=False, transforms=[],
merge_labels={}): merge_labels={}):
# initialize super class StandardEoDataset # initialize super class StandardEoDataset
super().__init__(root_dir, use_bands, tile_size, pad, gt_pattern, super().__init__(root_dir, use_bands, tile_size, pad, gt_pattern,
sort, seed, transforms, merge_labels) sort, transforms, merge_labels)
# the csv file containing the names of the informative patches # the csv file containing the names of the informative patches
# patches with more than 80% black pixels, i.e. patches resulting from # patches with more than 80% black pixels, i.e. patches resulting from
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment