Skip to content
Snippets Groups Projects
Commit 2037ea8f authored by Frisinghelli Daniel's avatar Frisinghelli Daniel
Browse files

Moved random seed parameter to split configurations.

parent ad3d88eb
No related branches found
No related tags found
No related merge requests found
......@@ -67,8 +67,6 @@ class ImageDataset(Dataset):
A regural expression to match the ground truth naming convention.
sort : `bool`
Whether to chronologically sort the samples.
seed : `int`
The random seed.
transforms : `list`
List of :py:class:`pysegcnn.core.transforms.Augment` instances.
merge_labels : `dict` [`str`, `str`]
......@@ -113,7 +111,7 @@ class ImageDataset(Dataset):
"""
def __init__(self, root_dir, use_bands=[], tile_size=None, pad=False,
gt_pattern='(.*)gt\\.tif', sort=False, seed=0, transforms=[],
gt_pattern='(.*)gt\\.tif', sort=False, transforms=[],
merge_labels={}):
r"""Initialize.
......@@ -140,10 +138,6 @@ class ImageDataset(Dataset):
sort : `bool`, optional
Whether to chronologically sort the samples. Useful for time series
data. The default is `False`.
seed : `int`, optional
The random seed. Used to split the dataset into training,
validation and test set. Useful for reproducibility. The default is
`0`.
transforms : `list`, optional
List of :py:class:`pysegcnn.core.transforms.Augment` instances.
Each item in ``transforms`` generates a distinct transformed
......@@ -167,7 +161,6 @@ class ImageDataset(Dataset):
self.pad = pad
self.gt_pattern = gt_pattern
self.sort = sort
self.seed = seed
self.transforms = transforms
self.merge_labels = merge_labels
......@@ -711,11 +704,11 @@ class StandardEoDataset(ImageDataset):
"""
def __init__(self, root_dir, use_bands=[], tile_size=None, pad=False,
gt_pattern='(.*)gt\\.tif', sort=False, seed=0, transforms=[],
gt_pattern='(.*)gt\\.tif', sort=False, transforms=[],
merge_labels={}):
# initialize super class ImageDataset
super().__init__(root_dir, use_bands, tile_size, pad, gt_pattern,
sort, seed, transforms, merge_labels)
sort, transforms, merge_labels)
def _get_band_number(self, path):
"""Return the band number of a scene .tif file.
......@@ -804,8 +797,12 @@ class StandardEoDataset(ImageDataset):
def compose_scenes(self):
"""Build the list of samples of the dataset."""
# search the root directory
# initialize scene list and counter
scenes = []
nscenes = 0
# search the root directory
for dirpath, dirname, files in os.walk(self.root):
# search for a ground truth in the current directory
......@@ -855,9 +852,15 @@ class StandardEoDataset(ImageDataset):
# store optional transformation
data['transform'] = transf
# store scene counter
data['scene'] = nscenes
# append to list
scenes.append(data)
# advance scene counter
nscenes += 1
# sort list of scenes and ground truths in chronological order
if self.sort:
scenes.sort(key=lambda k: k['date'])
......@@ -878,11 +881,11 @@ class SparcsDataset(StandardEoDataset):
"""
def __init__(self, root_dir, use_bands=[], tile_size=None, pad=False,
gt_pattern='(.*)gt\\.tif', sort=False, seed=0, transforms=[],
gt_pattern='(.*)gt\\.tif', sort=False, transforms=[],
merge_labels={}):
# initialize super class StandardEoDataset
super().__init__(root_dir, use_bands, tile_size, pad, gt_pattern,
sort, seed, transforms, merge_labels)
sort, transforms, merge_labels)
@staticmethod
def get_size():
......@@ -951,11 +954,11 @@ class AlcdDataset(StandardEoDataset):
"""
def __init__(self, root_dir, use_bands=[], tile_size=None, pad=False,
gt_pattern='(.*)gt\\.tif', sort=False, seed=0, transforms=[],
gt_pattern='(.*)gt\\.tif', sort=False, transforms=[],
merge_labels={}):
# initialize super class StandardEoDataset
super().__init__(root_dir, use_bands, tile_size, pad, gt_pattern,
sort, seed, transforms, merge_labels)
sort, transforms, merge_labels)
@staticmethod
def get_size():
......@@ -1024,11 +1027,11 @@ class Cloud95Dataset(ImageDataset):
"""
def __init__(self, root_dir, use_bands=[], tile_size=None, pad=False,
gt_pattern='(.*)gt\\.tif', sort=False, seed=0, transforms=[],
gt_pattern='(.*)gt\\.tif', sort=False, transforms=[],
merge_labels={}):
# initialize super class StandardEoDataset
super().__init__(root_dir, use_bands, tile_size, pad, gt_pattern,
sort, seed, transforms, merge_labels)
sort, transforms, merge_labels)
# the csv file containing the names of the informative patches
# patches with more than 80% black pixels, i.e. patches resulting from
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment