diff --git a/pysegcnn/core/split.py b/pysegcnn/core/split.py index 78f653a3b41c2d0f552c2485e026d8f604487060..588c2b488e8239e485f7c3f0f1512aed01921996 100644 --- a/pysegcnn/core/split.py +++ b/pysegcnn/core/split.py @@ -14,6 +14,9 @@ from torch.utils.data.dataset import Subset # the names of the subsets SUBSET_NAMES = ['train', 'valid', 'test'] +# valid split modes +VALID_SPLIT_MODES = ['random', 'scene', 'date'] + # function calculating number of samples in a dataset given a ratio def _ds_len(ds, ratio): @@ -129,7 +132,6 @@ def date_scene_split(ds, date, dateformat='%Y%m%d'): # store the indices and corresponding tiles of the current subset to # dictionary subsets[SUBSET_NAMES[name]] = scenes - # sbst.ids = np.unique([s['id'] for s in scenes.values()]) # check if the splits are disjoint assert pairwise_disjoint([s.keys() for s in subsets.values()]) @@ -143,101 +145,130 @@ def pairwise_disjoint(sets): return n == len(union) -class Split(object): +class SceneSubset(Subset): - # the valid modes - valid_modes = ['random', 'scene', 'date'] + def __init__(self, ds, indices, name, scenes, scene_ids): + super().__init__(dataset=ds, indices=indices) - def __init__(self, ds, mode, **kwargs): + # the name of the subset + self.name = name - # check which mode is provided - if mode not in self.valid_modes: - raise ValueError('{} is not supported. Valid modes are {}, see ' - 'pysegcnn.main.config.py for a description of ' - 'each mode.'.format(mode, self.valid_modes)) - self.mode = mode + # the scene in the subset + self.scenes = scenes - # the dataset to split - self.ds = ds + # the names of the scenes + self.ids = scene_ids - # the keyword arguments - self.kwargs = kwargs - # initialize split - self._init_split() +class RandomSubset(Subset): - def _init_split(self): + def __init__(self, ds, indices, name, scenes, scene_ids): + super().__init__(dataset=ds, indices=indices) - if self.mode == 'random': - self.subset = RandomSubset - self.split_function = random_tile_split - self.allowed_kwargs = ['tvratio', 'ttratio', 'seed'] + # the name of the subset + self.name = name - if self.mode == 'scene': - self.subset = SceneSubset - self.split_function = random_scene_split - self.allowed_kwargs = ['tvratio', 'ttratio', 'seed'] + # the scene in the subset + self.scenes = scenes - if self.mode == 'date': - self.subset = SceneSubset - self.split_function = date_scene_split - self.allowed_kwargs = ['date', 'dateformat'] - self._check_kwargs() +class Split(object): - def _check_kwargs(self): + def __init__(self, ds): - # check if the correct keyword arguments are provided - if not set(self.allowed_kwargs).issubset(self.kwargs.keys()): - raise TypeError('__init__() expecting keyword arguments: {}.' - .format(', '.join(kwa for kwa in - self.allowed_kwargs))) - # select the correct keyword arguments - self.kwargs = {k: self.kwargs[k] for k in self.allowed_kwargs} + # the dataset to split + self.ds = ds - # function apply the split def split(self): - # create the subsets - subsets = self.split_function(self.ds, **self.kwargs) - # build the subsets ds_split = [] - for name, sub in subsets.items(): + for name, sub in self.subsets().items(): # the scene identifiers of the current subset ids = np.unique([s['id'] for s in sub.values()]) # build the subset - subset = self.subset(self.ds, list(sub.keys()), name, - list(sub.values()), ids) - ds_split.append(subset) + sbst = self.subset_type()(self.ds, list(sub.keys()), name, + list(sub.values()), ids) + ds_split.append(sbst) return ds_split + @property + def subsets(self): + raise NotImplementedError -class SceneSubset(Subset): + def subset_type(self): + raise NotImplementedError - def __init__(self, ds, indices, name, scenes, scene_ids): - super().__init__(dataset=ds, indices=indices) + def __repr__(self): - # the name of the subset - self.name = name + # representation string to print + fs = self.__class__.__name__ + '(\n ' - # the scene in the subset - self.scenes = scenes + # dataset split + fs += '\n '.join( + '- {}: {:d} batches ({:.2f}%)' + .format(k, len(v), len(v) * 100 / len(self.ds)) + for k, v in self.subsets().items()) + fs += '\n)' + return fs - # the names of the scenes - self.ids = scene_ids +class DateSplit(Split): + def __init__(self, ds, date, dateformat): + super().__init__(ds) -class RandomSubset(Subset): + # the date to split the dataset + # before: training set + # after : validation set + self.date = date - def __init__(self, ds, indices, name, scenes): - super().__init__(dataset=ds, indices=indices) + # the format of the date + self.dateformat = dateformat - # the name of the subset - self.name = name + def subsets(self): + return date_scene_split(self.ds, self.date, self.dateformat) - # the scene in the subset - self.scenes = scenes + def subset_type(self): + return SceneSubset + + +class RandomSplit(Split): + + def __init__(self, ds, ttratio, tvratio, seed): + super().__init__(ds) + + # the training, validation and test set ratios + self.ttratio = ttratio + self.tvratio = tvratio + + # the random seed: useful for reproducibility + self.seed = seed + + +class RandomTileSplit(RandomSplit): + + def __init__(self, ds, ttratio, tvratio, seed): + super().__init__(ds, ttratio, tvratio, seed) + + def subsets(self): + return random_tile_split(self.ds, self.tvratio, self.ttratio, + self.seed) + + def subset_type(self): + return RandomSubset + + +class RandomSceneSplit(RandomSplit): + + def __init__(self, ds, ttratio, tvratio, seed): + super().__init__(ds, ttratio, tvratio, seed) + + def subsets(self): + return random_scene_split(self.ds, self.tvratio, self.ttratio, + self.seed) + + def subset_type(self): + return SceneSubset