From ad3d88eb49f7ac02eea10e60509b8382669732f3 Mon Sep 17 00:00:00 2001
From: "Daniel.Frisinghelli" <daniel.frisinghelli@eurac.edu>
Date: Fri, 22 Jan 2021 17:41:30 +0100
Subject: [PATCH] Implemented cross-validation subsampling.

---
 pysegcnn/core/split.py | 746 ++++++++++-------------------------------
 1 file changed, 174 insertions(+), 572 deletions(-)

diff --git a/pysegcnn/core/split.py b/pysegcnn/core/split.py
index e70ffe4..49609c8 100644
--- a/pysegcnn/core/split.py
+++ b/pysegcnn/core/split.py
@@ -15,11 +15,11 @@ License
 # -*- coding: utf-8 -*-
 
 # builtins
-import datetime
 import enum
 
 # externals
 import numpy as np
+from sklearn.model_selection import KFold
 from torch.utils.data.dataset import Subset
 
 # the names of the subsets
@@ -45,93 +45,50 @@ def _ds_len(ds, ratio):
     return int(np.round(len(ds) * ratio))
 
 
-def random_tile_split(ds, tvratio, ttratio=1, seed=0):
-    """Randomly split the tiles of a dataset.
-
-    For each scene, the tiles of the scene can be distributed among the
-    training, validation and test set.
-
-    The parameters ``ttratio`` and ``tvratio`` control the size of the
-    training, validation and test datasets.
+def pairwise_disjoint(sets):
+    """Check if ``sets`` are pairwise disjoint.
 
-    Test dataset size      : ``(1 - ttratio) * len(ds)``
-    Train dataset size     : ``ttratio * tvratio * len(ds)``
-    Validation dataset size: ``ttratio * (1 - tvratio) * len(ds)``
+    Sets are pairwise disjoint if the length of their union equals the sum of
+    their lengths.
 
     Parameters
     ----------
-    ds : :py:class:`pysegcnn.core.dataset.ImageDataset`
-        An instance of :py:class:`pysegcnn.core.dataset.ImageDataset`.
-    tvratio : `float`
-        The ratio of training data to validation data, e.g. ``tvratio=0.8``
-        means 80% training, 20% validation.
-    ttratio : `float`, optional
-        The ratio of training and validation data to test data, e.g.
-        ``ttratio=0.6`` means 60% for training and validation, 40% for
-        testing. The default is `1`.
-    seed : `int`, optional
-        The random seed for reproducibility. The default is `0`.
-
-    Raises
-    ------
-    AssertionError
-        Raised if the splits are not pairwise disjoint.
+    sets : `list` [:py:class:`collections.Sized`]
+        A list of sized objects.
 
     Returns
     -------
-    subsets : `dict`
-        Subset dictionary with keys:
-            ``'train'``
-                The training scenes (`dict`).
-            ``'valid'``
-                The validation scenes (`dict`).
-            ``'test'``
-                The test scenes (`dict`).
+    disjoint : `bool`
+        Whether the sets are pairwise disjoint.
 
     """
-    # set the random seed for reproducibility
-    np.random.seed(seed)
-
-    # randomly permute indices to access dataset
-    indices = np.random.permutation(len(ds))
-
-    # length of the training and validation dataset
-    # number of samples: (ttratio * len(ds))
-    trav_len = _ds_len(indices, ttratio)
-    trav_indices = indices[:trav_len]
-
-    # length of the training dataset
-    # number of samples: (ttratio * tvratio * len(ds))
-    train_len = _ds_len(trav_indices, tvratio)
-    train_ind = trav_indices[:train_len]
-
-    # length of the validation dataset
-    # number of samples: (ttratio * (1- tvratio) * len(ds))
-    valid_ind = trav_indices[train_len:]
-
-    # length of the test dataset
-    # number of samples: ((1 - ttratio) * len(ds))
-    test_ind = indices[trav_len:]
+    union = set().union(*sets)
+    n = sum(len(u) for u in sets)
+    return n == len(union)
 
-    # get the tiles of the scenes of each dataset
-    subsets = {}
-    for name, dataset in enumerate([train_ind, valid_ind, test_ind]):
 
-        # store the indices and corresponding tiles of the current subset to
-        # dictionary
-        subsets[SUBSET_NAMES[name]] = {k: ds.scenes[k] for k in dataset}
+def index_dict(indices):
+    """Generate the training, validation and test set index dictionary.
 
-    # check if the splits are disjoint
-    assert pairwise_disjoint([s.keys() for s in subsets.values()])
+    Parameters
+    ----------
+    indices : `list` [:py:class:`numpy.ndarray`]
+        An ordered list composed of three :py:class:`numpy.ndarray` containing
+        the indices to the training, validation and test set.
 
-    return subsets
+    Returns
+    -------
+    index_dict : `dict`
+        The index dictionary, where the keys are equal to ``SUBSET_NAMES`` and
+        the values are py:class:`numpy.ndarray` containing the indices to the
+        training, validation and test set.
 
+    """
+    return {k: v for k, v in zip(SUBSET_NAMES, indices)}
 
-def random_scene_split(ds, tvratio, ttratio=1, seed=0):
-    """Semi-randomly split the tiles of a dataset.
 
-    For each scene, all the tiles of the scene are included in either the
-    training, validation or test set, respectively.
+def random_split(ds, tvratio=0.8, ttratio=1, seed=0, shuffle=True):
+    """Randomly split an iterable into training, validation and test set.
 
     The parameters ``ttratio`` and ``tvratio`` control the size of the
     training, validation and test datasets.
@@ -142,17 +99,20 @@ def random_scene_split(ds, tvratio, ttratio=1, seed=0):
 
     Parameters
     ----------
-    ds : :py:class:`pysegcnn.core.dataset.ImageDataset`
-        An instance of :py:class:`pysegcnn.core.dataset.ImageDataset`.
-    tvratio : `float`
+    ds : :py:class:`collections.Sized`
+        An object with a :py:meth:`__len__` method.
+    tvratio : `float`, optional
         The ratio of training data to validation data, e.g. ``tvratio=0.8``
-        means 80% training, 20% validation.
+        means 80% training, 20% validation. The default is `0.8`.
     ttratio : `float`, optional
         The ratio of training and validation data to test data, e.g.
         ``ttratio=0.6`` means 60% for training and validation, 40% for
         testing. The default is `1`.
     seed : `int`, optional
         The random seed for reproducibility. The default is `0`.
+    shuffle : `bool`, optional
+        Whether to shuffle the data before splitting into batches. The default
+        is `True`.
 
     Raises
     ------
@@ -161,568 +121,210 @@ def random_scene_split(ds, tvratio, ttratio=1, seed=0):
 
     Returns
     -------
-    subsets : `dict`
-        Subset dictionary with keys:
-            ``'train'``
-                The training scenes (`dict`).
-            ``'valid'``
-                The validation scenes (`dict`).
-            ``'test'``
-                The test scenes (`dict`).
+    indices : `list` [`dict`]
+        List of index dictionaries as composed by
+        :py:func:`pysegcnn.core.split.index_dict`.
 
     """
     # set the random seed for reproducibility
     np.random.seed(seed)
 
-    # get the names of the scenes and generate random permutation
-    scene_ids = np.random.permutation(np.unique([s['id'] for s in ds.scenes]))
+    # whether to shuffle the data before splitting
+    indices = np.arange(len(ds))
+    if shuffle:
+        # randomly permute indices to access the iterable
+        indices = np.random.permutation(indices)
 
     # the training and validation scenes
-    # number of samples: (ttratio * nscenes)
-    trav_len = _ds_len(scene_ids, ttratio)
-    trav_scenes = scene_ids[:trav_len]
-
-    # the training scenes
-    # number of samples: (ttratio * tvratio * nscenes)
-    train_len = _ds_len(trav_scenes, tvratio)
-    train_scenes = trav_scenes[:train_len]
-
-    # the validation scenes
-    # number of samples: (ttratio * (1- tvratio) * nscenes)
-    valid_scenes = trav_scenes[train_len:]
-
-    # the test scenes
-    # number of samples:((1 - ttratio) * nscenes)
-    test_scenes = scene_ids[trav_len:]
+    # number of samples: (ttratio * len(ds))
+    trav_len = _ds_len(ds, ttratio)
+    trav_ids = indices[:trav_len]
 
-    # get the tiles of the scenes of each dataset
-    subsets = {}
-    for name, dataset in enumerate([train_scenes, valid_scenes, test_scenes]):
+    # the training dataset indices
+    # number of samples: (ttratio * tvratio * len(ds))
+    train_len = _ds_len(trav_ids, tvratio)
+    train_ids = trav_ids[:train_len]
 
-        # store the indices and corresponding tiles of the current subset to
-        # dictionary
-        subsets[SUBSET_NAMES[name]] = {k: v for k, v in enumerate(ds.scenes)
-                                       if v['id'] in dataset}
+    # the validation dataset indices
+    # number of samples: (ttratio * (1- tvratio) * len(ds))
+    valid_ids = trav_ids[train_len:]
 
-    # check if the splits are disjoint
-    assert pairwise_disjoint([s.keys() for s in subsets.values()])
+    # the test dataset indices
+    # number of samples:((1 - ttratio) * len(ds))
+    test_ids = trav_ids[trav_len:]
 
-    return subsets
+    # check whether the different datasets or pairwise disjoint
+    indices = index_dict([train_ids, valid_ids, test_ids])
+    assert pairwise_disjoint(indices.values())
 
+    return [indices]
 
-def date_scene_split(ds, date, dateformat='%Y%m%d'):
-    """Split the dataset based on a date.
 
-    Scenes before ``date`` build the training dataset, scenes after ``date``
-    the validation dataset. The test set is empty.
+def kfold_split(ds, k_folds=5, seed=0, shuffle=True):
+    """Randomly split an iterable into ``k_folds`` folds.
 
-    Useful for time series data.
+    This function uses the cross validation index generator
+    :py:class:`sklearn.model_selection.KFold`.
 
     Parameters
     ----------
-    ds : :py:class:`pysegcnn.core.dataset.ImageDataset`
-        An instance of :py:class:`pysegcnn.core.dataset.ImageDataset`.
-    date : `str`
-        A date in the format ``dateformat``.
-    dateformat : `str`, optional
-        The format of ``date``. ``dateformat`` is used by
-        :py:func:`datetime.datetime.strptime' to parse ``date`` to a
-        :py:class:`datetime.datetime` object. The default is `'%Y%m%d'`.
+    ds : :py:class:`collections.Sized`
+        An object with a :py:meth:`__len__` method.
+    k_folds: `int`, optional
+        The number of folds. Must be a least 2. The default is `5`.
+    seed : `int`, optional
+        The random seed for reproducibility. The default is `0`.
+    shuffle : `bool`, optional
+        Whether to shuffle the data before splitting into batches. The default
+        is `True`.
 
     Raises
     ------
     AssertionError
-        Raised if the splits are not pairwise disjoint.
-
-    Returns
-    -------
-    subsets : `dict`
-        Subset dictionary with keys:
-            ``'train'``
-                The training scenes (`dict`).
-            ``'valid'``
-                The validation scenes (`dict`).
-            ``'test'``
-                The test scenes (`dict`).
+        Raised if the (training, validation) folds are not pairwise disjoint.
 
     """
-    # convert date to datetime object
-    date = datetime.datetime.strptime(date, dateformat)
-
-    # the training, validation and test scenes
-    train_scenes = {i: s for i, s in enumerate(ds.scenes) if s['date'] <= date}
-    valid_scenes = {i: s for i, s in enumerate(ds.scenes) if s['date'] > date}
-    test_scenes = {}
-
-    # build the training and test datasets
-    subsets = {}
-    for name, scenes in enumerate([train_scenes, valid_scenes, test_scenes]):
-
-        # store the indices and corresponding tiles of the current subset to
-        # dictionary
-        subsets[SUBSET_NAMES[name]] = scenes
-
-    # check if the splits are disjoint
-    assert pairwise_disjoint([s.keys() for s in subsets.values()])
-
-    return subsets
-
-
-def pairwise_disjoint(sets):
-    """Check if ``sets`` are pairwise disjoint.
-
-    Sets are pairwise disjoint if the length of their union equals the sum of
-    their lengths.
-
-    Parameters
-    ----------
-    sets : `list` [:py:class:`collections.Sized`]
-        A list of sized objects.
-
-    Returns
-    -------
-    disjoint : `bool`
-        Whether the sets are pairwise disjoint.
-
-    """
-    union = set().union(*sets)
-    n = sum(len(u) for u in sets)
-    return n == len(union)
-
-
-class CustomSubset(Subset):
-    """Generic custom subset inheriting :py:class:`torch.utils.data.Subset`.
-
-    .. important::
+    # set the random seed for reproducibility
+    np.random.seed(seed)
 
-        The training, validation and test datasets should be subclasses of
-        :py:class:`pysegcnn.core.split.CustomSubset`.
+    # cross validation index generator from scikit-learn
+    kf = KFold(k_folds, random_state=seed, shuffle=shuffle)
 
-    See :py:class:`pysegcnn.core.split.RandomTileSplit` for an example
-    implementing the :py:class:`pysegcnn.core.split.RandomSubset` subset
-    class.
+    # generate the indices of the different folds
+    folds = []
+    for i, (train, valid) in enumerate(kf.split(ds)):
+        folds.append(index_dict([train, valid, np.array([])]))
+        assert pairwise_disjoint(folds[i].values())
+    return folds
 
-    Attributes
-    ----------
-    dataset : :py:class:`pysegcnn.core.dataset.ImageDataset`
-        The dataset to split into subsets.
-    split_mode : `str`
-        The mode to split the dataset.
-    indices : `list` [`int`]
-        List of indices to access the dataset.
-    name : `str`
-        Name of the subset.
-    scenes : `list` [`dict`]
-        List of the subset tiles.
-    ids : `list` or :py:class:`numpy.ndarray`
-        Container of the scene identifiers.
 
-    """
+class RandomSplit(object):
+    """Base class for random splits of a `torch.utils.data.Dataset`."""
 
-    def __init__(self, ds, split_mode, indices, name, scenes, scene_ids):
-        """Initialize.
+    def __init__(self, ds, k_folds, seed=0, shuffle=True, tvratio=0.8,
+                 ttratio=1):
+        """Randomly split a dataset into training, validation and test set.
 
         Parameters
         ----------
-        ds : :py:class:`pysegcnn.core.dataset.ImageDataset`
-            An instance of :py:class:`pysegcnn.core.dataset.ImageDataset`.
-        split_mode : `str`
-            The mode to split the dataset.
-        indices : `list` [`int`]
-            List of indices to access ``ds``. ``indices`` must be pairwise
-            disjoint for each subset derived from the same dataset ``ds``.
-        name : `str`
-            Name of the subset.
-        scenes : `list` [`dict`]
-            List of the subset tiles.
-        scene_ids : `list` or :py:class:`numpy.ndarray`
-            Container of the scene identifiers.
-
-        """
-        super().__init__(dataset=ds, indices=indices)
-
-        # the mode to split the dataset
-        self.split_mode = split_mode
-
-        # the name of the subset
-        self.name = name
-
-        # the scene in the subset
-        self.scenes = scenes
-
-        # the names of the scenes
-        self.ids = scene_ids
-
-    def __repr__(self):
-        """Representation string.
-
-        Returns
-        -------
-        fs : `str`
-            The representation string.
+        ds : :py:class:`collections.Sized`
+            An object with a :py:meth:`__len__` method.
+        k_folds: `int`
+            The number of folds.
+        seed : `int`, optional
+            The random seed for reproducibility. The default is `0`.
+        shuffle : `bool`, optional
+            Whether to shuffle the data before splitting into batches. The
+            default is `True`.
+        tvratio : `float`, optional
+            The ratio of training data to validation data, e.g. ``tvratio=0.8``
+            means 80% training, 20% validation. The default is `0.8`. Used if
+            ``k_folds=1``.
+        ttratio : `float`, optional
+            The ratio of training and validation data to test data, e.g.
+            ``ttratio=0.6`` means 60% for training and validation, 40% for
+            testing. The default is `1`. Used if ``k_folds=1``.
 
         """
-        fs = '- {}: {:d} tiles ({:.2f}%), mode = {}'.format(
-            self.name, len(self.scenes), 100 * len(self.scenes) /
-            len(self.dataset), self.split_mode)
-
-        return fs
-
-
-class SceneSubset(CustomSubset):
-    """A custom subset for dataset splits where the scenes are preserved."""
-
-    def __init__(self, ds, split_mode, indices, name, scenes, scene_ids):
-        super().__init__(ds, split_mode, indices, name, scenes, scene_ids)
-
-
-class RandomSubset(CustomSubset):
-    """A custom subset for random dataset splits."""
-
-    def __init__(self, ds, split_mode, indices, name, scenes, scene_ids):
-        super().__init__(ds, split_mode, indices, name, scenes, scene_ids)
-
-
-class Split(object):
-    """Generic class handling how ``ds`` is split.
-
-    Each dataset should be split by a subclass of
-    :py:class:`pysegcnn.core.split.Split`, by calling the
-    :py:meth:`pysegcnn.core.split.Split.split` method.
-
-    .. important::
 
-        The :py:meth:`~pysegcnn.core.split.Split.subsets` and
-        :py:meth:`~pysegcnn.core.split.Split.subset_type` methods have to be
-        implemented when inheriting :py:class:`pysegcnn.core.split.Split`.
-        Furthermore, a class attribute ``split_mode`` (`str`) has to be
-        defined and added to :py:class:`pysegcnn.core.split.SupportedSplits`.
-
-    See :py:class:`pysegcnn.core.split.RandomTileSplit` for an example.
-
-    Attributes
-    ----------
-    ds : :py:class:`pysegcnn.core.dataset.ImageDataset`
-        The dataset to split into training, validation and test set.
-
-    """
-
-    def __init__(self, ds):
-        """Initialize.
-
-        Parameters
-        ----------
-        ds : :py:class:`pysegcnn.core.dataset.ImageDataset`
-            An instance of :py:class:`pysegcnn.core.dataset.ImageDataset`.
-
-        """
-        # the dataset to split
+        # instance attributes
         self.ds = ds
+        self.k_folds = k_folds
+        self.seed = seed
+        self.shuffle = shuffle
 
-    def split(self):
-        """Split dataset into training, validation and test set.
-
-        :py:meth:`~pysegcnn.core.split.Split.split` works only if
-        :py:meth:`~pysegcnn.core.split.Split.subsets` and
-        :py:meth:`~pysegcnn.core.split.Split.subset_type` are implemented.
-
-        """
-        # build the subsets
-        ds_split = []
-        for name, sub in self.subsets().items():
-
-            # the scene identifiers of the current subset: preserve the order
-            # of the scene identifiers
-            ids, idx = np.unique([s['id'] for s in sub.values()],
-                                 return_index=True)
-            ids = ids[np.argsort(idx)]
-
-            # build the subset
-            sbst = self.subset_type()(self.ds, self.split_mode,
-                                      list(sub.keys()), name,
-                                      list(sub.values()), ids)
-            ds_split.append(sbst)
-
-        return ds_split
-
-    def subsets(self):
-        """Define training, validation and test sets.
+        # instance attributes: training/validation/test split ratios
+        # used if kfolds=1
+        self.tvratio = tvratio
+        self.ttratio = ttratio
 
-        Wrapper method for
-        :py:func:`pysegcnn.core.split.Split.random_tile_split`,
-        :py:func:`pysegcnn.core.split.Split.random_scene_split` or
-        :py:func:`pysegcnn.core.split.Split.date_scene_split`.
+    def generate_splits(self):
 
-        Raises
-        ------
-        NotImplementedError
-            Raised if :py:class:`pysegcnn.core.split.Split` is not inherited.
+        # check whether to generate a single or multiple folds
+        if self.k_folds > 1:
+            # k-fold split
+            indices = kfold_split(
+                self.indices_to_split, self.k_folds, self.seed, self.shuffle)
+        else:
+            # single-fold split
+            indices = random_split(
+                self.indices_to_split, self.tvratio, self.ttratio, self.seed,
+                self.shuffle)
 
-        Returns
-        -------
-        None.
+        return indices
 
-        """
+    @property
+    def indices_to_split(self):
         raise NotImplementedError
 
-    def subset_type(self):
-        """Define the type of each subset.
-
-        Wrapper method for :py:class:`pysegcnn.core.split.RandomSubset` or
-        :py:class:`pysegcnn.core.split.SceneSubset`.
-
-        Raises
-        ------
-        NotImplementedError
-            Raised if :py:class:`pysegcnn.core.split.Split` is not inherited.
-
-        Returns
-        -------
-        None.
-
-        """
+    @property
+    def indices(self):
         raise NotImplementedError
 
+    def split(self):
 
-class DateSplit(Split):
-    """Split a dataset based on a date.
-
-    .. important::
-
-        Scenes before ``date`` build the training dataset, scenes after
-        ``date`` the validation dataset. The test set is empty.
-
-    Useful for time series data.
-
-    Class wrapper for :py:func:`pysegcnn.core.split.date_scene_split`.
-
-    Attributes
-    ----------
-    split_mode : `str`
-        The mode to split the dataset, i.e. `'date'`.
-    ds : :py:class:`pysegcnn.core.dataset.ImageDataset`
-        The dataset to split into training, validation and test set.
-    date : `str`
-        The date used to split the dataset.
-    dateformat : `str`
-        The format of ``date``.
-
-    """
-
-    # the split mode
-    split_mode = 'date'
-
-    def __init__(self, ds, date, dateformat):
-        """Initialize.
-
-        Parameters
-        ----------
-        ds : :py:class:`pysegcnn.core.dataset.ImageDataset`
-            An instance of :py:class:`pysegcnn.core.dataset.ImageDataset`.
-        date : `str`
-            A date in the format ``dateformat``.
-        dateformat : `str`
-            The format of ``date``. ``dateformat`` is used by
-            :py:func:`datetime.datetime.strptime' to parse ``date`` to a
-            :py:class:`datetime.datetime` object.
-
-        """
-        super().__init__(ds)
-
-        # the date to split the dataset
-        # before: training set
-        # after : validation set
-        self.date = date
-
-        # the format of the date
-        self.dateformat = dateformat
-
-    def subsets(self):
-        """Wrap :py:func:`pysegcnn.core.split.Split.date_scene_split`.
-
-        Returns
-        -------
-        subsets : `dict`
-            Subset dictionary with keys:
-                ``'train'``
-                    The training scenes (`dict`).
-                ``'valid'``
-                    The validation scenes (`dict`).
-                ``'test'``
-                    The test scenes, empty (`dict`).
-
-        """
-        return date_scene_split(self.ds, self.date, self.dateformat)
-
-    def subset_type(self):
-        """Wrap :py:class:`pysegcnn.core.split.SceneSubset`.
-
-        Returns
-        -------
-        SceneSubset : :py:class:`pysegcnn.core.split.SceneSubset`
-            The subset type.
-
-        """
-        return SceneSubset
-
-
-class RandomSplit(Split):
-    """Generic class for random dataset splits."""
-
-    def __init__(self, ds, ttratio, tvratio, seed):
-        """Initialize.
-
-        Parameters
-        ----------
-        ds : :py:class:`pysegcnn.core.dataset.ImageDataset`
-            An instance of :py:class:`pysegcnn.core.dataset.ImageDataset`.
-        tvratio : `float`
-            The ratio of training data to validation data, e.g.
-            ``tvratio=0.8`` means 80% training, 20% validation.
-        ttratio : `float`
-            The ratio of training and validation data to test data, e.g.
-            ``ttratio=0.6`` means 60% for training and validation, 40% for
-            testing.
-        seed : `int`
-            The random seed used to generate the split. Useful for
-            reproducibility.
-
-        """
-        super().__init__(ds)
+        # initialize training, validation and test subsets
+        subsets = []
 
-        # the training, validation and test set ratios
-        self.ttratio = ttratio
-        self.tvratio = tvratio
+        # the training, validation and test indices
+        for folds in self.indices:
+            subsets.append(
+                index_dict([Subset(self.ds, ids) for ids in folds.values()]))
 
-        # the random seed: useful for reproducibility
-        self.seed = seed
+        return subsets
 
 
 class RandomTileSplit(RandomSplit):
-    """Randomly split the dataset.
+    """Split a :py:class:`pysegcnn.core.dataset.ImageDataset` into tiles."""
 
-    .. important::
+    def __init__(self, ds, k_folds, seed=0, shuffle=True, tvratio=0.8,
+                 ttratio=1):
+        # initialize super class
+        super().__init__(ds, k_folds, seed, shuffle, tvratio, ttratio)
 
-        For each scene, the tiles of the scene can be distributed among the
-        training, validation and test set.
+    @property
+    def indices_to_split(self):
+        return np.arange(len(self.ds))
 
-    Class wrapper for :py:func:`pysegcnn.core.split.random_tile_split`.
-
-    Attributes
-    ----------
-    split_mode : `str`
-        The mode to split the dataset, i.e. `'random'`.
-    ds : :py:class:`pysegcnn.core.dataset.ImageDataset`
-        The dataset to split into training, validation and test set.
-    tvratio : `float`
-        The ratio of training data to validation data.
-    ttratio : `float`
-        The ratio of training and validation data to test data.
-    seed : `int`
-        The random seed used to generate the split.
-
-    """
-
-    # the split mode
-    split_mode = 'random'
-
-    def __init__(self, ds, ttratio, tvratio, seed):
-        super().__init__(ds, ttratio, tvratio, seed)
-
-    def subsets(self):
-        """Wrap :py:func:`pysegcnn.core.split.Split.random_tile_split`.
-
-        Returns
-        -------
-        subsets : `dict`
-            Subset dictionary with keys:
-                ``'train'``
-                    The training scenes (`dict`).
-                ``'valid'``
-                    The validation scenes (`dict`).
-                ``'test'``
-                    The test scenes (`dict`).
-
-        """
-        return random_tile_split(self.ds, self.tvratio, self.ttratio,
-                                 self.seed)
-
-    def subset_type(self):
-        """Wrap :py:class:`pysegcnn.core.split.RandomSubset`.
-
-        Returns
-        -------
-        SceneSubset : :py:class:`pysegcnn.core.split.RandomSubset`
-            The subset type.
-
-        """
-        return RandomSubset
+    @property
+    def indices(self):
+        return self.generate_splits()
 
 
 class RandomSceneSplit(RandomSplit):
-    """Semi-randomly split the dataset.
-
-    .. important::
-
-        For each scene, all the tiles of the scene are included in either the
-        training, validation or test set, respectively.
-
-    Class wrapper for :py:func:`pysegcnn.core.split.random_scene_split`.
-
-    Attributes
-    ----------
-    split_mode : `str`
-        The mode to split the dataset, i.e. `'scene'`.
-    ds : :py:class:`pysegcnn.core.dataset.ImageDataset`
-        The dataset to split into training, validation and test set.
-    tvratio : `float`
-        The ratio of training data to validation data.
-    ttratio : `float`
-        The ratio of training and validation data to test data.
-    seed : `int`
-        The random seed used to generate the split.
+    """Split a :py:class:`pysegcnn.core.dataset.ImageDataset` into scenes."""
 
-    """
-
-    # the split mode
-    split_mode = 'scene'
-
-    def __init__(self, ds, ttratio, tvratio, seed):
-        super().__init__(ds, ttratio, tvratio, seed)
+    def __init__(self, ds, k_folds, seed=0, shuffle=True, tvratio=0.8,
+                 ttratio=1):
+        # initialize super class
+        super().__init__(ds, k_folds, seed, shuffle, tvratio, ttratio)
 
-    def subsets(self):
-        """Wrap :py:func:`pysegcnn.core.split.Split.random_scene_split`.
+        # the number of the scenes in the dataset
+        self.scenes = np.array([v['scene'] for v in self.ds.scenes])
 
-        Returns
-        -------
-        subsets : `dict`
-            Subset dictionary with keys:
-                ``'train'``
-                    The training scenes (`dict`).
-                ``'valid'``
-                    The validation scenes (`dict`).
-                ``'test'``
-                    The test scenes (`dict`).
+    @property
+    def indices_to_split(self):
+        return np.unique(self.scenes)
 
-        """
-        return random_scene_split(self.ds, self.tvratio, self.ttratio,
-                                  self.seed)
-
-    def subset_type(self):
-        """Wrap :py:class:`pysegcnn.core.split.SceneSubset`.
+    @property
+    def indices(self):
+        # indices of the different scene identifiers
+        indices = self.generate_splits()
 
-        Returns
-        -------
-        SceneSubset : :py:class:`pysegcnn.core.split.SceneSubset`
-            The subset type.
+        # iterate over the different folds
+        scene_indices = []
+        for folds in indices:
+            # iterate over the training, validation and test set
+            subset = {}
+            for name, ids in folds.items():
+                subset[name] = np.where(np.isin(self.scenes, ids))[0]
+            scene_indices.append(subset)
 
-        """
-        return SceneSubset
+        return scene_indices
 
 
 class SupportedSplits(enum.Enum):
     """Names and corresponding classes of the implemented split modes."""
 
-    random = RandomTileSplit
+    tile = RandomTileSplit
     scene = RandomSceneSplit
-    date = DateSplit
-- 
GitLab