diff --git a/pysegcnn/core/dataset.py b/pysegcnn/core/dataset.py
index 98a24283adfbbfc3d6292f028f2bdf7a4b348471..d09b713630f5520ad3dd9b4c73cf796fe51af08c 100644
--- a/pysegcnn/core/dataset.py
+++ b/pysegcnn/core/dataset.py
@@ -35,45 +35,76 @@ from pysegcnn.core.utils import (img2np, is_divisible, tile_topleft_corner,
 LOGGER = logging.getLogger(__name__)
 
 
-# generic image dataset class
 class ImageDataset(Dataset):
-
-    # allowed keyword arguments and default values
-    default_kwargs = {
-
-        # which bands to use, if use_bands=[], use all available bands
-        'use_bands': [],
-
-        # each scene is divided into (tile_size x tile_size) blocks
-        # each of these blocks is treated as a single sample
-        'tile_size': None,
-
-        # a pattern to match the ground truth file naming convention
-        'gt_pattern': '*gt.tif',
-
-        # whether to chronologically sort the samples
-        'sort': False,
-
-        # the random seed used to split dataset into training, validation and
-        # test data
-        'seed': 0,
-
-        # the transformations to apply to the original image
-        # artificially increases the training data size
-        'transforms': [],
-
-        # whether to pad the image to be evenly divisible in square tiles
-        # of size (tile_size x tile_size)
-        'pad': False
-
-        }
-
-    def __init__(self, root_dir, **kwargs):
+    """Base class for multispectral image data.
+
+    Inheriting from `torch.utils.data.Dataset` to be compliant to the PyTorch
+    standard. Furthermore, using instances of `torch.utils.data.Dataset`
+    enables the use of the handy `torch.utils.data.DataLoader` class during
+    model training.
+
+    Parameters
+    ----------
+    root_dir : `str`
+        The root directory, path to the dataset.
+    use_bands : `list` [`str`], optional
+        A list of the spectral bands to use. The default is [].
+    tile_size : `int` or `None`, optional
+        The size of the tiles. If not `None`, each scene is divided into square
+        tiles of shape (tile_size, tile_size). The default is None.
+    pad : `bool`, optional
+        Whether to center pad the input image. Set ``pad`` = True, if the
+        images are not evenly divisible by the ``tile_size``. The image data is
+        padded with a constant padding value of zero. For each image, the
+        corresponding ground truth image is padded with a "no data" label.
+        The default is False.
+    gt_pattern : `str`, optional
+        A pattern to match the ground truth naming convention. All directories
+        and subdirectories in ``root_dir`` are searched for files matching
+        ``gt_pattern``. The default is '*gt.tif'.
+    sort : `bool`, optional
+        Whether to chronologically sort the samples. Useful for time series
+        data. The default is False.
+    seed : `int`, optional
+        The random seed. Used to split the dataset into training, validation
+        and test set. Useful for reproducibility. The default is 0.
+    transforms : `list` [`pysegcnn.core.split.Augment`], optional
+        List of `pysegcnn.core.split.Augment` instances. Each item in
+        ``transforms`` generates a distinct transformed version of the dataset.
+        The total dataset is composed of the original untransformed dataset
+        together with each transformed version of it.
+        If ``transforms`` = [], only the original dataset is used.
+        The default is [].
+
+    Returns
+    -------
+    None.
+
+    """
+
+    def __init__(self, root_dir, use_bands=[], tile_size=None, pad=False,
+                 gt_pattern='*gt.tif', sort=False, seed=0, transforms=[]):
         super().__init__()
 
-        # the root directory: path to the image dataset
+        # dataset configuration
         self.root = root_dir
+        self.use_bands = use_bands
+        self.tile_size = tile_size
+        self.pad = pad
+        self.gt_pattern = gt_pattern
+        self.sort = sort
+        self.seed = seed
+        self.transforms = transforms
+
+        # initialize instance attributes
+        self._init_attributes()
 
+        # the samples of the dataset
+        self.scenes = self.compose_scenes()
+        self._assert_compose_scenes()
+
+    def _init_attributes(self):
+        """Initialize the class instance attributes."""
         # the size of a scene/patch in the dataset
         self.size = self.get_size()
         self._assert_get_size()
@@ -88,26 +119,6 @@ class ImageDataset(Dataset):
         self._assert_get_labels()
         self.labels = self._build_labels()
 
-        # initialize keyword arguments
-        self._init_kwargs(**kwargs)
-
-        # the samples of the dataset
-        self.scenes = self.compose_scenes()
-        self._assert_compose_scenes()
-
-    def _init_kwargs(self, **kwargs):
-
-        # check if the keyword arguments are correctly specified
-        if not set(kwargs.keys()).issubset(self.default_kwargs.keys()):
-            raise TypeError('Valid keyword arguments are: \n' +
-                            '\n'.join('- {}'.format(k) for k in
-                                      self.default_kwargs.keys()))
-
-        # update default arguments with specified keyword argument values
-        self.default_kwargs.update(kwargs)
-        for k, v in self.default_kwargs.items():
-            setattr(self, k, v)
-
         # check which bands to use
         self.use_bands = (self.use_bands if self.use_bands else
                           [*self.bands.values()])
@@ -140,12 +151,25 @@ class ImageDataset(Dataset):
                         .format(self.cval))
 
     def _build_labels(self):
+        """Build the label dictionary.
+
+        Returns
+        -------
+        labels : `dict` [`int`, `dict`]
+            The label dictionary. The keys are the values of the class labels
+            in the ground truth ``y``. Each nested `dict` should have keys:
+                ``'color'``
+                    A named color (`str`).
+                ``'label'``
+                    The name of the class label (`str`).
+
+        """
         return {band.id: {'label': band.name.replace('_', ' '),
                           'color': band.color}
                 for band in self._label_class}
 
     def _assert_compose_scenes(self):
-
+        """Check whether compose_scenes() is correctly implemented."""
         # list of required keys
         self.keys = self.use_bands + ['gt', 'date', 'tile', 'transform', 'id']
 
@@ -163,12 +187,14 @@ class ImageDataset(Dataset):
                                .format(self.keys))
 
     def _assert_get_size(self):
+        """Check whether get_size() is correctly implemented."""
         if not isinstance(self.size, tuple) and len(self.size) == 2:
             raise TypeError('{}.get_size() should return the spatial size of '
                             'an image sample as tuple, i.e. (height, width).'
                             .format(self.__class__.__name__))
 
     def _assert_get_sensor(self):
+        """Check whether get_sensor() is correctly implemented."""
         if not isinstance(self.sensor, enum.EnumMeta):
             raise TypeError('{}.get_sensor() should return an instance of '
                             'enum.Enum, containing an enumeration of the '
@@ -178,6 +204,7 @@ class ImageDataset(Dataset):
                             .format(self.__class__.__name__))
 
     def _assert_get_labels(self):
+        """Check whether get_labels() is correctly implemented."""
         if not issubclass(self._label_class, Label):
             raise TypeError('{}.get_labels() should return an instance of '
                             'pysegcnn.core.constants.Label, '
@@ -188,16 +215,32 @@ class ImageDataset(Dataset):
                             'pysegcnn.core.constants.py.'
                             .format(self.__class__.__name__))
 
-    # the __len__() method returns the number of samples in the dataset
     def __len__(self):
-        # number of (tiles x channels x height x width) patches after each
-        # scene is decomposed to tiles blocks
+        """Return the number of samples in the dataset.
+
+        Returns
+        -------
+        nsamples : `int`
+            The number of samples in the dataset.
+        """
         return len(self.scenes)
 
-    # the __getitem__() method returns a single sample of the dataset given an
-    # index, i.e. an array/tensor of shape (channels x height x width)
     def __getitem__(self, idx):
+        """Return the data of a sample of the dataset given an index ``idx``.
+
+        Parameters
+        ----------
+        idx : `int`
+            The index of the sample.
+
+        Returns
+        -------
+        x : `torch.Tensor`
+            The sample input data.
+        y : `torch.Tensor`
+            The sample ground truth.
 
+        """
         # select a scene
         scene = self.read_scene(idx)
 
@@ -209,88 +252,191 @@ class ImageDataset(Dataset):
         # preprocess samples
         x, y = self.preprocess(data, gt)
 
-        # optional transformation
+        # apply transformation
         if scene['transform'] is not None:
             x, y = scene['transform'](x, y)
 
         # convert to torch tensors
-        x, y = self.to_tensor(x, y)
+        x = self.to_tensor(x, dtype=torch.float32)
+        y = self.to_tensor(y, dtype=torch.uint8)
 
         return x, y
 
-    # the compose_scenes() method has to be implemented by the class inheriting
-    # the ImageDataset class
-    # compose_scenes() should return a list of dictionaries, where each
-    # dictionary represent one sample of the dataset, a scene or a tile
-    # of a scene, etc.
-    # the dictionaries should have the following (key, value) pairs:
-    #   - (band_1, path_to_band_1.tif)
-    #   - (band_2, path_to_band_2.tif)
-    #   - ...
-    #   - (band_n, path_to_band_n.tif)
-    #   - (gt, path_to_ground_truth.tif)
-    #   - (tile, None or int)
-    #   - ...
-    def compose_scenes(self, *args, **kwargs):
+    def compose_scenes(self):
+        """Build the list of samples of the dataset.
+
+        Each sample is represented by a dictionary.
+
+        Raises
+        ------
+        NotImplementedError
+            Raised if the `pysegcnn.core.dataset.ImageDataset` class is not
+            inherited.
+
+        Returns
+        -------
+        samples : `list` [`dict`]
+            Each dictionary representing a sample should have keys:
+                ``'band_name_1'``
+                    Path to the file of band_1.
+                ``'band_name_2'``
+                    Path to the file of band_2.
+                ``'band_name_n'``
+                    Path to the file of band_n.
+                ``'gt'``
+                    Path to the ground truth file.
+                ``'date'``
+                    The date of the sample.
+                ``'tile'``
+                    The tile id of the sample.
+                ``'transform'``
+                    The transformation to apply.
+                ``'id'``
+                    The scene identifier.
+
+        """
         raise NotImplementedError('Inherit the ImageDataset class and '
                                   'implement the method.')
 
-    # the get_size() method has to be implemented by the class inheriting
-    # the ImageDataset class
-    # get_size() method should return the image size as tuple, (height, width)
-    def get_size(self, *args, **kwargs):
+    def get_size(self):
+        """Return the size of the images in the dataset.
+
+        Raises
+        ------
+        NotImplementedError
+            Raised if the `pysegcnn.core.dataset.ImageDataset` class is not
+            inherited.
+
+        Returns
+        -------
+        size : `tuple`
+            The image size (height, width).
+
+        """
         raise NotImplementedError('Inherit the ImageDataset class and '
                                   'implement the method.')
 
-    # the get_sensor() method has to be implemented by the class inheriting
-    # the ImageDataset class
-    # get_sensor() should return an enum.Enum with the following
-    # (name: str, value: int) tuples:
-    #    - (red, 2)
-    #    - (green, 3)
-    #    - ...
-    def get_sensor(self, *args, **kwargs):
+    def get_sensor(self):
+        """Return an enumeration of the bands of the sensor of the dataset.
+
+        Examples can be found in `pysegcnn.core.constants`.
+
+        Raises
+        ------
+        NotImplementedError
+            Raised if the `pysegcnn.core.dataset.ImageDataset` class is not
+            inherited.
+
+        Returns
+        -------
+        sensor : `enum.Enum`
+            An enumeration of the bands of the sensor.
+
+        """
         raise NotImplementedError('Inherit the ImageDataset class and '
                                   'implement the method.')
 
-    # the get_labels() method has to be implemented by the class inheriting
-    # the ImageDataset class
-    # get_labels() should return a dictionary with the following
-    # (key: int, value: str) pairs:
-    #    - (0, label_1_name)
-    #    - (1, label_2_name)
-    #    - ...
-    #    - (n, label_n_name)
-    # where the keys should be the values representing the values of the
-    # corresponding label in the ground truth mask
-    # the labels in the dictionary determine the classes to be segmented
-    def get_labels(self, *args, **kwargs):
+    def get_labels(self):
+        """Return an enumeration of the class labels of the dataset.
+
+        Examples can be found in `pysegcnn.core.constants`.
+
+        Raises
+        ------
+        NotImplementedError
+            Raised if the `pysegcnn.core.dataset.ImageDataset` class is not
+            inherited.
+
+        Returns
+        -------
+        labels : `enum.Enum`
+            The class labels.
+
+        """
         raise NotImplementedError('Inherit the ImageDataset class and '
                                   'implement the method.')
 
-    # the preprocess() method has to be implemented by the class inheriting
-    # the ImageDataset class
-    # preprocess() should return two torch.tensors:
-    #    - input data: tensor of shape (bands, height, width)
-    #    - ground truth: tensor of shape (height, width)
     def preprocess(self, data, gt):
+        """Preprocess a sample before feeding it to a model.
+
+        Parameters
+        ----------
+        data : `numpy.ndarray`
+            The sample input data.
+        gt : `numpy.ndarray`
+            The sample ground truth.
+
+        Raises
+        ------
+        NotImplementedError
+            Raised if the `pysegcnn.core.dataset.ImageDataset` class is not
+            inherited.
+
+        Returns
+        -------
+        data : `numpy.ndarray`
+            The preprocessed input data.
+        gt : `numpy.ndarray`
+            The preprocessed ground truth data.
+
+        """
         raise NotImplementedError('Inherit the ImageDataset class and '
                                   'implement the method.')
 
-    # the parse_scene_id() method has to be implemented by the class inheriting
-    # the ImageDataset class
-    # the input to the parse_scene_id() method is a string describing a scene
-    # id, e.g. an id of a Landsat or a Sentinel scene
-    # parse_scene_id() should return a dictionary containing the scene metadata
-    def parse_scene_id(self, scene):
+    def parse_scene_id(self, scene_id):
+        """Parse the scene identifier.
+
+        Parameters
+        ----------
+        scene_id : `str`
+            A scene identifier.
+
+        Raises
+        ------
+        NotImplementedError
+            Raised if the `pysegcnn.core.dataset.ImageDataset` class is not
+            inherited.
+
+        Returns
+        -------
+        scene : `dict` or `None`
+            A dictionary containing scene metadata. If `None`, ``scene_id`` is
+            not a valid scene identifier.
+
+        """
         raise NotImplementedError('Inherit the ImageDataset class and '
                                   'implement the method.')
 
-    # _read_scene() reads all the bands and the ground truth mask in a
-    # scene/tile to a numpy array and returns a dictionary with
-    # (key, value) = ('band_name', np.ndarray(band_data))
     def read_scene(self, idx):
-
+        """Read the data of the sample with index ``idx``.
+
+        Parameters
+        ----------
+        idx : `int`
+            The index of the sample.
+
+        Returns
+        -------
+        scene_data : `dict`
+            The sample data dictionary with keys:
+                ``'band_name_1'``
+                    data of band_1 (`numpy.ndarray`).
+                ``'band_name_2'``
+                    data of band_2 (`numpy.ndarray`).
+                ``'band_name_n'``
+                    data of band_n (`numpy.ndarray`).
+                ``'gt'``
+                    data of the ground truth (`numpy.ndarray`).
+                ``'date'``
+                    The date of the sample.
+                ``'tile'``
+                    The tile id of the sample.
+                ``'transform'``
+                    The transformation to apply.
+                ``'id'``
+                    The scene identifier.
+
+        """
         # select a scene from the root directory
         scene = self.scenes[idx]
 
@@ -309,22 +455,71 @@ class ImageDataset(Dataset):
 
         return scene_data
 
-    # _build_samples() stacks all bands of a scene/tile into a
-    # numpy array of shape (bands x height x width)
     def build_samples(self, scene):
-
+        """Stack the bands of a sample in a single array.
+
+        Parameters
+        ----------
+        scene : `dict`
+            The sample data dictionary with keys:
+                ``'band_name_1'``
+                    data of band_1 (`numpy.ndarray`).
+                ``'band_name_2'``
+                    data of band_2 (`numpy.ndarray`).
+                ``'band_name_n'``
+                    data of band_n (`numpy.ndarray`).
+                ``'gt'``
+                    data of the ground truth (`numpy.ndarray`).
+                ``'date'``
+                    The date of the sample.
+                ``'tile'``
+                    The tile id of the sample.
+                ``'transform'``
+                    The transformation to apply.
+                ``'id'``
+                    The scene identifier.
+
+        Returns
+        -------
+        stack : `numpy.ndarray`
+            The input data of the sample.
+        gt : TYPE
+            The ground truth of the sample.
+
+        """
         # iterate over the channels to stack
         stack = np.stack([scene[band] for band in self.use_bands], axis=0)
         gt = scene['gt']
 
         return stack, gt
 
-    def to_tensor(self, x, y):
-        return (torch.tensor(x.copy(), dtype=torch.float32),
-                torch.tensor(y.copy(), dtype=torch.uint8))
+    def to_tensor(self, x, dtype):
+        """Convert ``x`` to `torch.Tensor`.
+
+        Parameters
+        ----------
+        x : array_like
+            The input data.
+        dtype : `torch.dtype`
+            The data type used to convert ``x``.
+
+        Returns
+        -------
+        x : `torch.Tensor`
+            The input data tensor.
+
+        """
+        return torch.tensor(np.asarray(x).copy(), dtype=dtype)
 
     def __repr__(self):
+        """Dataset representation.
+
+        Returns
+        -------
+        fs : `str`
+            Representation string.
 
+        """
         # representation string to print
         fs = self.__class__.__name__ + '(\n'
 
@@ -358,17 +553,106 @@ class ImageDataset(Dataset):
         return fs
 
 
-
 class StandardEoDataset(ImageDataset):
-
-    def __init__(self, root_dir, **kwargs):
+    """Base class for standard Earth Observation style datasets.
+
+    `pysegcnn.core.dataset.StandardEoDataset` implements the
+    `~pysegcnn.core.dataset.StandardEoDataset.compose_scenes` method for
+    datasets with the following directory structure:
+
+    root_dir/
+        scene_id_1/
+            scene_id_1_B1.tif
+            scene_id_1_B2.tif
+            .
+            .
+            .
+            scene_id_1_BN.tif
+        scene_id_2/
+            scene_id_2_B1.tif
+            scene_id_2_B2.tif
+            .
+            .
+            .
+            scene_id_2_BN.tif
+        .
+        .
+        .
+        scene_id_N/
+            .
+            .
+            .
+
+    If your dataset shares this directory structure, you can directly inherit
+    `pysegcnn.core.dataset.StandardEoDataset` and implement the remaining
+    methods.
+
+    See `pysegcnn.core.dataset.SparcsDataset` for an example.
+
+    Parameters
+    ----------
+    root_dir : `str`
+        The root directory, path to the dataset.
+    use_bands : `list` [`str`], optional
+        A list of the spectral bands to use. The default is [].
+    tile_size : `int` or `None`, optional
+        The size of the tiles. If not `None`, each scene is divided into square
+        tiles of shape (tile_size, tile_size). The default is None.
+    pad : `bool`, optional
+        Whether to center pad the input image. Set ``pad`` = True, if the
+        images are not evenly divisible by the ``tile_size``. The image data is
+        padded with a constant padding value of zero. For each image, the
+        corresponding ground truth image is padded with a "no data" label.
+        The default is False.
+    gt_pattern : `str`, optional
+        A pattern to match the ground truth naming convention. All directories
+        and subdirectories in ``root_dir`` are searched for files matching
+        ``gt_pattern``. The default is '*gt.tif'.
+    sort : `bool`, optional
+        Whether to chronologically sort the samples. Useful for time series
+        data. The default is False.
+    seed : `int`, optional
+        The random seed. Used to split the dataset into training, validation
+        and test set. Useful for reproducibility. The default is 0.
+    transforms : `list` [`pysegcnn.core.split.Augment`], optional
+        List of `pysegcnn.core.split.Augment` instances. Each item in
+        ``transforms`` generates a distinct transformed version of the dataset.
+        The total dataset is composed of the original untransformed dataset
+        together with each transformed version of it.
+        If ``transforms`` = [], only the original dataset is used.
+        The default is [].
+
+    Returns
+    -------
+    None.
+
+    """
+
+    def __init__(self, root_dir, use_bands=[], tile_size=None, pad=False,
+                 gt_pattern='*gt.tif', sort=False, seed=0, transforms=[]):
         # initialize super class ImageDataset
-        super().__init__(root_dir, **kwargs)
+        super().__init__(root_dir, use_bands, tile_size, pad, gt_pattern,
+                         sort, seed, transforms)
+
+    def _get_band_number(self, path):
+        """Return the band number of a scene .tif file.
+
+        Parameters
+        ----------
+        path : `str`
+            The path to the .tif file.
+
+        Raises
+        ------
+        ValueError
+            Raised if ``path`` is not a .tif file.
 
-    # returns the band number of a Landsat8 or Sentinel2 tif file
-    # path: path to a tif file
-    def get_band_number(self, path):
+        Returns
+        -------
+        band : `int` or `str`
+            The band number.
 
+        """
         # check whether the path leads to a tif file
         if not path.endswith(('tif', 'TIF')):
             raise ValueError('Expected a path to a tif file.')
@@ -377,7 +661,7 @@ class StandardEoDataset(ImageDataset):
         fname = os.path.basename(path)
 
         # search for numbers following a "B" in the filename
-        band = re.search('B\dA|B\d{1,2}', fname)[0].replace('B', '')
+        band = re.search('B\\dA|B\\d{1,2}', fname)[0].replace('B', '')
 
         # try converting to an integer:
         # raises a ValueError for Sentinel2 8A band
@@ -388,14 +672,34 @@ class StandardEoDataset(ImageDataset):
 
         return band
 
-    # store_bands() writes the paths to the data of each scene to a dictionary
-    # only the bands of interest are stored
-    def store_bands(self, bands, gt):
-
+    def _store_bands(self, bands, gt):
+        """Write the bands of interest to a dictionary.
+
+        Parameters
+        ----------
+        bands : `list` [`str`]
+            Paths to the .tif files of the bands of the scene.
+        gt : `str`
+            Path to the ground truth of the scene.
+
+        Returns
+        -------
+        scene_data : `dict`
+            The scene data dictionary with keys:
+                ``'band_name_1'``
+                    Path to the .tif file of band_1.
+                ``'band_name_2'``
+                    Path to the .tif file of band_2.
+                ``'band_name_n'``
+                    Path to the .tif file of band_n.
+                ``'gt'``
+                    Path to the ground truth file.
+
+        """
         # store the bands of interest in a dictionary
         scene_data = {}
         for i, b in enumerate(bands):
-            band = self.bands[self.get_band_number(b)]
+            band = self.bands[self._get_band_number(b)]
             if band in self.use_bands:
                 scene_data[band] = b
 
@@ -404,12 +708,33 @@ class StandardEoDataset(ImageDataset):
 
         return scene_data
 
-    # compose_scenes() creates a list of dictionaries containing the paths
-    # to the tif files of each scene
-    # if the scenes are divided into tiles, each tile has its own entry
-    # with corresponding tile id
     def compose_scenes(self):
-
+        """Build the list of samples of the dataset.
+
+        Each sample is represented by a dictionary.
+
+        Returns
+        -------
+        scenes : `list` [`dict`]
+            Each item in ``scenes`` is a `dict` with keys:
+                ``'band_name_1'``
+                    Path to the file of band_1.
+                ``'band_name_2'``
+                    Path to the file of band_2.
+                ``'band_name_n'``
+                    Path to the file of band_n.
+                ``'gt'``
+                    Path to the ground truth file.
+                ``'date'``
+                    The date of the sample.
+                ``'tile'``
+                    The tile id of the sample.
+                ``'transform'``
+                    The transformation to apply.
+                ``'id'``
+                    The scene identifier.
+
+        """
         # search the root directory
         scenes = []
         self.gt = []
@@ -447,7 +772,7 @@ class StandardEoDataset(ImageDataset):
                     for transf in self.transforms:
 
                         # store the bands and the ground truth mask of the tile
-                        data = self.store_bands(bands, gt)
+                        data = self._store_bands(bands, gt)
 
                         # the name of the scene
                         data['id'] = scene['id']
@@ -472,122 +797,518 @@ class StandardEoDataset(ImageDataset):
         return scenes
 
 
-# SparcsDataset class: inherits from the generic ImageDataset class
 class SparcsDataset(StandardEoDataset):
-
-    def __init__(self, root_dir, **kwargs):
+    """Dataset class for the `Sparcs`_ dataset.
+
+    .. _Sparcs:
+        https://www.usgs.gov/land-resources/nli/landsat/spatial-procedures-automated-removal-cloud-and-shadow-sparcs-validation
+
+    Parameters
+    ----------
+    root_dir : `str`
+        The root directory, path to the dataset.
+    use_bands : `list` [`str`], optional
+        A list of the spectral bands to use. The default is [].
+    tile_size : `int` or `None`, optional
+        The size of the tiles. If not `None`, each scene is divided into square
+        tiles of shape (tile_size, tile_size). The default is None.
+    pad : `bool`, optional
+        Whether to center pad the input image. Set ``pad`` = True, if the
+        images are not evenly divisible by the ``tile_size``. The image data is
+        padded with a constant padding value of zero. For each image, the
+        corresponding ground truth image is padded with a "no data" label.
+        The default is False.
+    gt_pattern : `str`, optional
+        A pattern to match the ground truth naming convention. All directories
+        and subdirectories in ``root_dir`` are searched for files matching
+        ``gt_pattern``. The default is '*gt.tif'.
+    sort : `bool`, optional
+        Whether to chronologically sort the samples. Useful for time series
+        data. The default is False.
+    seed : `int`, optional
+        The random seed. Used to split the dataset into training, validation
+        and test set. Useful for reproducibility. The default is 0.
+    transforms : `list` [`pysegcnn.core.split.Augment`], optional
+        List of `pysegcnn.core.split.Augment` instances. Each item in
+        ``transforms`` generates a distinct transformed version of the dataset.
+        The total dataset is composed of the original untransformed dataset
+        together with each transformed version of it.
+        If ``transforms`` = [], only the original dataset is used.
+        The default is [].
+
+    Returns
+    -------
+    None.
+
+    """
+
+    def __init__(self, root_dir, use_bands=[], tile_size=None, pad=False,
+                 gt_pattern='*gt.tif', sort=False, seed=0, transforms=[]):
         # initialize super class StandardEoDataset
-        super().__init__(root_dir, **kwargs)
+        super().__init__(root_dir, use_bands, tile_size, pad, gt_pattern,
+                         sort, seed, transforms)
 
-    # image size of the Sparcs dataset: (height, width)
     def get_size(self):
+        """Image size of the Sparcs dataset.
+
+        Returns
+        -------
+        size : `tuple`
+            The image size (height, width).
+
+        """
         return (1000, 1000)
 
-    # Landsat 8 bands of the Sparcs dataset
     def get_sensor(self):
+        """Landsat 8 bands of the Sparcs dataset.
+
+        Returns
+        -------
+        sensor : `enum.Enum`
+            An enumeration of the bands of the sensor.
+
+        """
         return Landsat8
 
-    # class labels of the Sparcs dataset
     def get_labels(self):
+        """Class labels of the Sparcs dataset.
+
+        Returns
+        -------
+        labels : `enum.Enum`
+            The class labels.
+
+        """
         return SparcsLabels
 
-    # preprocessing of the Sparcs dataset
     def preprocess(self, data, gt):
-
+        """Preprocess Sparcs dataset images.
+
+        Parameters
+        ----------
+        data : `numpy.ndarray`
+            The sample input data.
+        gt : `numpy.ndarray`
+            The sample ground truth.
+
+        Returns
+        -------
+        data : `numpy.ndarray`
+            The preprocessed input data.
+        gt : `numpy.ndarray`
+            The preprocessed ground truth data.
+
+        """
         # if the preprocessing is not done externally, implement it here
         return data, gt
 
-    # function that parses the date from a Landsat 8 scene id
-    def parse_scene_id(self, scene):
-        return parse_landsat_scene(scene)
+    def parse_scene_id(self, scene_id):
+        """Parse Sparcs scene identifiers (Landsat 8).
 
+        Parameters
+        ----------
+        scene_id : `str`
+            A scene identifier.
 
-class ProSnowDataset(StandardEoDataset):
+        Returns
+        -------
+        scene : `dict` or `None`
+            A dictionary containing scene metadata. If `None`, ``scene_id`` is
+            not a valid Landsat scene identifier.
+
+        """
+        return parse_landsat_scene(scene_id)
 
-    def __init__(self, root_dir, **kwargs):
+
+class ProSnowDataset(StandardEoDataset):
+    """Dataset class for the ProSnow datasets.
+
+    Parameters
+    ----------
+    root_dir : `str`
+        The root directory, path to the dataset.
+    use_bands : `list` [`str`], optional
+        A list of the spectral bands to use. The default is [].
+    tile_size : `int` or `None`, optional
+        The size of the tiles. If not `None`, each scene is divided into square
+        tiles of shape (tile_size, tile_size). The default is None.
+    pad : `bool`, optional
+        Whether to center pad the input image. Set ``pad`` = True, if the
+        images are not evenly divisible by the ``tile_size``. The image data is
+        padded with a constant padding value of zero. For each image, the
+        corresponding ground truth image is padded with a "no data" label.
+        The default is False.
+    gt_pattern : `str`, optional
+        A pattern to match the ground truth naming convention. All directories
+        and subdirectories in ``root_dir`` are searched for files matching
+        ``gt_pattern``. The default is '*gt.tif'.
+    sort : `bool`, optional
+        Whether to chronologically sort the samples. Useful for time series
+        data. The default is False.
+    seed : `int`, optional
+        The random seed. Used to split the dataset into training, validation
+        and test set. Useful for reproducibility. The default is 0.
+    transforms : `list` [`pysegcnn.core.split.Augment`], optional
+        List of `pysegcnn.core.split.Augment` instances. Each item in
+        ``transforms`` generates a distinct transformed version of the dataset.
+        The total dataset is composed of the original untransformed dataset
+        together with each transformed version of it.
+        If ``transforms`` = [], only the original dataset is used.
+        The default is [].
+
+    Returns
+    -------
+    None.
+
+    """
+
+    def __init__(self, root_dir, use_bands=[], tile_size=None, pad=False,
+                 gt_pattern='*gt.tif', sort=False, seed=0, transforms=[]):
         # initialize super class StandardEoDataset
-        super().__init__(root_dir, **kwargs)
+        super().__init__(root_dir, use_bands, tile_size, pad, gt_pattern,
+                         sort, seed, transforms)
 
-    # Sentinel 2 bands
     def get_sensor(self):
+        """Sentinel 2 bands of the ProSnow datasets.
+
+        Returns
+        -------
+        sensor : `enum.Enum`
+            An enumeration of the bands of the sensor.
+
+        """
         return Sentinel2
 
-    # class labels of the ProSnow dataset
     def get_labels(self):
+        """Class labels of the ProSnow datasets.
+
+        Returns
+        -------
+        labels : `enum.Enum`
+            The class labels.
+
+        """
         return ProSnowLabels
 
-    # preprocessing of the ProSnow dataset
     def preprocess(self, data, gt):
-
+        """Preprocess ProSnow dataset images.
+
+        Parameters
+        ----------
+        data : `numpy.ndarray`
+            The sample input data.
+        gt : `numpy.ndarray`
+            The sample ground truth.
+
+        Returns
+        -------
+        data : `numpy.ndarray`
+            The preprocessed input data.
+        gt : `numpy.ndarray`
+            The preprocessed ground truth data.
+
+        """
         # if the preprocessing is not done externally, implement it here
         return data, gt
 
-    # function that parses the date from a Sentinel 2 scene id
-    def parse_scene_id(self, scene):
-        return parse_sentinel2_scene(scene)
+    def parse_scene_id(self, scene_id):
+        """Parse ProSnow scene identifiers (Sentinel 2).
 
+        Parameters
+        ----------
+        scene_id : `str`
+            A scene identifier.
 
-class ProSnowGarmisch(ProSnowDataset):
+        Returns
+        -------
+        scene : `dict` or `None`
+            A dictionary containing scene metadata. If `None`, ``scene_id`` is
+            not a valid Sentinel-2 scene identifier.
 
-    def __init__(self, root_dir, **kwargs):
-        # initialize super class StandardEoDatasets
-        super().__init__(root_dir, **kwargs)
+        """
+        return parse_sentinel2_scene(scene_id)
+
+
+class ProSnowGarmisch(ProSnowDataset):
+    """Dataset class for the ProSnow Garmisch dataset.
+
+    Parameters
+    ----------
+    root_dir : `str`
+        The root directory, path to the dataset.
+    use_bands : `list` [`str`], optional
+        A list of the spectral bands to use. The default is [].
+    tile_size : `int` or `None`, optional
+        The size of the tiles. If not `None`, each scene is divided into square
+        tiles of shape (tile_size, tile_size). The default is None.
+    pad : `bool`, optional
+        Whether to center pad the input image. Set ``pad`` = True, if the
+        images are not evenly divisible by the ``tile_size``. The image data is
+        padded with a constant padding value of zero. For each image, the
+        corresponding ground truth image is padded with a "no data" label.
+        The default is False.
+    gt_pattern : `str`, optional
+        A pattern to match the ground truth naming convention. All directories
+        and subdirectories in ``root_dir`` are searched for files matching
+        ``gt_pattern``. The default is '*gt.tif'.
+    sort : `bool`, optional
+        Whether to chronologically sort the samples. Useful for time series
+        data. The default is False.
+    seed : `int`, optional
+        The random seed. Used to split the dataset into training, validation
+        and test set. Useful for reproducibility. The default is 0.
+    transforms : `list` [`pysegcnn.core.split.Augment`], optional
+        List of `pysegcnn.core.split.Augment` instances. Each item in
+        ``transforms`` generates a distinct transformed version of the dataset.
+        The total dataset is composed of the original untransformed dataset
+        together with each transformed version of it.
+        If ``transforms`` = [], only the original dataset is used.
+        The default is [].
+
+    Returns
+    -------
+    None.
+
+    """
+
+    def __init__(self, root_dir, use_bands=[], tile_size=None, pad=False,
+                 gt_pattern='*gt.tif', sort=False, seed=0, transforms=[]):
+        # initialize super class StandardEoDataset
+        super().__init__(root_dir, use_bands, tile_size, pad, gt_pattern,
+                         sort, seed, transforms)
 
     def get_size(self):
+        """Image size of the ProSnow Garmisch dataset.
+
+        Returns
+        -------
+        size : `tuple`
+            The image size (height, width).
+
+        """
         return (615, 543)
 
 
 class ProSnowObergurgl(ProSnowDataset):
-
-    def __init__(self, root_dir, **kwargs):
+    """Dataset class for the ProSnow Obergurgl dataset.
+
+    Parameters
+    ----------
+    root_dir : `str`
+        The root directory, path to the dataset.
+    use_bands : `list` [`str`], optional
+        A list of the spectral bands to use. The default is [].
+    tile_size : `int` or `None`, optional
+        The size of the tiles. If not `None`, each scene is divided into square
+        tiles of shape (tile_size, tile_size). The default is None.
+    pad : `bool`, optional
+        Whether to center pad the input image. Set ``pad`` = True, if the
+        images are not evenly divisible by the ``tile_size``. The image data is
+        padded with a constant padding value of zero. For each image, the
+        corresponding ground truth image is padded with a "no data" label.
+        The default is False.
+    gt_pattern : `str`, optional
+        A pattern to match the ground truth naming convention. All directories
+        and subdirectories in ``root_dir`` are searched for files matching
+        ``gt_pattern``. The default is '*gt.tif'.
+    sort : `bool`, optional
+        Whether to chronologically sort the samples. Useful for time series
+        data. The default is False.
+    seed : `int`, optional
+        The random seed. Used to split the dataset into training, validation
+        and test set. Useful for reproducibility. The default is 0.
+    transforms : `list` [`pysegcnn.core.split.Augment`], optional
+        List of `pysegcnn.core.split.Augment` instances. Each item in
+        ``transforms`` generates a distinct transformed version of the dataset.
+        The total dataset is composed of the original untransformed dataset
+        together with each transformed version of it.
+        If ``transforms`` = [], only the original dataset is used.
+        The default is [].
+
+    Returns
+    -------
+    None.
+
+    """
+
+    def __init__(self, root_dir, use_bands=[], tile_size=None, pad=False,
+                 gt_pattern='*gt.tif', sort=False, seed=0, transforms=[]):
         # initialize super class StandardEoDataset
-        super().__init__(root_dir, **kwargs)
+        super().__init__(root_dir, use_bands, tile_size, pad, gt_pattern,
+                         sort, seed, transforms)
 
     def get_size(self):
+        """Image size of the ProSnow Obergurgl dataset.
+
+        Returns
+        -------
+        size : `tuple`
+            The image size (height, width).
+
+        """
         return (310, 270)
 
 
 class Cloud95Dataset(ImageDataset):
-
-    def __init__(self, root_dir, **kwargs):
+    """Dataset class for the `Cloud95`_ dataset by `Mohajerani et al. (2020)`_.
+
+    .. _Cloud95:
+        https://github.com/SorourMo/95-Cloud-An-Extension-to-38-Cloud-Dataset
+    .. _`Mohajerani et al. (2020)`:
+        https://arxiv.org/abs/2001.08768
+
+    Parameters
+    ----------
+    root_dir : `str`
+        The root directory, path to the dataset.
+    use_bands : `list` [`str`], optional
+        A list of the spectral bands to use. The default is [].
+    tile_size : `int` or `None`, optional
+        The size of the tiles. If not `None`, each scene is divided into square
+        tiles of shape (tile_size, tile_size). The default is None.
+    pad : `bool`, optional
+        Whether to center pad the input image. Set ``pad`` = True, if the
+        images are not evenly divisible by the ``tile_size``. The image data is
+        padded with a constant padding value of zero. For each image, the
+        corresponding ground truth image is padded with a "no data" label.
+        The default is False.
+    gt_pattern : `str`, optional
+        A pattern to match the ground truth naming convention. All directories
+        and subdirectories in ``root_dir`` are searched for files matching
+        ``gt_pattern``. The default is '*gt.tif'.
+    sort : `bool`, optional
+        Whether to chronologically sort the samples. Useful for time series
+        data. The default is False.
+    seed : `int`, optional
+        The random seed. Used to split the dataset into training, validation
+        and test set. Useful for reproducibility. The default is 0.
+    transforms : `list` [`pysegcnn.core.split.Augment`], optional
+        List of `pysegcnn.core.split.Augment` instances. Each item in
+        ``transforms`` generates a distinct transformed version of the dataset.
+        The total dataset is composed of the original untransformed dataset
+        together with each transformed version of it.
+        If ``transforms`` = [], only the original dataset is used.
+        The default is [].
+
+    Returns
+    -------
+    None.
+
+    """
+
+    def __init__(self, root_dir, use_bands=[], tile_size=None, pad=False,
+                 gt_pattern='*gt.tif', sort=False, seed=0, transforms=[]):
+        # initialize super class StandardEoDataset
+        super().__init__(root_dir, use_bands, tile_size, pad, gt_pattern,
+                         sort, seed, transforms)
 
         # the csv file containing the names of the informative patches
         # patches with more than 80% black pixels, i.e. patches resulting from
         # the black margins around a Landsat 8 scene are excluded
         self.exclude = 'training_patches_95-cloud_nonempty.csv'
 
-        # initialize super class ImageDataset
-        super().__init__(root_dir, **kwargs)
-
-    # image size of the Cloud-95 dataset: (height, width)
     def get_size(self):
+        """Image size of the Cloud-95 dataset.
+
+        Returns
+        -------
+        size : `tuple`
+            The image size (height, width).
+
+        """
         return (384, 384)
 
-    # Landsat 8 bands in the Cloud-95 dataset
     def get_sensor(self):
+        """Landsat 8 bands of the Cloud-95 dataset.
+
+        Returns
+        -------
+        sensor : `enum.Enum`
+            An enumeration of the bands of the sensor.
+
+        """
         return Landsat8
 
-    # class labels of the Cloud-95 dataset
     def get_labels(self):
+        """Class labels of the Cloud-95 dataset.
+
+        Returns
+        -------
+        labels : `enum.Enum`
+            The class labels.
+
+        """
         return Cloud95Labels
 
-    # preprocess Cloud-95 dataset
     def preprocess(self, data, gt):
-
+        """Preprocess Cloud-95 dataset images.
+
+        Parameters
+        ----------
+        data : `numpy.ndarray`
+            The sample input data.
+        gt : `numpy.ndarray`
+            The sample ground truth.
+
+        Returns
+        -------
+        data : `numpy.ndarray`
+            The preprocessed input data.
+        gt : `numpy.ndarray`
+            The preprocessed ground truth data.
+
+        """
         # normalize the data
         # here, we use the normalization of the authors of Cloud-95, i.e.
         # Mohajerani and Saeedi (2019, 2020)
         data /= 65535
         gt[gt != self.cval] /= 255
-
         return data, gt
 
-    # function that parses the date from a Landsat 8 scene id
-    def parse_scene_id(self, scene):
-        return parse_landsat_scene(scene)
+    def parse_scene_id(self, scene_id):
+        """Parse Sparcs scene identifiers (Landsat 8).
 
-    def compose_scenes(self):
+        Parameters
+        ----------
+        scene_id : `str`
+            A scene identifier.
+
+        Returns
+        -------
+        scene : `dict` or `None`
+            A dictionary containing scene metadata. If `None`, ``scene_id`` is
+            not a valid Landsat scene identifier.
 
+        """
+        return parse_landsat_scene(scene_id)
+
+    def compose_scenes(self):
+        """Build the list of samples of the dataset.
+
+        Each sample is represented by a dictionary.
+
+        Returns
+        -------
+        scenes : `list` [`dict`]
+            Each item in ``scenes`` is a `dict` with keys:
+                ``'band_name_1'``
+                    Path to the file of band_1.
+                ``'band_name_2'``
+                    Path to the file of band_2.
+                ``'band_name_n'``
+                    Path to the file of band_n.
+                ``'gt'``
+                    Path to the ground truth file.
+                ``'date'``
+                    The date of the sample.
+                ``'tile'``
+                    The tile id of the sample.
+                ``'transform'``
+                    The transformation to apply.
+                ``'id'``
+                    The scene identifier.
+
+        """
         # whether to exclude patches with more than 80% black pixels
         ipatches = []
         if self.exclude is not None:
@@ -620,7 +1341,7 @@ class Cloud95Dataset(ImageDataset):
             patchname = file.split('.')[0].replace(biter + '_', '')
 
             # get the date of the current scene
-            date = self.parse_scene_id(patchname)['date']
+            scene_meta = self.parse_scene_id(patchname)
 
             # check whether the current file is an informative patch
             if ipatches and patchname not in ipatches:
@@ -641,11 +1362,14 @@ class Cloud95Dataset(ImageDataset):
                         scene[band] = os.path.join(band_dirs[band],
                                                    file.replace(biter, band))
 
+                    # the name of the scene the patch was extracted from
+                    scene['id'] = scene_meta['id']
+
                     # store tile number
                     scene['tile'] = tile
 
                     # store date
-                    scene['date'] = date
+                    scene['date'] = scene_meta['date']
 
                     # store optional transformation
                     scene['transform'] = transf
@@ -661,49 +1385,9 @@ class Cloud95Dataset(ImageDataset):
 
 
 class SupportedDatasets(enum.Enum):
+    """Names and corresponding classes of the implemented datasets."""
+
     Sparcs = SparcsDataset
     Cloud95 = Cloud95Dataset
     Garmisch = ProSnowGarmisch
     Obergurgl = ProSnowObergurgl
-
-
-if __name__ == '__main__':
-
-    # define path to working directory
-    # wd = '//projectdata.eurac.edu/projects/cci_snow/dfrisinghelli/'
-    # wd = '/mnt/CEPH_PROJECTS/cci_snow/dfrisinghelli'
-    wd = 'C:/Eurac/2020/'
-
-    # path to the preprocessed sparcs dataset
-    sparcs_path = os.path.join(wd, '_Datasets/Sparcs')
-
-    # path to the Cloud-95 dataset
-    # cloud_path = os.path.join(wd, '_Datasets/Cloud95/Training')
-
-    # path to the ProSnow dataset
-    # prosnow_path = os.path.join(wd, '_Datasets/ProSnow/')
-
-    # instanciate the Cloud-95 dataset
-    # cloud_dataset = Cloud95Dataset(cloud_path,
-    #                                tile_size=192,
-    #                                use_bands=[],
-    #                                sort=False)
-
-    # instanciate the SparcsDataset class
-    sparcs_dataset = SparcsDataset(sparcs_path,
-                                   tile_size=None,
-                                   use_bands=['red', 'green', 'blue', 'nir'],
-                                   sort=False,
-                                   transforms=[],
-                                   gt_pattern='*mask.png',
-                                   pad=True,
-                                   cval=99,
-                                   )
-
-    # instanciate the ProSnow datasets
-    # garmisch = ProSnowGarmisch(os.path.join(prosnow_path, 'Garmisch'),
-    #                            tile_size=None,
-    #                            use_bands=['nir', 'red', 'green', 'blue'],
-    #                            sort=True,
-    #                            transforms=[],
-    #                            gt_pattern='*class.img')
diff --git a/pysegcnn/core/graphics.py b/pysegcnn/core/graphics.py
index 5d765cdfe2d08ec2b3aeb6048765ddd87d1e35ad..e760d9dda5c3b995bd406db3e0bf5424476a703f 100644
--- a/pysegcnn/core/graphics.py
+++ b/pysegcnn/core/graphics.py
@@ -83,15 +83,15 @@ def plot_sample(x, y, use_bands, labels, y_pred=None, figsize=(10, 10),
 
     Parameters
     ----------
-    x : `numpy.ndarray` or `torch.tensor`, (b, h, w)
+    x : `numpy.ndarray` or `torch.Tensor`, (b, h, w)
         Array containing the raw data of the tile, shape=(bands, height, width)
-    y : `numpy.ndarray` or `torch.tensor`, (h, w)
+    y : `numpy.ndarray` or `torch.Tensor`, (h, w)
         Array containing the ground truth of tile ``x``, shape=(height, width)
     use_bands : `list` of `str`
         List describing the order of the bands in ``x``.
     labels : `dict` [`int`, `dict`]
-        The keys are the values of the class labels in the ground truth ``y``.
-        Each nested `dict` should have keys:
+        The label dictionary. The keys are the values of the class labels
+        in the ground truth ``y``. Each nested `dict` should have keys:
             ``'color'``
                 A named color (`str`).
             ``'label'``
@@ -185,8 +185,8 @@ def plot_confusion_matrix(cm, labels, normalize=True,
     cm : `numpy.ndarray`
         The confusion matrix.
     labels : `dict` [`int`, `dict`]
-        The keys are the values of the class labels in the ground truth ``y``.
-        Each nested `dict` should have keys:
+        The label dictionary. The keys are the values of the class labels
+        in the ground truth ``y``. Each nested `dict` should have keys:
             ``'color'``
                 A named color (`str`).
             ``'label'``
diff --git a/pysegcnn/core/layers.py b/pysegcnn/core/layers.py
index 367bd7259537e08bf19cad04ebeb9a42a36c164d..ab6c59a1b85b5c1f9065a54c807b315fb2b60633 100644
--- a/pysegcnn/core/layers.py
+++ b/pysegcnn/core/layers.py
@@ -134,16 +134,16 @@ class Conv2dPool(nn.Module):
 
         Parameters
         ----------
-        x : `torch.tensor`
+        x : `torch.Tensor`
             Output of previous layer.
 
         Returns
         -------
-        y : `torch.tensor`
+        y : `torch.Tensor`
             Output of this block.
-        x : `torch.tensor`
+        x : `torch.Tensor`
             Output before max pooling. Stored for skip connections.
-        i : `torch.tensor`
+        i : `torch.Tensor`
             Indices of the max pooling operation. Used in unpooling operation.
 
         """
@@ -190,18 +190,18 @@ class Conv2dUnpool(nn.Module):
 
         Parameters
         ----------
-        x : `torch.tensor`
+        x : `torch.Tensor`
             Output of previous layer.
-        feature : `torch.tensor`
+        feature : `torch.Tensor`
             Encoder feature used for the skip connection.
-        indices : `torch.tensor`
+        indices : `torch.Tensor`
             Indices of the max pooling operation. Used in unpooling operation.
         skip : `bool`
             Whether to apply skip connetion.
 
         Returns
         -------
-        x : `torch.tensor`
+        x : `torch.Tensor`
             Output of this block.
 
         """
@@ -254,11 +254,11 @@ class Conv2dUpsample(nn.Module):
 
         Parameters
         ----------
-        x : `torch.tensor`
+        x : `torch.Tensor`
             Output of previous layer.
-        feature : `torch.tensor`
+        feature : `torch.Tensor`
             Encoder feature used for the skip connection.
-        indices : `torch.tensor`
+        indices : `torch.Tensor`
             Indices of the max pooling operation. Used in unpooling operation.
             Not used here, but passed to preserve generic interface. Useful in
             `pysegcnn.core.layers.Decoder`.
@@ -267,7 +267,7 @@ class Conv2dUpsample(nn.Module):
 
         Returns
         -------
-        x : `torch.tensor`
+        x : `torch.Tensor`
             Output of this block.
 
         """
@@ -331,12 +331,12 @@ class Encoder(nn.Module):
 
         Parameters
         ----------
-        x : `torch.tensor`
+        x : `torch.Tensor`
             Input image.
 
         Returns
         -------
-        x : `torch.tensor`
+        x : `torch.Tensor`
             Output of the encoder.
 
         """
@@ -411,7 +411,7 @@ class Decoder(nn.Module):
 
         Parameters
         ----------
-        x : `torch.tensor`
+        x : `torch.Tensor`
             Output of the encoder.
         enc_cache : `dict`
             Cache dictionary with keys:
@@ -422,7 +422,7 @@ class Decoder(nn.Module):
 
         Returns
         -------
-        x : `torch.tensor`
+        x : `torch.Tensor`
             Output of the decoder.
 
         """
diff --git a/pysegcnn/core/models.py b/pysegcnn/core/models.py
index 52400d7fb361cc2d01127f4d08d781c889e9fc92..0cd85f4db17a8bf3938d4865280d0674be1ebaeb 100644
--- a/pysegcnn/core/models.py
+++ b/pysegcnn/core/models.py
@@ -275,7 +275,7 @@ class UNet(Network):
 
         Parameters
         ----------
-        x : `torch.tensor`
+        x : `torch.Tensor`
             The input image, shape=(batch_size, channels, height, width).
 
         Returns
diff --git a/pysegcnn/core/split.py b/pysegcnn/core/split.py
index ea6bef9e24f31c65a3c04765f705a9837edb7ce9..2be0aacbc7e9d56cdb5674b6651de16054a0ad71 100644
--- a/pysegcnn/core/split.py
+++ b/pysegcnn/core/split.py
@@ -214,9 +214,9 @@ def date_scene_split(ds, date, dateformat='%Y%m%d'):
     ----------
     ds : `pysegcnn.core.dataset.ImageDataset`
         An instance of `~pysegcnn.core.dataset.ImageDataset`.
-    date : 'str'
+    date : `str`
         A date.
-    dateformat : 'str', optional
+    dateformat : `str`, optional
         The format of ``date``. ``dateformat`` is used by
         `datetime.datetime.strptime' to parse ``date`` to a `datetime.datetime`
         object. The default is '%Y%m%d'.
@@ -283,7 +283,7 @@ def pairwise_disjoint(sets):
 
 
 class CustomSubset(Subset):
-    """Custom subset inheriting `torch.utils.data.dataset.Subset`."""
+    """Custom subset inheriting `torch.utils.data.Subset`."""
 
     def __repr__(self):
         """Representation of ``~pysegcnn.core.split.CustomSubset``."""
diff --git a/pysegcnn/core/trainer.py b/pysegcnn/core/trainer.py
index 31c615f3156d7832917389c53059ce5f8e8bc315..1e0916147652bae3c52e4ec59f33c35a81826366 100644
--- a/pysegcnn/core/trainer.py
+++ b/pysegcnn/core/trainer.py
@@ -1,9 +1,8 @@
+"""Model configuration and training."""
+
+# !/usr/bin/env python
 # -*- coding: utf-8 -*-
-"""
-Created on Wed Aug 12 10:24:34 2020
 
-@author: Daniel
-"""
 # builtins
 import dataclasses
 import pathlib
@@ -34,8 +33,22 @@ LOGGER = logging.getLogger(__name__)
 
 @dataclasses.dataclass
 class BaseConfig:
+    """Base `dataclasses.dataclass` for each configuration."""
 
     def __post_init__(self):
+        """Check the type of each argument.
+
+        Raises
+        ------
+        TypeError
+            Raised if the conversion to the specified type of the argument
+            fails.
+
+        Returns
+        -------
+        None.
+
+        """
         # check input types
         for field in dataclasses.fields(self):
             # the value of the current field
@@ -55,6 +68,45 @@ class BaseConfig:
 
 @dataclasses.dataclass
 class DatasetConfig(BaseConfig):
+    """Dataset configuration class.
+
+    Parameters
+    ----------
+    dataset_name : `str`
+        The name of the dataset.
+    root_dir : `pathlib.Path`
+        The root directory, path to the dataset.
+    bands : `list` [`str`]
+        A list of the spectral bands to use.
+    tile_size : `int`
+        The size of the tiles. Each scene is divided into square tiles of shape
+        (tile_size, tile_size).
+    gt_pattern : `str`
+        A pattern to match the ground truth naming convention. All directories
+        and subdirectories in ``root_dir`` are searched for files matching
+        ``gt_pattern``.
+    seed : `int`
+        The random seed. Used to split the dataset into training, validation
+        and test set. Useful for reproducibility. The default is 0.
+    sort : `bool`, optional
+        Whether to chronologically sort the samples. Useful for time series
+        data. The default is False.
+    transforms : `list` [`pysegcnn.core.split.Augment`], optional
+        List of `pysegcnn.core.split.Augment` instances. Each item in
+        ``transforms`` generates a distinct transformed version of the dataset.
+        The total dataset is composed of the original untransformed dataset
+        together with each transformed version of it.
+        If ``transforms`` = [], only the original dataset is used.
+        The default is [].
+    pad : `bool`, optional
+        Whether to center pad the input image. Set ``pad`` = True, if the
+        images are not evenly divisible by the ``tile_size``. The image data is
+        padded with a constant padding value of zero. For each image, the
+        corresponding ground truth image is padded with a "no data" label.
+        The default is False.
+
+    """
+
     dataset_name: str
     root_dir: pathlib.Path
     bands: list
@@ -66,6 +118,21 @@ class DatasetConfig(BaseConfig):
     pad: bool = False
 
     def __post_init__(self):
+        """Check the type of each argument.
+
+        Raises
+        ------
+        FileNotFoundError
+            Raised if ``root_dir`` does not exist.
+        TypeError
+            Raised if not each item in ``transforms`` is an instance of
+            `pysegcnn.core.split.Augment` in case ``transforms`` is not empty.
+
+        Returns
+        -------
+        None.
+
+        """
         # check input types
         super().__post_init__()
 
@@ -84,7 +151,14 @@ class DatasetConfig(BaseConfig):
                                                        Augment.__name__])))
 
     def init_dataset(self):
+        """Instanciate the dataset.
+
+        Returns
+        -------
+        dataset : `pysegcnn.core.dataset.ImageDataset`
+            An instance of `pysegcnn.core.dataset.ImageDataset`.
 
+        """
         # instanciate the dataset
         dataset = self.dataset_class(
                     root_dir=str(self.root_dir),
@@ -102,6 +176,32 @@ class DatasetConfig(BaseConfig):
 
 @dataclasses.dataclass
 class SplitConfig(BaseConfig):
+    """Dataset split configuration class.
+
+    Parameters
+    ----------
+    split_mode : `str`
+        The mode to split the dataset.
+    ttratio : `float`
+        The ratio of training and validation data to test data, e.g.
+        ``ttratio`` = 0.6 means 60% for training and validation, 40% for
+        testing.
+    tvratio : `float`
+        The ratio of training data to validation data, e.g. ``tvratio`` = 0.8
+        means 80% training, 20% validation.
+    date : `str`, optional
+        A date. Used if ``split_mode`` = 'date'. The default is 'yyyymmdd'.
+    dateformat : `str`, optional
+        The format of ``date``. ``dateformat`` is used by
+        `datetime.datetime.strptime' to parse ``date`` to a `datetime.datetime`
+        object. The default is '%Y%m%d'.
+    drop : `float`, optional
+        Whether to drop samples (during training only) with a fraction of
+        pixels equal to the constant padding value >= ``drop``. ``drop`` = 0
+        means, do not drop any samples. The default is 0.
+
+    """
+
     split_mode: str
     ttratio: float
     tvratio: float
@@ -110,17 +210,46 @@ class SplitConfig(BaseConfig):
     drop: float = 0
 
     def __post_init__(self):
+        """Check the type of each argument.
+
+        Raises
+        ------
+        ValueError
+            Raised if ``split_mode`` is not supported.
+
+        Returns
+        -------
+        None.
+
+        """
         # check input types
         super().__post_init__()
 
         # check if the split mode is valid
         self.split_class = item_in_enum(self.split_mode, SupportedSplits)
 
-    # function to drop samples with a fraction of pixels equal to the constant
-    # padding value self.cval >= self.drop
     @staticmethod
     def _drop_samples(ds, drop_threshold=1):
-
+        """Drop samples with a fraction of pixels equal to the padding value.
+
+        Parameters
+        ----------
+        ds : `pysegcnn.core.split.RandomSubset` or
+        `pysegcnn.core.split.SceneSubset`.
+            An instance of `pysegcnn.core.split.RandomSubset` or
+            `pysegcnn.core.split.SceneSubset`.
+        drop_threshold : `float`, optional
+            The threshold above which samples are dropped. ``drop_threshold`` =
+            1 means a sample is dropped, if all pixels are equal to the padding
+            value. ``drop_threshold`` = 0.8 means, drop a sample if 80% of the
+            pixels are equal to the padding value, etc. The default is 1.
+
+        Returns
+        -------
+        dropped : `list` [`dict`]
+            List of the dropped samples.
+
+        """
         # iterate over the scenes returned by self.compose_scenes()
         dropped = []
         for pos, i in enumerate(ds.indices):
@@ -145,7 +274,32 @@ class SplitConfig(BaseConfig):
         return dropped
 
     def train_val_test_split(self, ds):
-
+        """Split ``ds`` into training, validation and test set.
+
+        Parameters
+        ----------
+        ds : `pysegcnn.core.dataset.ImageDataset`
+            An instance of `pysegcnn.core.dataset.ImageDataset`.
+
+        Raises
+        ------
+        TypeError
+            Raised if ``ds`` is not an instance of
+            `pysegcnn.core.dataset.ImageDataset`.
+
+        Returns
+        -------
+        train_ds : `pysegcnn.core.split.RandomSubset` or
+        `pysegcnn.core.split.SceneSubset`.
+            The training set.
+        valid_ds : `pysegcnn.core.split.RandomSubset` or
+        `pysegcnn.core.split.SceneSubset`.
+            The validation set.
+        test_ds : `pysegcnn.core.split.RandomSubset` or
+        `pysegcnn.core.split.SceneSubset`.
+            The test set.
+
+        """
         if not isinstance(ds, ImageDataset):
             raise TypeError('Expected "ds" to be {}.'
                             .format('.'.join([ImageDataset.__module__,
@@ -172,6 +326,31 @@ class SplitConfig(BaseConfig):
 
     @staticmethod
     def dataloaders(*args, **kwargs):
+        """Build `torch.utils.data.DataLoader` instances.
+
+        Parameters
+        ----------
+        *args : `list` [`torch.utils.data.Dataset`]
+            List of instances of `torch.utils.data.Dataset`.
+        **kwargs
+            Additional keyword arguments passed to
+            `torch.utils.data.DataLoader`.
+
+        Raises
+        ------
+        TypeError
+            Raised if not each item in ``args`` is an instance of
+            `torch.utils.data.Dataset`.
+
+        Returns
+        -------
+        loaders : `list` [`torch.utils.data.DataLoader`]
+            List of instances of `torch.utils.data.DataLoader`. If an instance
+            of `torch.utils.data.Dataset` in ``args`` is empty, `None` is
+            appended to ``loaders`` instead of an instance of
+            `torch.utils.data.DataLoader`.
+
+        """
         # check whether each dataset in args has the correct type
         loaders = []
         for ds in args:
@@ -192,6 +371,72 @@ class SplitConfig(BaseConfig):
 
 @dataclasses.dataclass
 class ModelConfig(BaseConfig):
+    """Model configuration class.
+
+    Parameters
+    ----------
+    model_name : `str`
+        The name of the model.
+    filters : `list` [`int`]
+        List of input channels to the convolutional layers.
+    torch_seed : `int`
+        The random seed to initialize the model weights.
+        Useful for reproducibility.
+    optim_name : `str`
+        The name of the optimizer to update the model weights.
+    loss_name : `str`
+        The name of the loss function measuring the model error.
+    skip_connection : `bool`, optional
+        Whether to apply skip connections. The defaul is True.
+    kwargs: `dict`, optional
+        The configuration for each convolution in the model. The default is
+        {'kernel_size': 3, 'stride': 1, 'dilation': 1}.
+    batch_size : `int`, optional
+        The model batch size. Determines the number of samples to process
+        before updating the model weights. The default is 64.
+    checkpoint : `bool`, optional
+        Whether to resume training from an existing model checkpoint. The
+        default is False.
+    transfer : `bool`, optional
+        Whether to use a model for transfer learning on a new dataset. If True,
+        the model architecture of ``pretrained_model`` is adjusted to a new
+        dataset. The default is False.
+    pretrained_model : `str`, optional
+        The name of the pretrained model to use for transfer learning.
+        The default is ''.
+    lr : `float`, optional
+        The learning rate used by the gradient descent algorithm.
+        The default is 0.001.
+    early_stop : `bool`, optional
+        Whether to apply `early stopping`_. The default is False.
+    mode : `str`, optional
+        The mode of the early stopping. Depends on the metric measuring
+        performance. When using model loss as metric, use ``mode`` = 'min',
+        however, when using accuracy as metric, use ``mode`` = 'max'. For now,
+        only ``mode`` = 'max' is supported. Only used if ``early_stop`` = True.
+        The default is 'max'.
+    delta : `float`, optional
+        Minimum change in early stopping metric to be considered as an
+        improvement. Only used if ``early_stop`` = True. The default is 0.
+    patience : `int`, optional
+        The number of epochs to wait for an improvement in the early stopping
+        metric. If the model does not improve over more than ``patience``
+        epochs, quit training. Only used if ``early_stop`` = True.
+        The default is 10.
+    epochs : `int`, optional
+        The maximum number of epochs to train. The default is 50.
+    nthreads : `int`, optional
+        The number of cpu threads to use during training. The default is
+        torch.get_num_threads().
+    save : `bool`, optional
+        Whether to save the model state to disk. Model states are saved in
+        pysegcnn/main/_models. The default is True.
+
+    .. _early stopping:
+        https://en.wikipedia.org/wiki/Early_stopping
+
+    """
+
     model_name: str
     filters: list
     torch_seed: int
@@ -214,6 +459,19 @@ class ModelConfig(BaseConfig):
     save: bool = True
 
     def __post_init__(self):
+        """Check the type of each argument.
+
+        Raises
+        ------
+        ValueError
+            Raised if the model ``model_name``, the optimizer ``optim_name`` or
+            the loss function ``loss_name`` is not supported.
+
+        Returns
+        -------
+        None.
+
+        """
         # check input types
         super().__post_init__()
 
@@ -233,7 +491,19 @@ class ModelConfig(BaseConfig):
         self.pretrained_path = self.state_path.joinpath(self.pretrained_model)
 
     def init_optimizer(self, model):
+        """Instanciate the optimizer.
 
+        Parameters
+        ----------
+        model : `torch.nn.Module`
+            An instance of `torch.nn.Module`.
+
+        Returns
+        -------
+        optimizer : `torch.optim.Optimizer`
+            An instance of `torch.optim.Optimizer`.
+
+        """
         LOGGER.info('Optimizer: {}.'.format(repr(self.optim_class)))
 
         # initialize the optimizer for the specified model
@@ -242,7 +512,14 @@ class ModelConfig(BaseConfig):
         return optimizer
 
     def init_loss_function(self):
+        """Instanciate the loss function.
+
+        Returns
+        -------
+        loss_function : `torch.nn.Module`
+            An instance of `torch.nn.Module`.
 
+        """
         LOGGER.info('Loss function: {}.'.format(repr(self.loss_class)))
 
         # instanciate the loss function
@@ -251,7 +528,38 @@ class ModelConfig(BaseConfig):
         return loss_function
 
     def init_model(self, ds, state_file):
-
+        """Instanciate the model and the optimizer.
+
+        If the model checkpoint ``state_file`` exists, the pretrained model and
+        optimizer states are loaded, otherwise the model and the optimizer are
+        initialized from scratch.
+
+        Parameters
+        ----------
+        ds : `pysegcnn.core.dataset.ImageDataset`
+            An instance of `pysegcnn.core.dataset.ImageDataset`.
+        state_file : `pathlib.Path`
+            Path to a model checkpoint.
+
+        Returns
+        -------
+        model : `pysegcnn.core.models.Network`
+            An instance of `pysegcnn.core.models.Network`.
+        optimizer : `torch.optim.Optimizer`
+            An instance of `torch.optim.Optimizer`.
+        checkpoint_state : `dict`
+            If the model checkpoint ``state_file`` exists, ``checkpoint_state``
+            has keys:
+                ``'ta'``
+                    The accuracy on the training set (`numpy.ndarray`).
+                ``'tl'``
+                    The loss on the training set (`numpy.ndarray`).
+                ``'va'``
+                    The accuracy on the validation set (`numpy.ndarray`).
+                ``'vl'``
+                    The loss on the validation set (`numpy.ndarray`).
+
+        """
         # write an initialization string to the log file
         LogConfig.init_log('{}: Initializing model run. ')
 
@@ -290,7 +598,39 @@ class ModelConfig(BaseConfig):
 
     @staticmethod
     def load_checkpoint(model, optimizer, state_file):
-
+        """Load an existing model checkpoint.
+
+        If the model checkpoint ``state_file`` exists, the pretrained model and
+        optimizer states are loaded.
+
+        Parameters
+        ----------
+        model : `pysegcnn.core.models.Network`
+            An instance of `pysegcnn.core.models.Network`.
+        optimizer : `torch.optim.Optimizer`
+            An instance of `torch.optim.Optimizer`.
+        state_file : `pathlib.Path`
+            Path to the model checkpoint.
+
+        Returns
+        -------
+        model : `pysegcnn.core.models.Network`
+            An instance of `pysegcnn.core.models.Network`.
+        optimizer : `torch.optim.Optimizer`
+            An instance of `torch.optim.Optimizer`.
+        checkpoint_state : `dict`
+            If the model checkpoint ``state_file`` exists, ``checkpoint_state``
+            has keys:
+                ``'ta'``
+                    The accuracy on the training set (`numpy.ndarray`).
+                ``'tl'``
+                    The loss on the training set (`numpy.ndarray`).
+                ``'va'``
+                    The accuracy on the validation set (`numpy.ndarray`).
+                ``'vl'``
+                    The loss on the validation set (`numpy.ndarray`).
+
+        """
         # whether to resume training from an existing model checkpoint
         checkpoint_state = {}
 
@@ -315,7 +655,36 @@ class ModelConfig(BaseConfig):
 
     @staticmethod
     def transfer_model(state_file, ds):
-
+        """Adjust a pretrained model to a new dataset.
+
+        The classification layer of the pretrained model in ``state_file`` is
+        initilialized from scratch with the classes of the new dataset ``ds``.
+
+        The remaining model weights are preserved.
+
+        Parameters
+        ----------
+        state_file : `pathlib.Path`
+            Path to a pretrained model.
+        ds : `pysegcnn.core.dataset.ImageDataset`
+            An instance of `pysegcnn.core.dataset.ImageDataset`.
+
+        Raises
+        ------
+        TypeError
+            Raised if ``ds`` is not an instance of
+            `pysegcnn.core.dataset.ImageDataset`.
+        ValueError
+            Raised if the bands of ``ds`` do not match the bands of the dataset
+            the pretrained model was trained with.
+
+        Returns
+        -------
+        model : `pysegcnn.core.models.Network`
+            An instance of `pysegcnn.core.models.Network`. The pretrained model
+            adjusted to the new dataset.
+
+        """
         # check input type
         if not isinstance(ds, ImageDataset):
             raise TypeError('Expected "ds" to be {}.'
@@ -346,8 +715,7 @@ class ModelConfig(BaseConfig):
                     .format(', '.join('({}, {})'.format(k, v['label'])
                                       for k, v in ds.labels.items())))
 
-        # adjust the classification layer to the number of classes of the
-        # current dataset
+        # adjust the classification layer to the classes of the new dataset
         model.classifier = Conv2dSame(in_channels=filters[0],
                                       out_channels=model.nclasses,
                                       kernel_size=1)
@@ -757,7 +1125,7 @@ class EarlyStopping(object):
         # whether to check for an increase or a decrease in a given metric
         self.is_better = self.decreased if mode == 'min' else self.increased
 
-        # minimum change in metric to be classified as an improvement
+        # minimum change in metric to be considered as an improvement
         self.min_delta = min_delta
 
         # number of epochs to wait for improvement