Skip to content
Snippets Groups Projects
Commit 9374ee3e authored by Frisinghelli Daniel's avatar Frisinghelli Daniel
Browse files

Improved train/valid/test split workflow

parent d9a3b557
No related branches found
No related tags found
No related merge requests found
......@@ -14,6 +14,9 @@ from torch.utils.data.dataset import Subset
# the names of the subsets
SUBSET_NAMES = ['train', 'valid', 'test']
# valid split modes
VALID_SPLIT_MODES = ['random', 'scene', 'date']
# function calculating number of samples in a dataset given a ratio
def _ds_len(ds, ratio):
......@@ -129,7 +132,6 @@ def date_scene_split(ds, date, dateformat='%Y%m%d'):
# store the indices and corresponding tiles of the current subset to
# dictionary
subsets[SUBSET_NAMES[name]] = scenes
# sbst.ids = np.unique([s['id'] for s in scenes.values()])
# check if the splits are disjoint
assert pairwise_disjoint([s.keys() for s in subsets.values()])
......@@ -143,101 +145,130 @@ def pairwise_disjoint(sets):
return n == len(union)
class Split(object):
class SceneSubset(Subset):
# the valid modes
valid_modes = ['random', 'scene', 'date']
def __init__(self, ds, indices, name, scenes, scene_ids):
super().__init__(dataset=ds, indices=indices)
def __init__(self, ds, mode, **kwargs):
# the name of the subset
self.name = name
# check which mode is provided
if mode not in self.valid_modes:
raise ValueError('{} is not supported. Valid modes are {}, see '
'pysegcnn.main.config.py for a description of '
'each mode.'.format(mode, self.valid_modes))
self.mode = mode
# the scene in the subset
self.scenes = scenes
# the dataset to split
self.ds = ds
# the names of the scenes
self.ids = scene_ids
# the keyword arguments
self.kwargs = kwargs
# initialize split
self._init_split()
class RandomSubset(Subset):
def _init_split(self):
def __init__(self, ds, indices, name, scenes, scene_ids):
super().__init__(dataset=ds, indices=indices)
if self.mode == 'random':
self.subset = RandomSubset
self.split_function = random_tile_split
self.allowed_kwargs = ['tvratio', 'ttratio', 'seed']
# the name of the subset
self.name = name
if self.mode == 'scene':
self.subset = SceneSubset
self.split_function = random_scene_split
self.allowed_kwargs = ['tvratio', 'ttratio', 'seed']
# the scene in the subset
self.scenes = scenes
if self.mode == 'date':
self.subset = SceneSubset
self.split_function = date_scene_split
self.allowed_kwargs = ['date', 'dateformat']
self._check_kwargs()
class Split(object):
def _check_kwargs(self):
def __init__(self, ds):
# check if the correct keyword arguments are provided
if not set(self.allowed_kwargs).issubset(self.kwargs.keys()):
raise TypeError('__init__() expecting keyword arguments: {}.'
.format(', '.join(kwa for kwa in
self.allowed_kwargs)))
# select the correct keyword arguments
self.kwargs = {k: self.kwargs[k] for k in self.allowed_kwargs}
# the dataset to split
self.ds = ds
# function apply the split
def split(self):
# create the subsets
subsets = self.split_function(self.ds, **self.kwargs)
# build the subsets
ds_split = []
for name, sub in subsets.items():
for name, sub in self.subsets().items():
# the scene identifiers of the current subset
ids = np.unique([s['id'] for s in sub.values()])
# build the subset
subset = self.subset(self.ds, list(sub.keys()), name,
list(sub.values()), ids)
ds_split.append(subset)
sbst = self.subset_type()(self.ds, list(sub.keys()), name,
list(sub.values()), ids)
ds_split.append(sbst)
return ds_split
@property
def subsets(self):
raise NotImplementedError
class SceneSubset(Subset):
def subset_type(self):
raise NotImplementedError
def __init__(self, ds, indices, name, scenes, scene_ids):
super().__init__(dataset=ds, indices=indices)
def __repr__(self):
# the name of the subset
self.name = name
# representation string to print
fs = self.__class__.__name__ + '(\n '
# the scene in the subset
self.scenes = scenes
# dataset split
fs += '\n '.join(
'- {}: {:d} batches ({:.2f}%)'
.format(k, len(v), len(v) * 100 / len(self.ds))
for k, v in self.subsets().items())
fs += '\n)'
return fs
# the names of the scenes
self.ids = scene_ids
class DateSplit(Split):
def __init__(self, ds, date, dateformat):
super().__init__(ds)
class RandomSubset(Subset):
# the date to split the dataset
# before: training set
# after : validation set
self.date = date
def __init__(self, ds, indices, name, scenes):
super().__init__(dataset=ds, indices=indices)
# the format of the date
self.dateformat = dateformat
# the name of the subset
self.name = name
def subsets(self):
return date_scene_split(self.ds, self.date, self.dateformat)
# the scene in the subset
self.scenes = scenes
def subset_type(self):
return SceneSubset
class RandomSplit(Split):
def __init__(self, ds, ttratio, tvratio, seed):
super().__init__(ds)
# the training, validation and test set ratios
self.ttratio = ttratio
self.tvratio = tvratio
# the random seed: useful for reproducibility
self.seed = seed
class RandomTileSplit(RandomSplit):
def __init__(self, ds, ttratio, tvratio, seed):
super().__init__(ds, ttratio, tvratio, seed)
def subsets(self):
return random_tile_split(self.ds, self.tvratio, self.ttratio,
self.seed)
def subset_type(self):
return RandomSubset
class RandomSceneSplit(RandomSplit):
def __init__(self, ds, ttratio, tvratio, seed):
super().__init__(ds, ttratio, tvratio, seed)
def subsets(self):
return random_scene_split(self.ds, self.tvratio, self.ttratio,
self.seed)
def subset_type(self):
return SceneSubset
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment