From ad3d88eb49f7ac02eea10e60509b8382669732f3 Mon Sep 17 00:00:00 2001 From: "Daniel.Frisinghelli" <daniel.frisinghelli@eurac.edu> Date: Fri, 22 Jan 2021 17:41:30 +0100 Subject: [PATCH] Implemented cross-validation subsampling. --- pysegcnn/core/split.py | 746 ++++++++++------------------------------- 1 file changed, 174 insertions(+), 572 deletions(-) diff --git a/pysegcnn/core/split.py b/pysegcnn/core/split.py index e70ffe4..49609c8 100644 --- a/pysegcnn/core/split.py +++ b/pysegcnn/core/split.py @@ -15,11 +15,11 @@ License # -*- coding: utf-8 -*- # builtins -import datetime import enum # externals import numpy as np +from sklearn.model_selection import KFold from torch.utils.data.dataset import Subset # the names of the subsets @@ -45,93 +45,50 @@ def _ds_len(ds, ratio): return int(np.round(len(ds) * ratio)) -def random_tile_split(ds, tvratio, ttratio=1, seed=0): - """Randomly split the tiles of a dataset. - - For each scene, the tiles of the scene can be distributed among the - training, validation and test set. - - The parameters ``ttratio`` and ``tvratio`` control the size of the - training, validation and test datasets. +def pairwise_disjoint(sets): + """Check if ``sets`` are pairwise disjoint. - Test dataset size : ``(1 - ttratio) * len(ds)`` - Train dataset size : ``ttratio * tvratio * len(ds)`` - Validation dataset size: ``ttratio * (1 - tvratio) * len(ds)`` + Sets are pairwise disjoint if the length of their union equals the sum of + their lengths. Parameters ---------- - ds : :py:class:`pysegcnn.core.dataset.ImageDataset` - An instance of :py:class:`pysegcnn.core.dataset.ImageDataset`. - tvratio : `float` - The ratio of training data to validation data, e.g. ``tvratio=0.8`` - means 80% training, 20% validation. - ttratio : `float`, optional - The ratio of training and validation data to test data, e.g. - ``ttratio=0.6`` means 60% for training and validation, 40% for - testing. The default is `1`. - seed : `int`, optional - The random seed for reproducibility. The default is `0`. - - Raises - ------ - AssertionError - Raised if the splits are not pairwise disjoint. + sets : `list` [:py:class:`collections.Sized`] + A list of sized objects. Returns ------- - subsets : `dict` - Subset dictionary with keys: - ``'train'`` - The training scenes (`dict`). - ``'valid'`` - The validation scenes (`dict`). - ``'test'`` - The test scenes (`dict`). + disjoint : `bool` + Whether the sets are pairwise disjoint. """ - # set the random seed for reproducibility - np.random.seed(seed) - - # randomly permute indices to access dataset - indices = np.random.permutation(len(ds)) - - # length of the training and validation dataset - # number of samples: (ttratio * len(ds)) - trav_len = _ds_len(indices, ttratio) - trav_indices = indices[:trav_len] - - # length of the training dataset - # number of samples: (ttratio * tvratio * len(ds)) - train_len = _ds_len(trav_indices, tvratio) - train_ind = trav_indices[:train_len] - - # length of the validation dataset - # number of samples: (ttratio * (1- tvratio) * len(ds)) - valid_ind = trav_indices[train_len:] - - # length of the test dataset - # number of samples: ((1 - ttratio) * len(ds)) - test_ind = indices[trav_len:] + union = set().union(*sets) + n = sum(len(u) for u in sets) + return n == len(union) - # get the tiles of the scenes of each dataset - subsets = {} - for name, dataset in enumerate([train_ind, valid_ind, test_ind]): - # store the indices and corresponding tiles of the current subset to - # dictionary - subsets[SUBSET_NAMES[name]] = {k: ds.scenes[k] for k in dataset} +def index_dict(indices): + """Generate the training, validation and test set index dictionary. - # check if the splits are disjoint - assert pairwise_disjoint([s.keys() for s in subsets.values()]) + Parameters + ---------- + indices : `list` [:py:class:`numpy.ndarray`] + An ordered list composed of three :py:class:`numpy.ndarray` containing + the indices to the training, validation and test set. - return subsets + Returns + ------- + index_dict : `dict` + The index dictionary, where the keys are equal to ``SUBSET_NAMES`` and + the values are py:class:`numpy.ndarray` containing the indices to the + training, validation and test set. + """ + return {k: v for k, v in zip(SUBSET_NAMES, indices)} -def random_scene_split(ds, tvratio, ttratio=1, seed=0): - """Semi-randomly split the tiles of a dataset. - For each scene, all the tiles of the scene are included in either the - training, validation or test set, respectively. +def random_split(ds, tvratio=0.8, ttratio=1, seed=0, shuffle=True): + """Randomly split an iterable into training, validation and test set. The parameters ``ttratio`` and ``tvratio`` control the size of the training, validation and test datasets. @@ -142,17 +99,20 @@ def random_scene_split(ds, tvratio, ttratio=1, seed=0): Parameters ---------- - ds : :py:class:`pysegcnn.core.dataset.ImageDataset` - An instance of :py:class:`pysegcnn.core.dataset.ImageDataset`. - tvratio : `float` + ds : :py:class:`collections.Sized` + An object with a :py:meth:`__len__` method. + tvratio : `float`, optional The ratio of training data to validation data, e.g. ``tvratio=0.8`` - means 80% training, 20% validation. + means 80% training, 20% validation. The default is `0.8`. ttratio : `float`, optional The ratio of training and validation data to test data, e.g. ``ttratio=0.6`` means 60% for training and validation, 40% for testing. The default is `1`. seed : `int`, optional The random seed for reproducibility. The default is `0`. + shuffle : `bool`, optional + Whether to shuffle the data before splitting into batches. The default + is `True`. Raises ------ @@ -161,568 +121,210 @@ def random_scene_split(ds, tvratio, ttratio=1, seed=0): Returns ------- - subsets : `dict` - Subset dictionary with keys: - ``'train'`` - The training scenes (`dict`). - ``'valid'`` - The validation scenes (`dict`). - ``'test'`` - The test scenes (`dict`). + indices : `list` [`dict`] + List of index dictionaries as composed by + :py:func:`pysegcnn.core.split.index_dict`. """ # set the random seed for reproducibility np.random.seed(seed) - # get the names of the scenes and generate random permutation - scene_ids = np.random.permutation(np.unique([s['id'] for s in ds.scenes])) + # whether to shuffle the data before splitting + indices = np.arange(len(ds)) + if shuffle: + # randomly permute indices to access the iterable + indices = np.random.permutation(indices) # the training and validation scenes - # number of samples: (ttratio * nscenes) - trav_len = _ds_len(scene_ids, ttratio) - trav_scenes = scene_ids[:trav_len] - - # the training scenes - # number of samples: (ttratio * tvratio * nscenes) - train_len = _ds_len(trav_scenes, tvratio) - train_scenes = trav_scenes[:train_len] - - # the validation scenes - # number of samples: (ttratio * (1- tvratio) * nscenes) - valid_scenes = trav_scenes[train_len:] - - # the test scenes - # number of samples:((1 - ttratio) * nscenes) - test_scenes = scene_ids[trav_len:] + # number of samples: (ttratio * len(ds)) + trav_len = _ds_len(ds, ttratio) + trav_ids = indices[:trav_len] - # get the tiles of the scenes of each dataset - subsets = {} - for name, dataset in enumerate([train_scenes, valid_scenes, test_scenes]): + # the training dataset indices + # number of samples: (ttratio * tvratio * len(ds)) + train_len = _ds_len(trav_ids, tvratio) + train_ids = trav_ids[:train_len] - # store the indices and corresponding tiles of the current subset to - # dictionary - subsets[SUBSET_NAMES[name]] = {k: v for k, v in enumerate(ds.scenes) - if v['id'] in dataset} + # the validation dataset indices + # number of samples: (ttratio * (1- tvratio) * len(ds)) + valid_ids = trav_ids[train_len:] - # check if the splits are disjoint - assert pairwise_disjoint([s.keys() for s in subsets.values()]) + # the test dataset indices + # number of samples:((1 - ttratio) * len(ds)) + test_ids = trav_ids[trav_len:] - return subsets + # check whether the different datasets or pairwise disjoint + indices = index_dict([train_ids, valid_ids, test_ids]) + assert pairwise_disjoint(indices.values()) + return [indices] -def date_scene_split(ds, date, dateformat='%Y%m%d'): - """Split the dataset based on a date. - Scenes before ``date`` build the training dataset, scenes after ``date`` - the validation dataset. The test set is empty. +def kfold_split(ds, k_folds=5, seed=0, shuffle=True): + """Randomly split an iterable into ``k_folds`` folds. - Useful for time series data. + This function uses the cross validation index generator + :py:class:`sklearn.model_selection.KFold`. Parameters ---------- - ds : :py:class:`pysegcnn.core.dataset.ImageDataset` - An instance of :py:class:`pysegcnn.core.dataset.ImageDataset`. - date : `str` - A date in the format ``dateformat``. - dateformat : `str`, optional - The format of ``date``. ``dateformat`` is used by - :py:func:`datetime.datetime.strptime' to parse ``date`` to a - :py:class:`datetime.datetime` object. The default is `'%Y%m%d'`. + ds : :py:class:`collections.Sized` + An object with a :py:meth:`__len__` method. + k_folds: `int`, optional + The number of folds. Must be a least 2. The default is `5`. + seed : `int`, optional + The random seed for reproducibility. The default is `0`. + shuffle : `bool`, optional + Whether to shuffle the data before splitting into batches. The default + is `True`. Raises ------ AssertionError - Raised if the splits are not pairwise disjoint. - - Returns - ------- - subsets : `dict` - Subset dictionary with keys: - ``'train'`` - The training scenes (`dict`). - ``'valid'`` - The validation scenes (`dict`). - ``'test'`` - The test scenes (`dict`). + Raised if the (training, validation) folds are not pairwise disjoint. """ - # convert date to datetime object - date = datetime.datetime.strptime(date, dateformat) - - # the training, validation and test scenes - train_scenes = {i: s for i, s in enumerate(ds.scenes) if s['date'] <= date} - valid_scenes = {i: s for i, s in enumerate(ds.scenes) if s['date'] > date} - test_scenes = {} - - # build the training and test datasets - subsets = {} - for name, scenes in enumerate([train_scenes, valid_scenes, test_scenes]): - - # store the indices and corresponding tiles of the current subset to - # dictionary - subsets[SUBSET_NAMES[name]] = scenes - - # check if the splits are disjoint - assert pairwise_disjoint([s.keys() for s in subsets.values()]) - - return subsets - - -def pairwise_disjoint(sets): - """Check if ``sets`` are pairwise disjoint. - - Sets are pairwise disjoint if the length of their union equals the sum of - their lengths. - - Parameters - ---------- - sets : `list` [:py:class:`collections.Sized`] - A list of sized objects. - - Returns - ------- - disjoint : `bool` - Whether the sets are pairwise disjoint. - - """ - union = set().union(*sets) - n = sum(len(u) for u in sets) - return n == len(union) - - -class CustomSubset(Subset): - """Generic custom subset inheriting :py:class:`torch.utils.data.Subset`. - - .. important:: + # set the random seed for reproducibility + np.random.seed(seed) - The training, validation and test datasets should be subclasses of - :py:class:`pysegcnn.core.split.CustomSubset`. + # cross validation index generator from scikit-learn + kf = KFold(k_folds, random_state=seed, shuffle=shuffle) - See :py:class:`pysegcnn.core.split.RandomTileSplit` for an example - implementing the :py:class:`pysegcnn.core.split.RandomSubset` subset - class. + # generate the indices of the different folds + folds = [] + for i, (train, valid) in enumerate(kf.split(ds)): + folds.append(index_dict([train, valid, np.array([])])) + assert pairwise_disjoint(folds[i].values()) + return folds - Attributes - ---------- - dataset : :py:class:`pysegcnn.core.dataset.ImageDataset` - The dataset to split into subsets. - split_mode : `str` - The mode to split the dataset. - indices : `list` [`int`] - List of indices to access the dataset. - name : `str` - Name of the subset. - scenes : `list` [`dict`] - List of the subset tiles. - ids : `list` or :py:class:`numpy.ndarray` - Container of the scene identifiers. - """ +class RandomSplit(object): + """Base class for random splits of a `torch.utils.data.Dataset`.""" - def __init__(self, ds, split_mode, indices, name, scenes, scene_ids): - """Initialize. + def __init__(self, ds, k_folds, seed=0, shuffle=True, tvratio=0.8, + ttratio=1): + """Randomly split a dataset into training, validation and test set. Parameters ---------- - ds : :py:class:`pysegcnn.core.dataset.ImageDataset` - An instance of :py:class:`pysegcnn.core.dataset.ImageDataset`. - split_mode : `str` - The mode to split the dataset. - indices : `list` [`int`] - List of indices to access ``ds``. ``indices`` must be pairwise - disjoint for each subset derived from the same dataset ``ds``. - name : `str` - Name of the subset. - scenes : `list` [`dict`] - List of the subset tiles. - scene_ids : `list` or :py:class:`numpy.ndarray` - Container of the scene identifiers. - - """ - super().__init__(dataset=ds, indices=indices) - - # the mode to split the dataset - self.split_mode = split_mode - - # the name of the subset - self.name = name - - # the scene in the subset - self.scenes = scenes - - # the names of the scenes - self.ids = scene_ids - - def __repr__(self): - """Representation string. - - Returns - ------- - fs : `str` - The representation string. + ds : :py:class:`collections.Sized` + An object with a :py:meth:`__len__` method. + k_folds: `int` + The number of folds. + seed : `int`, optional + The random seed for reproducibility. The default is `0`. + shuffle : `bool`, optional + Whether to shuffle the data before splitting into batches. The + default is `True`. + tvratio : `float`, optional + The ratio of training data to validation data, e.g. ``tvratio=0.8`` + means 80% training, 20% validation. The default is `0.8`. Used if + ``k_folds=1``. + ttratio : `float`, optional + The ratio of training and validation data to test data, e.g. + ``ttratio=0.6`` means 60% for training and validation, 40% for + testing. The default is `1`. Used if ``k_folds=1``. """ - fs = '- {}: {:d} tiles ({:.2f}%), mode = {}'.format( - self.name, len(self.scenes), 100 * len(self.scenes) / - len(self.dataset), self.split_mode) - - return fs - - -class SceneSubset(CustomSubset): - """A custom subset for dataset splits where the scenes are preserved.""" - - def __init__(self, ds, split_mode, indices, name, scenes, scene_ids): - super().__init__(ds, split_mode, indices, name, scenes, scene_ids) - - -class RandomSubset(CustomSubset): - """A custom subset for random dataset splits.""" - - def __init__(self, ds, split_mode, indices, name, scenes, scene_ids): - super().__init__(ds, split_mode, indices, name, scenes, scene_ids) - - -class Split(object): - """Generic class handling how ``ds`` is split. - - Each dataset should be split by a subclass of - :py:class:`pysegcnn.core.split.Split`, by calling the - :py:meth:`pysegcnn.core.split.Split.split` method. - - .. important:: - The :py:meth:`~pysegcnn.core.split.Split.subsets` and - :py:meth:`~pysegcnn.core.split.Split.subset_type` methods have to be - implemented when inheriting :py:class:`pysegcnn.core.split.Split`. - Furthermore, a class attribute ``split_mode`` (`str`) has to be - defined and added to :py:class:`pysegcnn.core.split.SupportedSplits`. - - See :py:class:`pysegcnn.core.split.RandomTileSplit` for an example. - - Attributes - ---------- - ds : :py:class:`pysegcnn.core.dataset.ImageDataset` - The dataset to split into training, validation and test set. - - """ - - def __init__(self, ds): - """Initialize. - - Parameters - ---------- - ds : :py:class:`pysegcnn.core.dataset.ImageDataset` - An instance of :py:class:`pysegcnn.core.dataset.ImageDataset`. - - """ - # the dataset to split + # instance attributes self.ds = ds + self.k_folds = k_folds + self.seed = seed + self.shuffle = shuffle - def split(self): - """Split dataset into training, validation and test set. - - :py:meth:`~pysegcnn.core.split.Split.split` works only if - :py:meth:`~pysegcnn.core.split.Split.subsets` and - :py:meth:`~pysegcnn.core.split.Split.subset_type` are implemented. - - """ - # build the subsets - ds_split = [] - for name, sub in self.subsets().items(): - - # the scene identifiers of the current subset: preserve the order - # of the scene identifiers - ids, idx = np.unique([s['id'] for s in sub.values()], - return_index=True) - ids = ids[np.argsort(idx)] - - # build the subset - sbst = self.subset_type()(self.ds, self.split_mode, - list(sub.keys()), name, - list(sub.values()), ids) - ds_split.append(sbst) - - return ds_split - - def subsets(self): - """Define training, validation and test sets. + # instance attributes: training/validation/test split ratios + # used if kfolds=1 + self.tvratio = tvratio + self.ttratio = ttratio - Wrapper method for - :py:func:`pysegcnn.core.split.Split.random_tile_split`, - :py:func:`pysegcnn.core.split.Split.random_scene_split` or - :py:func:`pysegcnn.core.split.Split.date_scene_split`. + def generate_splits(self): - Raises - ------ - NotImplementedError - Raised if :py:class:`pysegcnn.core.split.Split` is not inherited. + # check whether to generate a single or multiple folds + if self.k_folds > 1: + # k-fold split + indices = kfold_split( + self.indices_to_split, self.k_folds, self.seed, self.shuffle) + else: + # single-fold split + indices = random_split( + self.indices_to_split, self.tvratio, self.ttratio, self.seed, + self.shuffle) - Returns - ------- - None. + return indices - """ + @property + def indices_to_split(self): raise NotImplementedError - def subset_type(self): - """Define the type of each subset. - - Wrapper method for :py:class:`pysegcnn.core.split.RandomSubset` or - :py:class:`pysegcnn.core.split.SceneSubset`. - - Raises - ------ - NotImplementedError - Raised if :py:class:`pysegcnn.core.split.Split` is not inherited. - - Returns - ------- - None. - - """ + @property + def indices(self): raise NotImplementedError + def split(self): -class DateSplit(Split): - """Split a dataset based on a date. - - .. important:: - - Scenes before ``date`` build the training dataset, scenes after - ``date`` the validation dataset. The test set is empty. - - Useful for time series data. - - Class wrapper for :py:func:`pysegcnn.core.split.date_scene_split`. - - Attributes - ---------- - split_mode : `str` - The mode to split the dataset, i.e. `'date'`. - ds : :py:class:`pysegcnn.core.dataset.ImageDataset` - The dataset to split into training, validation and test set. - date : `str` - The date used to split the dataset. - dateformat : `str` - The format of ``date``. - - """ - - # the split mode - split_mode = 'date' - - def __init__(self, ds, date, dateformat): - """Initialize. - - Parameters - ---------- - ds : :py:class:`pysegcnn.core.dataset.ImageDataset` - An instance of :py:class:`pysegcnn.core.dataset.ImageDataset`. - date : `str` - A date in the format ``dateformat``. - dateformat : `str` - The format of ``date``. ``dateformat`` is used by - :py:func:`datetime.datetime.strptime' to parse ``date`` to a - :py:class:`datetime.datetime` object. - - """ - super().__init__(ds) - - # the date to split the dataset - # before: training set - # after : validation set - self.date = date - - # the format of the date - self.dateformat = dateformat - - def subsets(self): - """Wrap :py:func:`pysegcnn.core.split.Split.date_scene_split`. - - Returns - ------- - subsets : `dict` - Subset dictionary with keys: - ``'train'`` - The training scenes (`dict`). - ``'valid'`` - The validation scenes (`dict`). - ``'test'`` - The test scenes, empty (`dict`). - - """ - return date_scene_split(self.ds, self.date, self.dateformat) - - def subset_type(self): - """Wrap :py:class:`pysegcnn.core.split.SceneSubset`. - - Returns - ------- - SceneSubset : :py:class:`pysegcnn.core.split.SceneSubset` - The subset type. - - """ - return SceneSubset - - -class RandomSplit(Split): - """Generic class for random dataset splits.""" - - def __init__(self, ds, ttratio, tvratio, seed): - """Initialize. - - Parameters - ---------- - ds : :py:class:`pysegcnn.core.dataset.ImageDataset` - An instance of :py:class:`pysegcnn.core.dataset.ImageDataset`. - tvratio : `float` - The ratio of training data to validation data, e.g. - ``tvratio=0.8`` means 80% training, 20% validation. - ttratio : `float` - The ratio of training and validation data to test data, e.g. - ``ttratio=0.6`` means 60% for training and validation, 40% for - testing. - seed : `int` - The random seed used to generate the split. Useful for - reproducibility. - - """ - super().__init__(ds) + # initialize training, validation and test subsets + subsets = [] - # the training, validation and test set ratios - self.ttratio = ttratio - self.tvratio = tvratio + # the training, validation and test indices + for folds in self.indices: + subsets.append( + index_dict([Subset(self.ds, ids) for ids in folds.values()])) - # the random seed: useful for reproducibility - self.seed = seed + return subsets class RandomTileSplit(RandomSplit): - """Randomly split the dataset. + """Split a :py:class:`pysegcnn.core.dataset.ImageDataset` into tiles.""" - .. important:: + def __init__(self, ds, k_folds, seed=0, shuffle=True, tvratio=0.8, + ttratio=1): + # initialize super class + super().__init__(ds, k_folds, seed, shuffle, tvratio, ttratio) - For each scene, the tiles of the scene can be distributed among the - training, validation and test set. + @property + def indices_to_split(self): + return np.arange(len(self.ds)) - Class wrapper for :py:func:`pysegcnn.core.split.random_tile_split`. - - Attributes - ---------- - split_mode : `str` - The mode to split the dataset, i.e. `'random'`. - ds : :py:class:`pysegcnn.core.dataset.ImageDataset` - The dataset to split into training, validation and test set. - tvratio : `float` - The ratio of training data to validation data. - ttratio : `float` - The ratio of training and validation data to test data. - seed : `int` - The random seed used to generate the split. - - """ - - # the split mode - split_mode = 'random' - - def __init__(self, ds, ttratio, tvratio, seed): - super().__init__(ds, ttratio, tvratio, seed) - - def subsets(self): - """Wrap :py:func:`pysegcnn.core.split.Split.random_tile_split`. - - Returns - ------- - subsets : `dict` - Subset dictionary with keys: - ``'train'`` - The training scenes (`dict`). - ``'valid'`` - The validation scenes (`dict`). - ``'test'`` - The test scenes (`dict`). - - """ - return random_tile_split(self.ds, self.tvratio, self.ttratio, - self.seed) - - def subset_type(self): - """Wrap :py:class:`pysegcnn.core.split.RandomSubset`. - - Returns - ------- - SceneSubset : :py:class:`pysegcnn.core.split.RandomSubset` - The subset type. - - """ - return RandomSubset + @property + def indices(self): + return self.generate_splits() class RandomSceneSplit(RandomSplit): - """Semi-randomly split the dataset. - - .. important:: - - For each scene, all the tiles of the scene are included in either the - training, validation or test set, respectively. - - Class wrapper for :py:func:`pysegcnn.core.split.random_scene_split`. - - Attributes - ---------- - split_mode : `str` - The mode to split the dataset, i.e. `'scene'`. - ds : :py:class:`pysegcnn.core.dataset.ImageDataset` - The dataset to split into training, validation and test set. - tvratio : `float` - The ratio of training data to validation data. - ttratio : `float` - The ratio of training and validation data to test data. - seed : `int` - The random seed used to generate the split. + """Split a :py:class:`pysegcnn.core.dataset.ImageDataset` into scenes.""" - """ - - # the split mode - split_mode = 'scene' - - def __init__(self, ds, ttratio, tvratio, seed): - super().__init__(ds, ttratio, tvratio, seed) + def __init__(self, ds, k_folds, seed=0, shuffle=True, tvratio=0.8, + ttratio=1): + # initialize super class + super().__init__(ds, k_folds, seed, shuffle, tvratio, ttratio) - def subsets(self): - """Wrap :py:func:`pysegcnn.core.split.Split.random_scene_split`. + # the number of the scenes in the dataset + self.scenes = np.array([v['scene'] for v in self.ds.scenes]) - Returns - ------- - subsets : `dict` - Subset dictionary with keys: - ``'train'`` - The training scenes (`dict`). - ``'valid'`` - The validation scenes (`dict`). - ``'test'`` - The test scenes (`dict`). + @property + def indices_to_split(self): + return np.unique(self.scenes) - """ - return random_scene_split(self.ds, self.tvratio, self.ttratio, - self.seed) - - def subset_type(self): - """Wrap :py:class:`pysegcnn.core.split.SceneSubset`. + @property + def indices(self): + # indices of the different scene identifiers + indices = self.generate_splits() - Returns - ------- - SceneSubset : :py:class:`pysegcnn.core.split.SceneSubset` - The subset type. + # iterate over the different folds + scene_indices = [] + for folds in indices: + # iterate over the training, validation and test set + subset = {} + for name, ids in folds.items(): + subset[name] = np.where(np.isin(self.scenes, ids))[0] + scene_indices.append(subset) - """ - return SceneSubset + return scene_indices class SupportedSplits(enum.Enum): """Names and corresponding classes of the implemented split modes.""" - random = RandomTileSplit + tile = RandomTileSplit scene = RandomSceneSplit - date = DateSplit -- GitLab