diff --git a/pysegcnn/core/dataset.py b/pysegcnn/core/dataset.py
index 9b2704c9f98476f575052448bcbe4797537e720f..a3371c0f222c6547699c50817cbd6d3d054dd522 100644
--- a/pysegcnn/core/dataset.py
+++ b/pysegcnn/core/dataset.py
@@ -19,14 +19,13 @@ import glob
 import enum
 import itertools
 
-
 # externals
 import numpy as np
 import torch
 from torch.utils.data import Dataset
 
 # locals
-from pysegcnn.core.constants import (Landsat8, Sentinel2, SparcsLabels,
+from pysegcnn.core.constants import (Landsat8, Sentinel2, Label, SparcsLabels,
                                      Cloud95Labels, ProSnowLabels)
 from pysegcnn.core.utils import (img2np, is_divisible, tile_topleft_corner,
                                  parse_landsat_scene, parse_sentinel2_scene)
@@ -35,6 +34,35 @@ from pysegcnn.core.utils import (img2np, is_divisible, tile_topleft_corner,
 # generic image dataset class
 class ImageDataset(Dataset):
 
+    # allowed keyword arguments and default values
+    default_kwargs = {
+
+        # which bands to use, if use_bands=[], use all available bands
+        'use_bands': [],
+
+        # each scene is divided into (tile_size x tile_size) blocks
+        # each of these blocks is treated as a single sample
+        'tile_size': None,
+
+        # a pattern to match the ground truth file naming convention
+        'gt_pattern': '*gt.tif',
+
+        # whether to chronologically sort the samples
+        'sort': False,
+
+        # the transformations to apply to the original image
+        # artificially increases the training data size
+        'transforms': [],
+
+        # whether to pad the image to be evenly divisible in square tiles
+        # of size (tile_size x tile_size)
+        'pad': False,
+
+        # the value to pad the samples
+        'cval': 0,
+
+        }
+
     def __init__(self, root_dir, **kwargs):
         super().__init__()
 
@@ -64,48 +92,15 @@ class ImageDataset(Dataset):
 
     def _init_kwargs(self, **kwargs):
 
-        # define allowed keyword arguments
-        self.default_kwargs = {
-            # which bands to use, if use_bands=[], use all available bands
-            'use_bands': [],
-
-            # each scene is divided into (tile_size x tile_size) blocks
-            # each of these blocks is treated as a single sample
-            'tile_size': None,
-
-            # a pattern to match the ground truth file naming convention
-            'gt_pattern': '*gt.tif',
+        # check if the keyword arguments are correctly specified
+        if not set(self.default_kwargs.keys()).issubset(kwargs.keys()):
+            raise TypeError('Valid keyword arguments are: \n' +
+                            '\n'.join('- {}'.format(k) for k in
+                                      self.default_kwargs.keys()))
 
-            # whether to chronologically sort the samples
-            'sort': False,
-
-            # the transformations to apply to the original image
-            # artificially increases the training data size
-            'transforms': [],
-
-            # whether to pad the image to be evenly divisible in square tiles
-            # of size (tile_size x tile_size)
-            'pad': False,
-
-            # the value to pad the samples
-            'cval': 0,
-
-            }
-
-        # set default kwargs
+        # update default arguments with specified keyword argument values
+        self.default_kwargs.update(kwargs)
         for k, v in self.default_kwargs.items():
-            # store default keyword arguments as instance attributes
-            setattr(self, k, v)
-
-        # check whether the keyword arguments are correctly specified
-        for k, v in kwargs.items():
-            if k not in self.default_kwargs.keys():
-                raise TypeError('"{}" is not a valid keyword argument. '
-                                'Valid keyword arguments are: \n'.format(k) +
-                                '\n'.join('- {}'.format(k) for k in
-                                          self.default_kwargs.keys()))
-
-            # store keyword argument as instance attribute
             setattr(self, k, v)
 
         # check which bands to use
@@ -152,8 +147,8 @@ class ImageDataset(Dataset):
                 self.labels[self.cval] = {'label': 'No data', 'color': 'black'}
 
     def _build_labels(self):
-        return {band.value[0]: {'label': band.name.replace('_', ' '),
-                                'color': band.value[1]}
+        return {band.id: {'label': band.name.replace('_', ' '),
+                          'color': band.color}
                 for band in self._label_class}
 
     def _assert_compose_scenes(self):
@@ -186,17 +181,18 @@ class ImageDataset(Dataset):
                             'enum.Enum, containing an enumeration of the '
                             'spectral bands of the sensor the dataset is '
                             'derived from. Examples can be found in '
-                            'pytorch.constants.py.'
+                            'pysegcnn.core.constants.py.'
                             .format(self.__class__.__name__))
 
     def _assert_get_labels(self):
-        if not isinstance(self._label_class, enum.EnumMeta):
+        if not issubclass(self._label_class, Label):
             raise TypeError('{}.get_labels() should return an instance of '
-                            'enum.Enum, containing an enumeration of the '
+                            'pysegcnn.core.constants.Label, '
+                            'containing an enumeration of the '
                             'class labels, together with the corresponing id '
                             'in the ground truth mask and a color for '
                             'visualization. Examples can be found in '
-                            'pytorch.constants.py.'
+                            'pysegcnn.core.constants.py.'
                             .format(self.__class__.__name__))
 
     # the __len__() method returns the number of samples in the dataset
@@ -212,12 +208,12 @@ class ImageDataset(Dataset):
         # select a scene
         scene = self.read_scene(idx)
 
-        # get samples: (tiles x channels x height x width)
+        # get samples
+        # data: (tiles, bands, height, width)
+        # gt: (height, width)
         data, gt = self.build_samples(scene)
 
-        # preprocess input and return torch tensors of shape:
-        # x : (bands, height, width)
-        # y : (height, width)
+        # preprocess samples
         x, y = self.preprocess(data, gt)
 
         # optional transformation
@@ -253,14 +249,13 @@ class ImageDataset(Dataset):
         raise NotImplementedError('Inherit the ImageDataset class and '
                                   'implement the method.')
 
-    # the get_bands() method has to be implemented by the class inheriting
+    # the get_sensor() method has to be implemented by the class inheriting
     # the ImageDataset class
-    # get_bands() should return a dictionary with the following
-    # (key: int, value: str) pairs:
-    #    - (1, band_1_name)
-    #    - (2, band_2_name)
+    # get_sensor() should return an enum.Enum with the following
+    # (name: str, value: int) tuples:
+    #    - (red, 2)
+    #    - (green, 3)
     #    - ...
-    #    - (n, band_n_name)
     def get_sensor(self, *args, **kwargs):
         raise NotImplementedError('Inherit the ImageDataset class and '
                                   'implement the method.')
@@ -335,6 +330,40 @@ class ImageDataset(Dataset):
         return (torch.tensor(x.copy(), dtype=torch.float32),
                 torch.tensor(y.copy(), dtype=torch.uint8))
 
+    def __repr__(self):
+
+        # representation string to print
+        fs = self.__class__.__name__ + '(\n'
+
+        # sensor
+        fs += '    (sensor):\n        - ' + self.sensor.__name__
+
+        # bands used for the segmentation
+        fs += '\n    (bands):\n        '
+        fs += '\n        '.join('- Band {}: {}'.format(i, b) for i, b in
+                                enumerate(self.use_bands))
+
+        # scenes
+        fs += '\n    (scene):\n        '
+        fs += '- size (h, w): {}\n        '.format((self.height, self.width))
+        fs += '- number of scenes: {}\n        '.format(
+            len(np.unique([f['id'] for f in self.scenes])))
+        fs += '- padding (bottom, left, top, right): {}'.format(self.padding)
+
+        # tiles
+        fs += '\n    (tiles):\n        '
+        fs += '- number of tiles per scene: {}\n        '.format(self.tiles)
+        fs += '- tile size: {}\n        '.format((self.tile_size,
+                                                  self.tile_size))
+        fs += '- number of tiles: {}'.format(len(self.scenes))
+
+        # classes of interest
+        fs += '\n    (classes):\n        '
+        fs += '\n        '.join('- Class {}: {}'.format(k, v['label']) for
+                                k, v in self.labels.items())
+        fs += '\n)'
+        return fs
+
 
 
 class StandardEoDataset(ImageDataset):