diff --git a/pysegcnn/core/split.py b/pysegcnn/core/split.py
index 412e7cd24a6362876cfcb7887528d8a2051055e6..78f653a3b41c2d0f552c2485e026d8f604487060 100644
--- a/pysegcnn/core/split.py
+++ b/pysegcnn/core/split.py
@@ -11,6 +11,9 @@ import datetime
 import numpy as np
 from torch.utils.data.dataset import Subset
 
+# the names of the subsets
+SUBSET_NAMES = ['train', 'valid', 'test']
+
 
 # function calculating number of samples in a dataset given a ratio
 def _ds_len(ds, ratio):
@@ -37,29 +40,26 @@ def random_tile_split(ds, tvratio, ttratio=1, seed=0):
     # length of the training dataset
     # number of samples: (ttratio * tvratio * len(ds))
     train_len = _ds_len(trav_indices, tvratio)
-    train_indices = trav_indices[:train_len]
+    train_ind = trav_indices[:train_len]
 
     # length of the validation dataset
     # number of samples: (ttratio * (1- tvratio) * len(ds))
-    valid_indices = trav_indices[train_len:]
+    valid_ind = trav_indices[train_len:]
 
     # length of the test dataset
     # number of samples: ((1 - ttratio) * len(ds))
-    test_indices = indices[trav_len:]
+    test_ind = indices[trav_len:]
 
     # get the tiles of the scenes of each dataset
-    subsets = []
-    for dataset in [train_indices, valid_indices, test_indices]:
-
-        # build the subset: store the scenes
-        sbst = Subset(dataset=ds, indices=list(dataset))
-        sbst.scenes = [ds.scenes[i] for i in dataset]
+    subsets = {}
+    for name, dataset in enumerate([train_ind, valid_ind, test_ind]):
 
-        # add to list of subsets
-        subsets.append(sbst)
+        # store the indices and corresponding tiles of the current subset to
+        # dictionary
+        subsets[SUBSET_NAMES[name]] = {k: ds.scenes[k] for k in dataset}
 
     # check if the splits are disjoint
-    assert pairwise_disjoint([s.indices for s in subsets])
+    assert pairwise_disjoint([s.keys() for s in subsets.values()])
 
     return subsets
 
@@ -95,28 +95,16 @@ def random_scene_split(ds, tvratio, ttratio=1, seed=0):
     test_scenes = scene_ids[trav_len:]
 
     # get the tiles of the scenes of each dataset
-    subsets = []
-    for dataset in [train_scenes, valid_scenes, test_scenes]:
-        # the indices of the scenes in the dataset
-        indices = []
-        tiles = []
-
-        # iterate over the scenes of the whole dataset
-        for i, scene in enumerate(ds.scenes):
-            if scene['id'] in dataset:
-                indices.append(i)
-                tiles.append(scene)
-
-        # build the subset: store scene ids
-        sbst = Subset(dataset=ds, indices=indices)
-        sbst.scenes = tiles
-        sbst.ids = dataset
-
-        # add to list of subsets
-        subsets.append(sbst)
+    subsets = {}
+    for name, dataset in enumerate([train_scenes, valid_scenes, test_scenes]):
+
+        # store the indices and corresponding tiles of the current subset to
+        # dictionary
+        subsets[SUBSET_NAMES[name]] = {k: v for k, v in enumerate(ds.scenes)
+                                       if v['id'] in dataset}
 
     # check if the splits are disjoint
-    assert pairwise_disjoint([s.indices for s in subsets])
+    assert pairwise_disjoint([s.keys() for s in subsets.values()])
 
     return subsets
 
@@ -135,18 +123,16 @@ def date_scene_split(ds, date, dateformat='%Y%m%d'):
     test_scenes = {}
 
     # build the training and test datasets
-    subsets = []
-    for scenes in [train_scenes, valid_scenes, test_scenes]:
-        # build the subset: store the scenes
-        sbst = Subset(dataset=ds, indices=list(scenes.keys()))
-        sbst.scenes = list(scenes.values())
-        sbst.ids = np.unique([s['id'] for s in scenes.values()])
+    subsets = {}
+    for name, scenes in enumerate([train_scenes, valid_scenes, test_scenes]):
 
-        # add to list of subsets
-        subsets.append(sbst)
+        # store the indices and corresponding tiles of the current subset to
+        # dictionary
+        subsets[SUBSET_NAMES[name]] = scenes
+        # sbst.ids = np.unique([s['id'] for s in scenes.values()])
 
     # check if the splits are disjoint
-    assert pairwise_disjoint([s.indices for s in subsets])
+    assert pairwise_disjoint([s.keys() for s in subsets.values()])
 
     return subsets
 
@@ -155,3 +141,103 @@ def pairwise_disjoint(sets):
     union = set().union(*sets)
     n = sum(len(u) for u in sets)
     return n == len(union)
+
+
+class Split(object):
+
+    # the valid modes
+    valid_modes = ['random', 'scene', 'date']
+
+    def __init__(self, ds, mode, **kwargs):
+
+        # check which mode is provided
+        if mode not in self.valid_modes:
+            raise ValueError('{} is not supported. Valid modes are {}, see '
+                             'pysegcnn.main.config.py for a description of '
+                             'each mode.'.format(mode, self.valid_modes))
+        self.mode = mode
+
+        # the dataset to split
+        self.ds = ds
+
+        # the keyword arguments
+        self.kwargs = kwargs
+
+        # initialize split
+        self._init_split()
+
+    def _init_split(self):
+
+        if self.mode == 'random':
+            self.subset = RandomSubset
+            self.split_function = random_tile_split
+            self.allowed_kwargs = ['tvratio', 'ttratio', 'seed']
+
+        if self.mode == 'scene':
+            self.subset = SceneSubset
+            self.split_function = random_scene_split
+            self.allowed_kwargs = ['tvratio', 'ttratio', 'seed']
+
+        if self.mode == 'date':
+            self.subset = SceneSubset
+            self.split_function = date_scene_split
+            self.allowed_kwargs = ['date', 'dateformat']
+
+        self._check_kwargs()
+
+    def _check_kwargs(self):
+
+        # check if the correct keyword arguments are provided
+        if not set(self.allowed_kwargs).issubset(self.kwargs.keys()):
+            raise TypeError('__init__() expecting keyword arguments: {}.'
+                            .format(', '.join(kwa for kwa in
+                                              self.allowed_kwargs)))
+        # select the correct keyword arguments
+        self.kwargs = {k: self.kwargs[k] for k in self.allowed_kwargs}
+
+    # function apply the split
+    def split(self):
+
+        # create the subsets
+        subsets = self.split_function(self.ds, **self.kwargs)
+
+        # build the subsets
+        ds_split = []
+        for name, sub in subsets.items():
+
+            # the scene identifiers of the current subset
+            ids = np.unique([s['id'] for s in sub.values()])
+
+            # build the subset
+            subset = self.subset(self.ds, list(sub.keys()), name,
+                                 list(sub.values()), ids)
+            ds_split.append(subset)
+
+        return ds_split
+
+
+class SceneSubset(Subset):
+
+    def __init__(self, ds, indices, name, scenes, scene_ids):
+        super().__init__(dataset=ds, indices=indices)
+
+        # the name of the subset
+        self.name = name
+
+        # the scene in the subset
+        self.scenes = scenes
+
+        # the names of the scenes
+        self.ids = scene_ids
+
+
+class RandomSubset(Subset):
+
+    def __init__(self, ds, indices, name, scenes):
+        super().__init__(dataset=ds, indices=indices)
+
+        # the name of the subset
+        self.name = name
+
+        # the scene in the subset
+        self.scenes = scenes
diff --git a/pysegcnn/core/trainer.py b/pysegcnn/core/trainer.py
index eba222cbe9541f2381bc56651a0528e3f032d847..b2535ff67626c601a3d133ced7c7700825d9f011 100644
--- a/pysegcnn/core/trainer.py
+++ b/pysegcnn/core/trainer.py
@@ -17,9 +17,8 @@ from torch.utils.data import DataLoader
 # locals
 from pysegcnn.core.dataset import SupportedDatasets
 from pysegcnn.core.layers import Conv2dSame
-from pysegcnn.core.utils import img2np
-from pysegcnn.core.split import (random_tile_split, random_scene_split,
-                                 date_scene_split)
+from pysegcnn.core.utils import img2np, accuracy_function
+from pysegcnn.core.split import Split
 
 
 class NetworkTrainer(object):
@@ -244,7 +243,7 @@ class NetworkTrainer(object):
 
         return training_state
 
-    def predict(self, pretrained=False, confusion=False):
+    def predict(self):
 
         print('------------------------ Predicting --------------------------')
 
@@ -341,18 +340,16 @@ class NetworkTrainer(object):
                              '\n'.join(name for name, _ in
                                        SupportedDatasets.__members__.items()))
 
-        # the training, validation and dataset
-        if self.split_mode == 'random':
-            self.train_ds, self.valid_ds, self.test_ds = random_tile_split(
-                self.dataset, self.tvratio, self.ttratio, self.seed)
-
-        if self.split_mode == 'scene':
-            self.train_ds, self.valid_ds, self.test_ds = random_scene_split(
-                self.dataset, self.tvratio, self.ttratio, self.seed)
+        # instanciate the Split class handling the dataset split
+        self.subset = Split(self.dataset, self.split_mode,
+                            tvratio=self.tvratio,
+                            ttratio=self.ttratio,
+                            seed=self.seed,
+                            date=self.date,
+                            dateformat=self.dateformat)
 
-        if self.split_mode == 'date':
-            self.train_ds, self.valid_ds, self.test_ds = date_scene_split(
-                self.dataset, self.date)
+        # the training, validation and dataset
+        self.train_ds, self.valid_ds, self.test_ds = self.subset.split()
 
         # whether to drop training samples with a fraction of pixels equal to
         # the constant padding value self.cval >= self.drop
@@ -409,7 +406,7 @@ class NetworkTrainer(object):
         for pos, i in enumerate(ds.indices):
 
             # the current scene
-            s = self.dataset.scenes[i]
+            s = ds.dataset.scenes[i]
 
             # the current tile in the ground truth
             tile_gt = img2np(s['gt'], self.tile_size, s['tile'],
@@ -537,8 +534,3 @@ class EarlyStopping(object):
 
     def increased(self, metric, best, min_delta):
         return metric > best + min_delta
-
-
-# function calculating prediction accuracy
-def accuracy_function(outputs, labels):
-    return (np.asarray(outputs) == np.asarray(labels)).mean()