diff --git a/pysegcnn/core/trainer.py b/pysegcnn/core/trainer.py
index b2535ff67626c601a3d133ced7c7700825d9f011..96bb55e7ebabae189aaf06ee9215b84fdf079162 100644
--- a/pysegcnn/core/trainer.py
+++ b/pysegcnn/core/trainer.py
@@ -18,7 +18,9 @@ from torch.utils.data import DataLoader
 from pysegcnn.core.dataset import SupportedDatasets
 from pysegcnn.core.layers import Conv2dSame
 from pysegcnn.core.utils import img2np, accuracy_function
-from pysegcnn.core.split import Split
+from pysegcnn.core.split import (RandomTileSplit, RandomSceneSplit, DateSplit,
+                                 VALID_SPLIT_MODES)
+
 
 
 class NetworkTrainer(object):
@@ -340,13 +342,27 @@ class NetworkTrainer(object):
                              '\n'.join(name for name, _ in
                                        SupportedDatasets.__members__.items()))
 
-        # instanciate the Split class handling the dataset split
-        self.subset = Split(self.dataset, self.split_mode,
-                            tvratio=self.tvratio,
-                            ttratio=self.ttratio,
-                            seed=self.seed,
-                            date=self.date,
-                            dateformat=self.dateformat)
+
+        # the mode to split
+        if self.split_mode not in VALID_SPLIT_MODES:
+            raise ValueError('{} is not supported. Valid modes are {}, see '
+                              'pysegcnn.main.config.py for a description of '
+                              'each mode.'.format(self.split_mode,
+                                                  VALID_SPLIT_MODES))
+        if self.split_mode == 'random':
+            self.subset = RandomTileSplit(self.dataset,
+                                          self.ttratio,
+                                          self.tvratio,
+                                          self.seed)
+        if self.split_mode == 'scene':
+            self.subset = RandomSceneSplit(self.dataset,
+                                           self.ttratio,
+                                           self.tvratio,
+                                           self.seed)
+        if self.split_mode == 'date':
+            self.subset = DateSplit(self.dataset,
+                                    self.date,
+                                    self.dateformat)
 
         # the training, validation and dataset
         self.train_ds, self.valid_ds, self.test_ds = self.subset.split()
@@ -442,16 +458,10 @@ class NetworkTrainer(object):
 
         # representation string to print
         fs = self.__class__.__name__ + '(\n'
-        fs += '    (bands):\n        '
-
-        # bands used for the segmentation
-        fs += '\n        '.join('- Band {}: {}'.format(i, b) for i, b in
-                                enumerate(self.dataset.use_bands))
 
-        # classes of interest
-        fs += '\n    (classes):\n        '
-        fs += '\n        '.join('- Class {}: {}'.format(k, v['label']) for
-                                k, v in self.dataset.labels.items())
+        # dataset
+        fs += '    (dataset):\n        '
+        fs += ''.join(self.dataset.__repr__()).replace('\n', '\n        ')
 
         # batch size
         fs += '\n    (batch):\n        '
@@ -459,19 +469,18 @@ class NetworkTrainer(object):
         fs += '- batch shape (b, h, w): {}'.format(self.batch_shape)
 
         # dataset split
-        fs += '\n    (dataset):\n        '
-        fs += '\n        '.join(
-            '- {}: {:d} batches ({:.2f}%)'
-            .format(k, v[0], v[1] * 100) for k, v in
-            {'Training': (len(self.train_ds), self.ttratio * self.tvratio),
-             'Validation': (len(self.valid_ds),
-                            self.ttratio * (1 - self.tvratio)),
-             'Test': (len(self.test_ds), 1 - self.ttratio)}.items())
+        fs += '\n    (split):\n        '
+        fs += ''.join(self.subset.__repr__()).replace('\n', '\n        ')
 
         # model
         fs += '\n    (model):\n        '
         fs += ''.join(self.model.__repr__()).replace('\n', '\n        ')
+
+        # optimizer
+        fs += '\n    (optimizer):\n        '
+        fs += ''.join(self.optimizer.__repr__()).replace('\n', '\n        ')
         fs += '\n)'
+
         return fs