diff --git a/pysegcnn/core/split.py b/pysegcnn/core/split.py index 588c2b488e8239e485f7c3f0f1512aed01921996..f658346880d4b20a47a2d2d7983a8653f0f23ab7 100644 --- a/pysegcnn/core/split.py +++ b/pysegcnn/core/split.py @@ -6,6 +6,7 @@ Created on Wed Jul 29 12:02:32 2020 """ # builtins import datetime +import enum # externals import numpy as np @@ -14,10 +15,6 @@ from torch.utils.data.dataset import Subset # the names of the subsets SUBSET_NAMES = ['train', 'valid', 'test'] -# valid split modes -VALID_SPLIT_MODES = ['random', 'scene', 'date'] - - # function calculating number of samples in a dataset given a ratio def _ds_len(ds, ratio): return int(np.round(len(ds) * ratio)) @@ -145,7 +142,19 @@ def pairwise_disjoint(sets): return n == len(union) -class SceneSubset(Subset): +class CustomSubset(Subset): + + def __repr__(self): + + # representation string + fs = '- {}: {:d} tiles ({:.2f}%)'.format( + self.name, len(self.scenes), 100 * len(self.scenes) / + len(self.dataset)) + + return fs + + +class SceneSubset(CustomSubset): def __init__(self, ds, indices, name, scenes, scene_ids): super().__init__(dataset=ds, indices=indices) @@ -160,7 +169,7 @@ class SceneSubset(Subset): self.ids = scene_ids -class RandomSubset(Subset): +class RandomSubset(CustomSubset): def __init__(self, ds, indices, name, scenes, scene_ids): super().__init__(dataset=ds, indices=indices) @@ -202,19 +211,6 @@ class Split(object): def subset_type(self): raise NotImplementedError - def __repr__(self): - - # representation string to print - fs = self.__class__.__name__ + '(\n ' - - # dataset split - fs += '\n '.join( - '- {}: {:d} batches ({:.2f}%)' - .format(k, len(v), len(v) * 100 / len(self.ds)) - for k, v in self.subsets().items()) - fs += '\n)' - return fs - class DateSplit(Split): def __init__(self, ds, date, dateformat): @@ -272,3 +268,9 @@ class RandomSceneSplit(RandomSplit): def subset_type(self): return SceneSubset + + +class SupportedSplits(enum.Enum): + random = RandomTileSplit + scene = RandomSceneSplit + date = DateSplit