diff --git a/pysegcnn/core/trainer.py b/pysegcnn/core/trainer.py
index 641fbec4ddad63a8aca4dfe2caa383928c7ca8a8..a6948c7e809dab3e57983074762262c7bbfe9078 100644
--- a/pysegcnn/core/trainer.py
+++ b/pysegcnn/core/trainer.py
@@ -44,7 +44,7 @@ from pysegcnn.core.utils import (img2np, item_in_enum, accuracy_function,
                                  reconstruct_scene, check_filename_length)
 from pysegcnn.core.split import SupportedSplits, CustomSubset, SceneSubset
 from pysegcnn.core.models import (SupportedModels, SupportedOptimizers,
-                                  SupportedLossFunctions, Network)
+                                  Network)
 from pysegcnn.core.uda import SupportedUdaMethods, CoralLoss, UDA_POSITIONS
 from pysegcnn.core.layers import Conv2dSame
 from pysegcnn.core.logging import log_conf
@@ -410,22 +410,14 @@ class ModelConfig(BaseConfig):
     ----------
     model_name : `str`
         The name of the model.
-    filters : `list` [`int`]
-        List of input channels to the convolutional layers.
     torch_seed : `int`
         The random seed to initialize the model weights.
         Useful for reproducibility.
     optim_name : `str`
         The name of the optimizer to update the model weights.
-    cla_loss : `str`
-        The name of the loss function measuring the model error.
     uda_loss : `str`
-        The name of the unsupervised domain adaptation loss.
-    skip_connection : `bool`
-        Whether to apply skip connections. The default is `True`.
-    kwargs: `dict`
-        The configuration for each convolution in the model. The default is
-        `{'kernel_size': 3, 'stride': 1, 'dilation': 1}`.
+        The name of the unsupervised domain adaptation loss. The default is
+        `''`, which is equivalent of not using unsupervised domain adaptation.
     batch_size : `int`
         The model batch size. Determines the number of samples to process
         before updating the model weights. The default is `64`.
@@ -448,11 +440,13 @@ class ModelConfig(BaseConfig):
         default is `False`, i.e. train from scratch.
     uda_lambda : `float`
         The weight of the domain adaptation, trading off adaptation with
-        classification accuracy on the source domain.
+        classification accuracy on the source domain. The default is `0.5`.
     uda_pos : `str`
-        The layer where to compute the domain adaptation loss.
+        The layer where to compute the domain adaptation loss. The default is
+        `enc`, which means calculating the adaptation loss after the encoder
+        layers.
     freeze : `bool`
-        Whether to freeze the pretrained weights.
+        Whether to freeze the pretrained weights. The default is `False`.
     lr : `float`
         The learning rate used by the gradient descent algorithm.
         The default is `0.001`.
@@ -487,8 +481,6 @@ class ModelConfig(BaseConfig):
         A subclass of :py:class:`pysegcnn.core.models.Network`.
     optim_class : :py:class:`torch.optim.Optimizer`
         A subclass of :py:class:`torch.optim.Optimizer`.
-    cla_loss_class : :py:class:`torch.nn.Module`
-        A subclass of :py:class:`torch.nn.Module`
     uda_loss_class : :py:class:`torch.nn.Module`
         A subclass of :py:class:`torch.nn.Module`
     state_path : :py:class:`pathlib.Path`
@@ -502,14 +494,9 @@ class ModelConfig(BaseConfig):
     """
 
     model_name: str
-    filters: list
     torch_seed: int
     optim_name: str
-    cla_loss: str
     uda_loss: str = ''
-    skip_connection: bool = True
-    kwargs: dict = dataclasses.field(
-        default_factory=lambda: {'kernel_size': 3, 'stride': 1, 'dilation': 1})
     batch_size: int = 64
     checkpoint: bool = False
     transfer: bool = False
@@ -539,8 +526,7 @@ class ModelConfig(BaseConfig):
         ------
         ValueError
             Raised if the model ``model_name``, the optimizer ``optim_name``,
-            the loss function ``cla_loss`` or the domain adaptation loss
-            ``uda_loss`` is not supported.
+            or the domain adaptation loss ``uda_loss`` is not supported.
 
         """
         # check input types
@@ -552,10 +538,6 @@ class ModelConfig(BaseConfig):
         # check whether the optimizer is currently supported
         self.optim_class = item_in_enum(self.optim_name, SupportedOptimizers)
 
-        # check whether the loss function is currently supported
-        self.cla_loss_class = item_in_enum(self.cla_loss,
-                                           SupportedLossFunctions)
-
         # check whether the domain adaptation loss is currently supported
         if self.transfer and not self.supervised:
             self.uda_loss_class = item_in_enum(self.uda_loss,
@@ -575,15 +557,13 @@ class ModelConfig(BaseConfig):
         # path to pretrained model
         self.pretrained_path = self.state_path.joinpath(self.pretrained_model)
 
-    def init_optimizer(self, model, **kwargs):
+    def init_optimizer(self, model):
         """Instanciate the optimizer.
 
         Parameters
         ----------
         model : :py:class:`torch.nn.Module`
             An instance of :py:class:`torch.nn.Module`.
-        **kwargs:
-            Additional keyword arguments passed to ``self.optim_class``.
 
         Returns
         -------
@@ -594,35 +574,14 @@ class ModelConfig(BaseConfig):
         LOGGER.info('Optimizer: {}.'.format(repr(self.optim_class)))
 
         # initialize the optimizer for the specified model
-        optimizer = self.optim_class(model.parameters(), self.lr, **kwargs)
+        optimizer = self.optim_class(model.parameters(), self.lr,
+                                     **self.optim_kwargs)
 
         return optimizer
 
-    def init_cla_loss_function(self):
-        """Instanciate the classification loss function.
-
-        Returns
-        -------
-        cla_loss_function : :py:class:`torch.nn.Module`
-            An instance of :py:class:`torch.nn.Module`.
-
-        """
-        LOGGER.info('Classification loss function: {}.'
-                    .format(repr(self.cla_loss_class)))
-
-        # instanciate the classification loss function
-        cla_loss_function = self.cla_loss_class()
-
-        return cla_loss_function
-
-    def init_uda_loss_function(self, uda_lambda):
+    def init_uda_loss_function(self):
         """Instanciate the domain adaptation loss function.
 
-        Parameters
-        ----------
-        uda_lambda : `float`
-            The weight of the domain adaptation.
-
         Returns
         -------
         uda_loss_function : :py:class:`torch.nn.Module`
@@ -633,7 +592,7 @@ class ModelConfig(BaseConfig):
                     .format(repr(self.uda_loss_class)))
 
         # instanciate the loss function
-        uda_loss_function = self.uda_loss_class(uda_lambda)
+        uda_loss_function = self.uda_loss_class(self.uda_lambda)
 
         return uda_loss_function
 
@@ -641,8 +600,8 @@ class ModelConfig(BaseConfig):
         """Instanciate the model and the optimizer.
 
         If ``self.checkpoint`` is set to True, the pretrained model in
-        ``state_file`` is loaded. Otherwise, the model is initiated
-        from scratch on the dataset ``ds``.
+        ``state_file`` is loaded, if it exists. Otherwise, the model is
+        initiated from scratch on the dataset ``ds``.
 
         If ``self.transfer`` is True, the pretrained model in
         ``self.pretrained_path`` is adjusted to the dataset ``ds``.
@@ -727,7 +686,8 @@ class ModelConfig(BaseConfig):
         Parameters
         ----------
         model_state : `dict`
-            A dictionary containing the model and optimizer state.
+            A dictionary containing the model and optimizer state, as
+            constructed by :py:meth:`~pysegcnn.core.Network.save`.
 
         Returns
         -------
@@ -735,8 +695,6 @@ class ModelConfig(BaseConfig):
             The model checkpoint loss and accuracy time series.
 
         """
-        # load model loss and accuracy
-
         # get all non-zero elements, i.e. get number of epochs trained
         # before the early stop
         checkpoint_state = {k: v[np.nonzero(v)].reshape(v.shape[0], -1)
@@ -745,7 +703,7 @@ class ModelConfig(BaseConfig):
         return checkpoint_state
 
     @staticmethod
-    def transfer_model(model, bands, ds, freeze=False):
+    def transfer_model(model, ds, freeze=False):
         """Adjust a pretrained model to a new dataset.
 
         If the number of classes in the pretrained model ``model`` does not
@@ -757,8 +715,6 @@ class ModelConfig(BaseConfig):
         ----------
         model : :py:class:`pysegcnn.core.models.Network`
             An instance of the pretrained model to adjust to the dataset``ds``.
-        bands : `list` [`str`]
-            The spectral bands used to train ``model``.
         ds : :py:class:`pysegcnn.core.dataset.ImageDataset`
             The dataset to which the classification layer of ``model`` is
             adapted.
@@ -774,8 +730,9 @@ class ModelConfig(BaseConfig):
             Raised if ``ds`` is not an instance of
             :py:class:`pysegcnn.core.dataset.ImageDataset`.
         ValueError
-            Raised if the bands of ``ds`` do not match the bands ``bands`` of
-            the dataset the pretrained model ``model`` was trained with.
+            Raised if the number of input channels in ``ds`` do not match the
+            number of input channels of the dataset the pretrained model
+            ``model`` was trained with.
 
         Returns
         -------
@@ -789,10 +746,13 @@ class ModelConfig(BaseConfig):
                             .format('.'.join([ImageDataset.__module__,
                                               ImageDataset.__name__])))
 
-        # check whether the current dataset uses the correct spectral bands
-        if ds.use_bands != bands:
-            raise ValueError('The model was trained with bands {}, not with '
-                             'bands {}.'.format(bands, ds.use_bands))
+        # check whether the current dataset uses the same number of input
+        # channels as the pretrained model
+        if len(ds.use_bands) != model.in_channels:
+            raise ValueError('The model was trained with {} input channels, '
+                             'which does not match the {} input channels of the
+                             'specified dataset.'.format(model.in_channels,
+                                                         len(ds.use_bands)))
 
         # configure model for the specified dataset
         LOGGER.info('Configuring model for new dataset: {}.'.format(
@@ -830,9 +790,20 @@ class StateConfig(BaseConfig):
     """Model state configuration class.
 
     Generate the model state filename according to the following naming
-    convention:
+    conventions:
+        - For source domain without domain adaptation:
+            Model_Optim_SourceDataset_ModelParams.pt
+
+        - For supervised domain adaptation to a target domain:
+            NameOfPretrainedModel_sda_TargetDataset.pt
+
+        - For unsupervised domain adaptation to a target domain:
+            Model_Optim_SourceDataset_uda_TargetDataset_ModelParams.pt
 
-    `model_dataset_optimizer_splitmode_splitparams_tilesize_batchsize_bands.pt`
+        - For unsupervised domain adaptation to a target domain using a
+        pretrained model:
+            Model_Optim_SourceDataset_TargetDataset_ModelParams_prt_
+            NameOfPretrainedModel.pt
 
     Attributes
     ----------
@@ -872,9 +843,14 @@ class StateConfig(BaseConfig):
         """
         super().__post_init__()
 
-        # base model state filename
-        # Model_Dataset_SplitMode_SplitParams_TileSize_BatchSize_Bands
-        self.state_file = '{}_{}_{}Split_{}_t{}_b{}_{}.pt'
+        # base dataset state filename: Dataset_SplitMode_SplitParams
+        self.ds_state_file = '{}_{}Split_{}'
+
+        # base model state filename: Model_Optim
+        self.ml_state_file = '{}_{}'
+
+        # base model state filename extentsion: TileSize_BatchSize_Bands
+        self.ml_state_ext = 't{}_b{}_{}.pt'
 
         # check that the spectral bands are the same for both source and target
         # domains
@@ -894,71 +870,68 @@ class StateConfig(BaseConfig):
             The path to the model state file.
 
         """
-        # state file name for model trained on the source domain only
-        state_src = self._format_state_file(
-                self.state_file, self.src_dc, self.src_sc, self.mc)
+        # source domain dataset state filename
+        src_ds_state = self._format_ds_state(
+            self.ds_state_file, self.src_dc, self.src_sc)
+
+        # model state file name and extension: common to both source and target
+        # domain
+        ml_state, ml_ext = self._format_ml_state(self.src_dc, self.mc)
+
+        # state file for models trained only on the source domain
+        state = '_'.join([ml_state, src_ds_state, ml_ext])
 
         # check whether the model is trained on the source domain only
-        if not self.mc.transfer:
-            # state file for models trained only on the source domain
-            state = state_src
-        else:
-            # state file for model trained on target domain
-            state_trg = self._format_state_file(
-                self.state_file, self.trg_dc, self.trg_sc, self.mc)
+        if self.mc.transfer:
+
+            # target domain dataset state filename
+            trg_ds_state = self._format_ds_state(
+                self.ds_state_file, self.trg_dc, self.trg_sc)
 
             # check whether a pretrained model is used to fine-tune to the
             # target domain
             if self.mc.supervised:
                 # state file for models fine-tuned to target domain
-                state = state_trg.replace('.pt', '_pretrained_{}'.format(
-                    self.mc.pretrained_model))
+                # DatasetConfig_PretrainedModel.pt
+                state = '_'.join([self.mc.pretrained_model,
+                                  'sda_{}'.format(trg_ds_state)])
             else:
                 # state file for models trained via unsupervised domain
                 # adaptation
-                state = state_src.replace('.pt', '_uda_{}{}'.format(
-                    self.mc.uda_pos, state_trg.replace(self.mc.model_name, ''))
-                    )
+                state = '_'.join([state.replace(
+                    ml_ext, 'uda_{}'.format(self.mc.uda_pos)),
+                    trg_ds_state, ml_ext])
 
                 # check whether unsupervised domain adaptation is initialized
                 # from a pretrained model state
                 if self.mc.uda_from_pretrained:
-                    state = state.replace('.pt', '_pretrained_{}'.format(
-                            self.mc.pretrained_model))
+                    state = '_'.join(state.replace('.pt', ''),
+                                     'prt_{}'.format(
+                                         self.mc.pretrained_model))
 
         # path to model state
         state = self.mc.state_path.joinpath(state)
 
         return state
 
-    def _format_state_file(self, state_file, dc, sc, mc):
-        """Format base model state filename.
+    def _format_ds_state(self, state_file, dc, sc):
+        """Format base dataset state filename.
 
         Parameters
         ----------
         state_file : `str`
-            The base model state filename.
+            The base dataset state filename.
         dc : :py:class:`pysegcnn.core.trainer.DatasetConfig`
-            The domain dataset configuration.
+            The dataset configuration.
         sc : :py:class:`pysegcnn.core.trainer.SplitConfig`
-            The domain dataset split configuration.
-        mc : :py:class:`pysegcnn.core.trainer.ModelConfig`
-            The model configuration.
+            The dataset split configuration.
 
         Returns
         -------
         file : `str`
-            The formatted model state filename.
+            The formatted dataset state filename.
 
         """
-        # get the band numbers
-        if dc.bands:
-            bformat = ''.join(band[0] +
-                              str(dc.dataset_class.get_sensor().
-                                  __members__[band].value)
-                              for band in dc.bands)
-        else:
-            bformat = 'all'
 
         # check which split mode was used
         if sc.split_mode == 'date':
@@ -971,879 +944,870 @@ class StateConfig(BaseConfig):
                 str(sc.tvratio).replace('.', ''))
 
         # model state filename
-        file = state_file.format(mc.model_name,
-                                 dc.dataset_class.__name__ +
+        file = state_file.format(dc.dataset_class.__name__ +
                                  '_m{}'.format(len(dc.merge_labels)),
                                  sc.split_mode.capitalize(),
                                  split_params,
-                                 dc.tile_size,
-                                 mc.batch_size,
-                                 bformat)
+                                 )
 
         return file
 
+    def _format_ml_state(self, dc, mc):
+        """Format base model state filename.
+
+        Parameters
+        ----------
+        dc : :py:class:`pysegcnn.core.trainer.DatasetConfig`
+            The source domain dataset configuration.
+        mc : :py:class:`pysegcnn.core.trainer.ModelConfig`
+            The model configuration.
+
+        Returns
+        -------
+        extention : `str`
+            The formatted model state filename.
+
+        """
+        # get the band numbers
+        if dc.bands:
+            bands = dc.dataset_class.get_sensor().band_dict()
+            bformat = ''.join([(v[0] + str(k)) for k, v in bands.items()])
+        else:
+            bformat = 'all'
+
+        return (self.ml_state_file.format(mc.model_name, mc.optim_name),
+                self.ml_state_ext.format(dc.tile_size, mc.batch_size, bformat))
+
 
 @dataclasses.dataclass
-class NetworkInference(BaseConfig):
-    """Model inference configuration.
+class LogConfig(BaseConfig):
+    """Logging configuration class.
 
-    Evaluate a model.
+    Generate the model log file.
 
     Attributes
     ----------
     state_file : :py:class:`pathlib.Path`
-        Path to the model to evaluate.
-    implicit : `bool`
-        Whether to evaluate the model on the datasets defined at training time.
-    domain : `str`
-        Whether to evaluate on the source domain (``domain='src'``), i.e. the
-        domain the model in ``state_file`` was trained on, or a target domain
-        (``domain='trg'``).
-    test : `bool` or `None`
-        Whether to evaluate the model on the training(``test=None``), the
-        validation (``test=False``) or the test set (``test=True``).
-    map_labels : `bool`
-        Whether to map the model labels from the model source domain to the
-        defined ``domain`` in case the domain class labels differ.
-    ds : `dict`
-        The dataset configuration dictionary passed to
-        :py:class:`pysegcnn.core.trainer.DatasetConfig` when evaluating on
-        an explicitly defined dataset, i.e. ``implicit=False``.
-    ds_split : `dict`
-        The dataset split configuration dictionary passed to
-        :py:class:`pysegcnn.core.trainer.SplitConfig` when evaluating on
-        an explicitly defined dataset, i.e. ``implicit=False``.
-    predict_scene : `bool`
-        The model prediction order. If False, the samples (tiles) of a dataset
-        are predicted in any order and the scenes are not reconstructed.
-        If True, the samples (tiles) are ordered according to the scene they
-        belong to and a model prediction for each entire reconstructed scene is
-        returned. The default is `False`.
-    plot_samples : `bool`
-        Whether to save a plot of false color composite, ground truth and model
-        prediction for each sample (tile). Only used if ``predict_scene=False``
-        . The default is `False`.
-    plot_scenes : `bool`
-        Whether to save a plot of false color composite, ground truth and model
-        prediction for each entire scene. Only used if ``predict_scene=True``.
-        The default is `False`.
-    plot_bands : `list` [`str`]
-        The bands to build the false color composite. The default is
-        `['nir', 'red', 'green']`.
-    cm : `bool`
-        Whether to compute and plot the confusion matrix. The default is `True`
-        .
-    figsize : `tuple`
-        The figure size in centimeters. The default is `(10, 10)`.
-    alpha : `int`
-        The level of the percentiles for contrast stretching of the false color
-        compsite. The default is `0`, i.e. no stretching.
-    animate : `bool`
-        Whether to create an animation of (input, ground truth, prediction) for
-        the scenes in the train/validation/test dataset. Only works if
-        ``predict_scene=True`` and ``plot_scene=True``.
-    device : `str`
-        The device to evaluate the model on, i.e. `cpu` or `cuda`.
-    base_path : :py:class:`pathlib.Path`
-        Root path to store model output.
-    sample_path : :py:class:`pathlib.Path`
-        Path to store plots of model predictions for single samples.
-    scenes_path : :py:class:`pathlib.Path`
-        Path to store plots of model predictions for entire scenes.
-    perfmc_path : :py:class:`pathlib.Path`
-        Path to store plots of model performance, e.g. confusion matrix.
-    animtn_path : :py:class:`pathlib.Path`
-        Path to store animations.
-    models_path : :py:class:`pathlib.Path`
-        Path to search for model state files, i.e. pretrained models.
-    kwargs : `dict`
-        Keyword arguments for :py:func:`pysegcnn.core.graphics.plot_sample`
-    basename : `str`
-        Base filename for each plot.
-    model : :py:class:`pysegcnn.core.models.Network`
-        The model to use for inference.
-    model_state : `dict`
-        The model state as saved by
-        :py:class:`pysegcnn.core.trainer.NetworkTrainer`.
-    trg_ds : :py:class:`pysegcnn.core.split.CustomSubset`
-        The dataset to evaluate ``model`` on.
-    src_ds : :py:class:`pysegcnn.core.split.CustomSubset`
-        The model source domain training dataset.
-    fig : :py:class:`matplotlib.figure.Figure`
-        A :py:class:`matplotlib.figure.Figure` instance to iteratively plot to.
-    anim : :py:class:`pysegcnn.core.graphics.Animate`
-        An instance :py:class:`pysegcnn.core.graphics.Animate` Used to create
-        animations if ``animate=True``.
-    conf_mat : :py:class:`numpy.ndarray`
-        The model confusion matrix.
-
+        Path to a model state file.
+    log_path : :py:class:`pathlib.Path`
+        Path to store model logs.
+    log_file : :py:class:`pathlib.Path`
+        Path to the log file of the model ``state_file``.
     """
 
     state_file: pathlib.Path
-    implicit: bool
-    domain: str
-    test: object
-    map_labels: bool
-    ds: dict = dataclasses.field(default_factory={})
-    ds_split: dict = dataclasses.field(default_factory={})
-    predict_scene: bool = False
-    plot_samples: bool = False
-    plot_scenes: bool = False
-    plot_bands: list = dataclasses.field(
-        default_factory=lambda: ['nir', 'red', 'green'])
-    cm: bool = True
-    figsize: tuple = (10, 10)
-    alpha: int = 5
-    animate: bool = False
 
     def __post_init__(self):
         """Check the type of each argument.
 
-        Configure figure output paths.
-
-        Raises
-        ------
-        TypeError
-            Raised if ``test`` is not of type `bool` or `None`.
-        ValueError
-            Raised if ``domain`` is not 'src' or 'trg'.
+        Generate model log file.
 
         """
         super().__post_init__()
 
-        # check whether the test input parameter is correctly specified
-        if self.test not in [None, False, True]:
-            raise TypeError('Expected "test" to be None, True or False, got '
-                            '{}.'.format(self.test))
-
-        # check whether the domain is correctly specified
-        if self.domain not in ['src', 'trg']:
-            raise ValueError('Expected "domain" to be "src" or "trg", got {}.'
-                             .format(self.domain))
-
-        # the device to compute on, use gpu if available
-        self.device = torch.device("cuda:0" if torch.cuda.is_available() else
-                                   "cpu")
+        # the path to store model logs
+        self.log_path = pathlib.Path(HERE).joinpath('_logs')
 
-        # the output paths for the different graphics
-        self.base_path = pathlib.Path(HERE)
-        self.sample_path = self.base_path.joinpath('_samples')
-        self.scenes_path = self.base_path.joinpath('_scenes')
-        self.perfmc_path = self.base_path.joinpath('_graphics')
-        self.animtn_path = self.base_path.joinpath('_animations')
+        # the log file of the current model
+        self.log_file = check_filename_length(self.log_path.joinpath(
+            self.state_file.name.replace('.pt', '.log')))
 
-        # input path for model state files
-        self.models_path = self.base_path.joinpath('_models')
-        self.state_file = self.models_path.joinpath(self.state_file)
+    @staticmethod
+    def now():
+        """Return the current date and time.
 
-        # initialize logging
-        log = LogConfig(self.state_file)
-        dictConfig(log_conf(log.log_file))
-        log.init_log('{}: ' + 'Evaluating model: {}.'
-                     .format(self.state_file.name))
+        Returns
+        -------
+        date : :py:class:`datetime.datetime`
+            The current date and time.
 
-        # plotting keyword arguments
-        self.kwargs = {'bands': self.plot_bands,
-                       'alpha': self.alpha,
-                       'figsize': self.figsize}
+        """
+        return datetime.datetime.strftime(datetime.datetime.now(),
+                                          '%Y-%m-%dT%H:%M:%S')
 
-        # base filename for each plot
-        self.basename = self.state_file.stem
+    @staticmethod
+    def init_log(init_str):
+        """Generate a string to identify a new model run.
 
-        # load the model state
-        self.model, _, self.model_state = Network.load(self.state_file)
+        Parameters
+        ----------
+        init_str : `str`
+            The string to write to the model log file.
 
-        # load the target dataset: dataset to evaluate the model on
-        self.trg_ds = self.load_dataset()
+        """
+        LOGGER.info(80 * '-')
+        LOGGER.info(init_str.format(LogConfig.now()))
+        LOGGER.info(80 * '-')
 
-        # load the source dataset: dataset the model was trained on
-        self.src_ds = self.model_state['src_train_dl'].dataset.dataset
 
-        # create a figure to use for plotting
-        self.fig, _ = plt.subplots(1, 3, figsize=self.kwargs['figsize'])
+@dataclasses.dataclass
+class ClassificationNetworkTrainer(BaseConfig):
+    """Base model training class for classification problems.
 
-        # check if the animate parameter is correctly specified
-        if self.animate:
-            if not self.plot:
-                LOGGER.warning('animate requires plot_scenes=True or '
-                               'plot_samples=True. Can not create animation.')
-            else:
-                # check whether the output path is valid
-                if not self.anim_path.exists():
-                    # create output path
-                    self.anim_path.mkdir(parents=True, exist_ok=True)
-                    self.anim = Animate(self.anim_path)
+    Train an instance of :py:class:`pysegcnn.core.models.Network` on a
+    classification problem. The `categorical cross-entropy loss`_
+    is used as the loss function in combination with the `softmax`_ output
+    layer activation function.
 
-    @staticmethod
-    def get_scene_tiles(ds, scene_id):
-        """Return the tiles of the scene with id ``scene_id``.
+    In case of a binary classification problem, the categorical cross-entropy
+    loss reduces to the binary cross-entropy loss and the softmax function to
+    the standard `logistic function`_.
 
-        Parameters
-        ----------
-        ds : :py:class:`pysegcnn.core.split.CustomSubset`
-            A instance of a subclass of
-            :py:class:`pysegcnn.core.split.CustomSubset`.
-        scene_id : `str`
-            A valid scene identifier.
-
-        Raises
-        ------
-        ValueError
-            Raised if ``scene_id`` is not a valid scene identifier for the
-            dataset ``ds``.
-
-        Returns
-        -------
-        indices : `list` [`int`]
-            List of indices of the tiles of the scene with id ``scene_id`` in
-            ``ds``.
-        date : :py:class:`datetime.datetime`
-            The date of the scene with id ``scene_id``.
+    Attributes
+    ----------
+    model : :py:class:`pysegcnn.core.models.Network`
+        The model to train. An instance of
+        :py:class:`pysegcnn.core.models.Network`.
+    optimizer : :py:class:`torch.optim.Optimizer`
+        The optimizer to update the model weights. An instance of
+        :py:class:`torch.optim.Optimizer`.
+    state_file : :py:class:`pathlib.Path`
+        Path to save the model state.
+    src_train_dl : :py:class:`torch.utils.data.DataLoader`
+        The source domain training :py:class:`torch.utils.data.DataLoader`
+        instance build from an instance of
+        :py:class:`pysegcnn.core.split.CustomSubset`.
+    src_valid_dl : :py:class:`torch.utils.data.DataLoader`
+        The source domain validation :py:class:`torch.utils.data.DataLoader`
+        instance build from an instance of
+        :py:class:`pysegcnn.core.split.CustomSubset`.
+    src_test_dl : :py:class:`torch.utils.data.DataLoader`
+        The source domain test :py:class:`torch.utils.data.DataLoader`
+        instance build from an instance of
+        :py:class:`pysegcnn.core.split.CustomSubset`.
+    epochs : `int`
+        The maximum number of epochs to train. The default is `1`.
+    nthreads : `int`
+        The number of cpu threads to use during training. The default is
+        :py:func:`torch.get_num_threads()`.
+    early_stop : `bool`
+        Whether to apply `Early Stopping`_. The default is `False`.
+    mode : `str`
+        The early stopping mode. Depends on the metric measuring
+        performance. When using model loss as metric, use ``mode='min'``,
+        however, when using accuracy as metric, use ``mode='max'``. For now,
+        only ``mode='max'`` is supported. Only used if ``early_stop=True``.
+        The default is `'max'`.
+    delta : `float`
+        Minimum change in early stopping metric to be considered as an
+        improvement. Only used if ``early_stop=True``. The default is `0`.
+    patience : `int`
+        The number of epochs to wait for an improvement in the early stopping
+        metric. If the model does not improve over more than ``patience``
+        epochs, quit training. Only used if ``early_stop=True``. The default is
+        `10`.
+    checkpoint_state : `dict` [`str`, :py:class:`numpy.ndarray`]
+        A model checkpoint for ``model``. If specified, ``checkpoint_state``
+        should be a dictionary with keys describing the training metric.
+        The default is `{}`.
+    save : `bool`
+        Whether to save the model state to ``state_file``. The default is
+        `True`.
+    device : `str`
+        The device to train the model on, i.e. `cpu` or `cuda`.
+    cla_loss_function : :py:class:`torch.nn.Module`
+        The classification loss function to compute the model error. An
+        instance of :py:class:`torch.nn.CrossEntropyLoss`.
+    tracker : :py:class:`pysegcnn.core.trainer.MetricTracker`
+        A :py:class:`pysegcnn.core.trainer.MetricTracker` instance tracking
+        training metrics, i.e. loss and accuracy.
+    max_accuracy : `float`
+        Maximum accuracy of ``model`` on the validation dataset.
+    es : `None` or :py:class:`pysegcnn.core.trainer.EarlyStopping`
+        The early stopping instance if ``early_stop=True``, else `None`.
+    tmbatch : `int`
+        Number of mini-batches in the training dataset.
+    vmbatch : `int`
+        Number of mini-batches in the validation dataset.
+    training_state : `dict` [`str`, :py:class:`numpy.ndarray`]
+        The training state dictionary. The keys describe the type of the
+        training metric.
+    params_to_save : `dict`
+        The parameters to save in the model ``state_file``, in addition to the
+        model and optimizer weights.
 
-        """
-        # check if the scene id is valid
-        scene_meta = ds.dataset.parse_scene_id(scene_id)
-        if scene_meta is None:
-            raise ValueError('{} is not a valid scene identifier'
-                             .format(scene_id))
+    .. _Early Stopping:
+        https://en.wikipedia.org/wiki/Early_stopping
 
-        # iterate over the scenes of the dataset
-        indices = []
-        for i, scene in enumerate(ds.scenes):
-            # if the scene id matches a given id, save the index of the scene
-            if scene['id'] == scene_id:
-                indices.append(i)
+    .. _multi-class cross-entropy loss:
+        https://gombru.github.io/2018/05/23/cross_entropy_loss/
 
-        return indices, scene_meta['date']
+    .. _softmax:
+        https://peterroelants.github.io/posts/cross-entropy-softmax/
 
-    @staticmethod
-    def replace_dataset_path(ds, drive_path):
-        """Replace the path to the datasets.
+    .. _logistic function:
+        https://en.wikipedia.org/wiki/Logistic_function
 
-        Useful to evaluate models on machines, that are different from the
-        machine the model was trained on.
+    """
 
-        .. important::
+    model: Network
+    optimizer: Optimizer
+    state_file: pathlib.Path
+    src_train_dl: DataLoader
+    src_valid_dl: DataLoader
+    src_test_dl: DataLoader
+    epochs: int = 1
+    nthreads: int = torch.get_num_threads()
+    early_stop: bool = False
+    mode: str = 'max'
+    delta: float = 0
+    patience: int = 10
+    checkpoint_state: dict = dataclasses.field(default_factory={})
+    save: bool = True
 
-            This function assumes that the datasets are stored in a directory
-            named "Datasets" on each machine.
+    def __post_init__(self):
+        """Check the type of each argument.
 
-        See ``DRIVE_PATH`` in :py:mod:`pysegcnn.main.config`.
+        Configure the device to train the model on, i.e. train on the gpu if
+        available.
 
-        Parameters
-        ----------
-        ds : :py:class:`pysegcnn.core.split.CustomSubset`
-            A subset of an instance of
-            :py:class:`pysegcnn.core.dataset.ImageDataset`.
-        drive_path : `str`
-            Base path to the datasets on the current machine. ``drive_path``
-            should end with `'Datasets'`.
+        Configure early stopping if required.
 
-        Raises
-        ------
-        TypeError
-            Raised if ``ds`` is not an instance of
-            :py:class:`pysegcnn.core.split.CustomSubset` and if ``ds`` is not
-            a subset of an instance of
-            :py:class:`pysegcnn.core.dataset.ImageDataset`.
+        Initialize training metric tracking.
 
         """
-        # check input type
-        if isinstance(ds, CustomSubset):
-            # check the type of the dataset
-            if not isinstance(ds.dataset, ImageDataset):
-                raise TypeError('ds should be a subset created from a {}.'
-                                .format(repr(ImageDataset)))
-        else:
-            raise TypeError('ds should be an instance of {}.'
-                            .format(repr(CustomSubset)))
+        super().__post_init__()
 
-        # iterate over the scenes of the dataset
-        for scene in ds.dataset.scenes:
-            for k, v in scene.items():
-                # do only look for paths
-                if isinstance(v, str) and k != 'id':
+        # the device to train the model on
+        self.device = torch.device('cuda:0' if torch.cuda.is_available() else
+                                   'cpu')
+        # set the number of threads
+        torch.set_num_threads(self.nthreads)
 
-                    # drive path: match path before "Datasets"
-                    # dpath = re.search('^(.*)(?=(/.*Datasets))', v)
+        # send the model to the gpu if available
+        self.model = self.model.to(self.device)
 
-                    # drive path: match path up to "Datasets"
-                    dpath = re.search('^(.*?Datasets)', v)[0]
+        # instanciate multiclass classification loss function: multi-class
+        # cross-entropy loss function
+        self.cla_loss_function = nn.CrossEntropyLoss()
+        LOGGER.info('Classification loss function: {}.'
+                    .format(repr(self.cla_loss_class)))
 
-                    # replace drive path
-                    if dpath != drive_path:
-                        scene[k] = v.replace(str(dpath), drive_path)
+        # instanciate metric tracker
+        self.tracker = MetricTracker(
+            train_metrics=['train_loss', 'train_accu'],
+            valid_metrics=['valid_loss', 'valid_accu'])
 
-    def load_dataset(self):
-        """Load the defined dataset.
+        # initialize metric tracker
+        self.tracker.initialize()
 
-        Raises
-        ------
-        ValueError
-            Raised if the requested dataset was not available at training time,
-            if ``implicit=True``.
+        # maximum accuracy on the validation set
+        self.max_accuracy = 0
+        if self.checkpoint_state:
+            self.max_accuracy = self.checkpoint_state['valid_accu'].mean(
+                axis=0).max().item()
 
-            Raised if the dataset ``ds`` does not have the same spectral bands
-            as the model to evaluate, if ``implicit=False``.
+        # whether to use early stopping
+        self.es = None
+        if self.early_stop:
+            self.es = EarlyStopping(self.mode, self.max_accuracy, self.delta,
+                                    self.patience)
 
-        Returns
-        -------
-        ds : :py:class:`pysegcnn.core.split.CustomSubset`
-            The dataset to evaluate the model on.
+        # number of mini-batches in the training and validation sets
+        self.tmbatch = len(self.src_train_dl)
+        self.vmbatch = len(self.src_valid_dl)
 
-        """
-        # check whether to evaluate on the datasets defined at training time
-        if self.implicit:
+        # log representation
+        LOGGER.info(repr(self))
 
-            # check whether to evaluate the model on the training, validation
-            # or test set
-            if self.test is None:
-                ds_set = 'train'
-            else:
-                ds_set = 'test' if self.test else 'valid'
+        # initialize training log
+        LOGGER.info(35 * '-' + ' Training ' + 35 * '-')
 
-            # the dataset to evaluate the model on
-            ds = self.model_state[
-                self.domain + '_{}_dl'.format(ds_set)].dataset
-            if ds is None:
-                raise ValueError('Requested dataset "{}" is not available.'
-                                 .format(self.domain + '_{}_dl'.format(ds_set))
-                                 )
+        # log the device and number of threads
+        LOGGER.info('Device: {}'.format(self.device))
+        LOGGER.info('Number of cpu threads: {}'.format(self.nthreads))
 
-            # log dataset representation
-            LOGGER.info('Evaluating on {} set of the {} domain defined at '
-                        'training time.'.format(ds_set, self.domain))
+    def train_source_domain(self, epoch):
+        """Train a model for a single epoch on the source domain.
 
-        else:
-            # explicitly defined dataset
-            ds = DatasetConfig(**self.ds).init_dataset()
+        Parameters
+        ----------
+        epoch : `int`
+            The current epoch.
 
-            # check if the spectral bands match
-            if ds.use_bands != self.model_state['bands']:
-                raise ValueError('The model was trained with bands {}, not '
-                                 'with bands {}.'.format(
-                                     self.model_state['bands'], ds.use_bands))
+        """
+        # iterate over the dataloader object
+        for batch, (inputs, labels) in enumerate(self.src_train_dl):
 
-            # split configuration
-            sc = SplitConfig(**self.ds_split)
-            train_ds, valid_ds, test_ds = sc.train_val_test_split(ds)
+            # send the data to the gpu if available
+            inputs = inputs.to(self.device)
+            labels = labels.to(self.device)
 
-            # check whether to evaluate the model on the training, validation
-            # or test set
-            if self.test is None:
-                ds = train_ds
-            else:
-                ds = test_ds if self.test else valid_ds
+            # reset the gradients
+            self.optimizer.zero_grad()
 
-            # log dataset representation
-            LOGGER.info('Evaluating on {} set of explicitly defined dataset: '
-                        '\n {}'.format(ds.name, repr(ds.dataset)))
+            # perform forward pass
+            outputs = self.model(inputs)
 
-        # check the dataset path: replace by path on current machine
-        self.replace_dataset_path(ds, DRIVE_PATH)
+            # compute loss
+            loss = self.cla_loss_function(outputs, labels.long())
 
-        return ds
+            # compute the gradients of the loss function w.r.t.
+            # the network weights
+            loss.backward()
 
-    @property
-    def source_labels(self):
-        """Class labels of the source domain the model was trained on.
+            # update the weights
+            self.optimizer.step()
 
-        Returns
-        -------
-        source_labels : `dict` [`int`, `dict`]
-            The class labels of the source domain.
+            # calculate predicted class labels
+            ypred = F.softmax(outputs, dim=1).argmax(dim=1)
+
+            # calculate accuracy on current batch
+            acc = accuracy_function(ypred, labels)
+
+            # print progress
+            LOGGER.info('Epoch: {:d}/{:d}, Mini-batch: {:d}/{:d}, '
+                        'Loss: {:.2f}, Accuracy: {:.2f}'
+                        .format(epoch + 1, self.epochs, batch + 1,
+                                self.tmbatch, loss.item(), acc))
+
+            # update training metrics
+            self.tracker.batch_update(self.tracker.train_metrics,
+                                      [loss.item(), acc])
+
+    def train_epoch(self, epoch):
+        """Train a model for a single epoch on the source domain.
+
+        Parameters
+        ----------
+        epoch : `int`
+            The current epoch.
 
         """
-        return self.src_ds.labels
+        self.train_source_domain(epoch)
 
-    @property
-    def target_labels(self):
-        """Class labels of the dataset to evaluate.
+    def train(self):
+        """Train the model.
 
         Returns
         -------
-        target_labels :  `dict` [`int`, `dict`]
-            The class labels of the dataset to evaluate.
+        training_state : `dict` [`str`, :py:class:`numpy.ndarray`]
+            The training state dictionary. The keys describe the type of the
+            training metric. See
+            :py:meth:`~pysegcnn.core.trainer.NetworkTrainer.training_state`.
 
         """
-        return self.trg_ds.dataset.labels
+        # initialize the training: iterate over the entire training dataset
+        for epoch in range(self.epochs):
 
-    # @property
-    # def label_map(self):
-    #     """Label mapping from the source to the target domain.
+            # set the model to training mode
+            LOGGER.info('Setting model to training mode ...')
+            self.model.train()
 
-    #     See :py:func:`pysegcnn.core.constants.map_labels`.
+            # train model for a single epoch
+            self.train_epoch(epoch)
 
-    #     Returns
-    #     -------
-    #     label_map : `dict` [`int`, `int`]
-    #         Dictionary with source labels as keys and corresponding target
-    #         labels as values.
+            # update the number of epochs trained
+            self.model.epoch += 1
 
-    #     """
-    #     # check whether the source domain labels are the same as the target
-    #     # domain labels
-    #     return map_labels(self.src_ds.get_labels(),
-    #                       self.trg_ds.dataset.get_labels())
+            # whether to evaluate model performance on the validation set and
+            # early stop the training process
+            if self.early_stop:
 
-    @property
-    def source_is_target(self):
-        """Whether the source and target domain labels are the same.
+                # model predictions on the validation set
+                valid_accu, valid_loss = self.predict(self.src_valid_dl)
 
-        Returns
-        -------
-        source_is_target : `bool`
-            `True` if the source and target domain labels are the same, `False`
-            if not.
+                # update validation metrics
+                self.tracker.batch_update(self.tracker.valid_metrics,
+                                          [valid_loss, valid_accu])
 
-        """
-        return self.label_map is None
+                # metric to assess model performance on the validation set
+                epoch_acc = np.mean(valid_accu)
 
-    @property
-    def apply_label_map(self):
-        """Whether to map source labels to target labels.
+                # whether the model improved with respect to the previous epoch
+                if self.es.increased(epoch_acc, self.max_accuracy, self.delta):
+                    self.max_accuracy = epoch_acc
 
-        Returns
-        -------
-        apply_label_map : `bool`
-            `True` if source and target labels differ and label mapping is
-            requested, `False` otherwise.
+                    # save model state if the model improved with
+                    # respect to the previous epoch
+                    if self.save:
+                        self.save_state()
 
-        """
-        return not self.source_is_target and self.map_labels
+                # whether the early stopping criterion is met
+                if self.es.stop(epoch_acc):
+                    break
 
-    @property
-    def use_labels(self):
-        """Labels to be predicted.
+            else:
+                # if no early stopping is required, the model state is
+                # saved after each epoch
+                if self.save:
+                    self.save_state()
 
-        Returns
-        -------
-        use_labels : `dict` [`int`, `dict`]
-            The labels of the classes to be predicted.
+        return self.training_state
 
-        """
-        return (self.target_labels if self.apply_label_map else
-                self.source_labels)
+    def predict(self, dataloader):
+        """Model inference at training time.
 
-    @property
-    def bands(self):
-        """Spectral bands the model was trained with.
+        Parameters
+        ----------
+        dataloader : :py:class:`torch.utils.data.DataLoader`
+            The validation dataloader to evaluate the model predictions.
 
         Returns
         -------
-        bands : `list` [`str`]
-            A list of the named spectral bands used to train the model.
+        accuracy : :py:class:`numpy.ndarray`
+            The mean model prediction accuracy on each mini-batch in the
+            validation set.
+        loss : :py:class:`numpy.ndarray`
+            The model loss for each mini-batch in the validation set.
 
         """
-        return self.src_ds.use_bands
+        # set the model to evaluation mode
+        LOGGER.info('Setting model to evaluation mode ...')
+        self.model.eval()
+
+        # create arrays of the observed loss and accuracy
+        accuracy = []
+        loss = []
+
+        # iterate over the validation/test set
+        LOGGER.info('Calculating accuracy on the validation set ...')
+        for batch, (inputs, labels) in enumerate(dataloader):
+
+            # send the data to the gpu if available
+            inputs = inputs.to(self.device)
+            labels = labels.to(self.device)
+
+            # calculate network outputs
+            with torch.no_grad():
+                outputs = self.model(inputs)
+
+            # compute loss
+            cla_loss = self.cla_loss_function(outputs, labels.long())
+            loss.append(cla_loss.item())
+
+            # calculate predicted class labels
+            pred = F.softmax(outputs, dim=1).argmax(dim=1)
+
+            # calculate accuracy on current batch
+            acc = accuracy_function(pred, labels)
+            accuracy.append(acc)
+
+            # print progress
+            LOGGER.info('Mini-batch: {:d}/{:d}, Accuracy: {:.2f}'
+                        .format(batch + 1, len(dataloader), acc))
+
+        # calculate overall accuracy on the validation/test set
+        LOGGER.info('Epoch: {:d}, Mean accuracy: {:.2f}%.'
+                    .format(self.model.epoch, np.mean(accuracy) * 100))
+
+        return accuracy, loss
+
+    def save_state(self):
+        """Save the model state."""
+        _ = self.model.save(self.state_file,
+                            self.optimizer,
+                            state=self.training_state,
+                            **self.params_to_save)
 
     @property
-    def compute_cm(self):
-        """Whether to compute the confusion matrix.
+    def training_state(self):
+        """Model training metrics.
 
         Returns
         -------
-        compute_cm : `bool`
-            Whether the confusion matrix can be computed. For datasets with
-            labels different from the source domain labels, the confusion
-            matrix can not be computed.
+        state : `dict` [`str`, :py:class:`numpy.ndarray`]
+            The training state dictionary. The keys describe the type of the
+            training metric and the values are :py:class:`numpy.ndarray`'s of
+            the corresponding metric observed during training with
+            shape=(mini_batch, epoch).
 
         """
-        return (False if not self.source_is_target and not self.map_labels
-                else self.cm)
+        # current training state
+        state = self.tracker.np_state(self.tmbatch, self.vmbatch)
+
+        # optional: training state of the model checkpoint
+        if self.checkpoint_state:
+            # prepend values from checkpoint to current training state
+            state = {k1: np.hstack([v1, v2]) for (k1, v1), (k2, v2) in
+                     zip(self.checkpoint_state.items(), state.items())
+                     if k1 == k2}
+
+        return state
 
     @property
-    def plot(self):
-        """Whether to save plots of (input, ground truth, prediction).
+    def params_to_save(self):
+        """The parameters and variables to save in the model state file."""
+        return {'src_train_dl': self.src_train_dl,
+                'src_valid_dl': self.src_valid_dl,
+                'src_test_dl': self.src_test_dl}
+
+    def _build_model_repr_(self):
+        """Build the model representation.
 
         Returns
         -------
-        plot : `bool`
-            Save plots for each sample or for each scene of the target dataset,
-            depending on ``self.predict_scene``.
+        fs : `str`
+            Representation string.
 
         """
-        return self.plot_scenes if self.predict_scene else self.plot_samples
+        # model
+        fs = '\n    (model):' + '\n' + 8 * ' '
+        fs += ''.join(repr(self.model)).replace('\n', '\n' + 8 * ' ')
 
-    @property
-    def is_scene_subset(self):
-        """Check the type of the target dataset.
+        # optimizer
+        fs += '\n    (optimizer):' + '\n' + 8 * ' '
+        fs += ''.join(repr(self.optimizer)).replace('\n', '\n' + 8 * ' ')
 
-        Whether ``self.trg_ds`` is an instance of
-        :py:class:`pysegcnn.core.split.SceneSubset`, as required when
-        ``self.predict_scene=True``.
+        # loss function
+        fs += '\n    (loss function):' + '\n' + 8 * ' '
+        fs += ''.join(repr(self.cla_loss_function)).replace('\n',
+                                                            '\n' + 8 * ' ')
+
+        # early stopping
+        fs += '\n    (early stop):' + '\n' + 8 * ' '
+        fs += ''.join(repr(self.es)).replace('\n', '\n' + 8 * ' ')
+
+        return fs
+
+     def __repr__(self):
+        """Representation.
 
         Returns
         -------
-        is_scene_subset : `bool`
-            Whether ``self.trg_ds`` is an instance of
-            :py:class:`pysegcnn.core.split.SceneSubset`.
+        fs : `str`
+            Representation string.
 
         """
-        return isinstance(self.trg_ds, SceneSubset)
+        # representation string to print
+        fs = self.__class__.__name__ + '(\n'
 
-    @property
-    def dataloader(self):
-        """Dataloader instance for model inference.
+        # model configuration
+        fs += self._build_model_repr_()
 
-        Returns
-        -------
-        dataloader : :py:class:`torch.utils.data.DataLoader`
-            The dataset for model inference.
+        fs += '\n)'
+        return fs
 
-        """
-        # build the dataloader for model inference
-        return DataLoader(self.trg_ds, batch_size=self._batch_size,
-                          shuffle=False, drop_last=False)
 
-    @property
-    def _original_source_labels(self):
-        """Original source domain labels.
+@dataclasses.dataclass
+class MultispectralImageSegmentationTrainer(ClassificationNetworkTrainer):
+    """Model training class for multispectral image segmentation.
 
-        Since PyTorch requires class labels to be an ascending sequence
-        starting from 0, the actual class labels in the ground truth may differ
-        from the class labels fed to the model.
+    Train an instance of :py:class:`pysegcnn.core.models.EncoderDecoderNetwork`
+    on an instance of :py:class:`pysegcnn.core.dataset.ImageDataset`.
 
-        Returns
-        -------
-        original_source_labels : `dict` [`int`, `dict`]
-            The original class labels of the source domain.
+    Attributes
+    ----------
+    trg_train_dl : :py:class:`torch.utils.data.DataLoader`
+        The target domain training :py:class:`torch.utils.data.DataLoader`
+        instance build from an instance of
+        :py:class:`pysegcnn.core.split.CustomSubset`. The default is an empty
+        :py:class:`torch.utils.data.DataLoader`.
+    trg_valid_dl : :py:class:`torch.utils.data.DataLoader`
+        The target domain validation :py:class:`torch.utils.data.DataLoader`
+        instance build from an instance of
+        :py:class:`pysegcnn.core.split.CustomSubset`. The default is an empty
+        :py:class:`torch.utils.data.DataLoader`.
+    trg_test_dl : :py:class:`torch.utils.data.DataLoader`
+        The target domain test :py:class:`torch.utils.data.DataLoader`
+        instance build from an instance of
+        :py:class:`pysegcnn.core.split.CustomSubset`. The default is an empty
+        :py:class:`torch.utils.data.DataLoader`.
+    uda_loss_function : :py:class:`torch.nn.Module`
+        The domain adaptation loss function. An instance of
+        :py:class:`torch.nn.Module`.
+        The default is :py:class:`pysegcnn.core.uda.CoralLoss`.
+    uda_lambda : `float`
+        The weight of the domain adaptation, trading off adaptation with
+        classification accuracy on the source domain. The default is `0`.
+    uda_pos : `str`
+        The layer where to compute the domain adaptation loss. The default
+        is `'enc'`, i.e. compute the domain adaptation loss using the output of
+        the model encoder.
+    uda : `bool`
+        Whether to train using deep domain adaptation.
 
-        """
-        return self.src_ds._labels
+    """
 
-    @property
-    def _original_target_labels(self):
-        """Original target domain labels.
+    trg_train_dl: DataLoader = DataLoader(None)
+    trg_valid_dl: DataLoader = DataLoader(None)
+    trg_test_dl: DataLoader = DataLoader(None)
+    uda_loss_function: nn.Module = CoralLoss(uda_lambda=0)
+    uda_lambda: float = 0
+    uda_pos: str = 'enc'
 
-        Returns
-        -------
-        original_target_labels : `dict` [`int`, `dict`]
-            The original class labels of the target domain.
+    def __post_init__(self):
+        """Check the type of each argument.
 
-        """
-        return self.trg_ds.dataset._labels
+        Configure the device to train the model on, i.e. train on the gpu if
+        available.
 
-    @property
-    def _label_log(self):
-        """Log if a label mapping is applied.
+        Configure early stopping if required.
 
-        Returns
-        -------
-        log : `str`
-            Represenation of the label mapping.
+        Initialize training metric tracking.
 
         """
-        log = 'Retaining model labels ({})'.format(
-            ', '.join([v['label'] for _, v in self.source_labels.items()]))
-        if self.apply_label_map:
-            log = (log.replace('Retaining', 'Mapping') + ' to target labels '
-                   '({}).'.format(', '.join([v['label'] for _, v in
-                                  self.target_labels.items()])))
-        return log
-
-    @property
-    def _batch_size(self):
-        """Batch size of the inference dataloader.
+        super().__post_init__()
 
-        Returns
-        -------
-        batch_size : `int`
-            The batch size of the dataloader used for model inference. Depends
-            on whether to predict each sample of the target dataset
-            individually or whether to reconstruct each scene in the target
-            dataset.
+        # whether to train using supervised transfer learning or
+        # deep domain adaptation
 
-        """
-        return (self.trg_ds.dataset.tiles if self.predict_scene and
-                self.is_scene_subset else 1)
+        # dummy variables for easy model evaluation
+        self.uda = False
+        if self.trg_train_dl.dataset is not None and self.uda_lambda > 0:
 
-    def _check_long_filename(self, filename):
-        """Modify filenames that exceed Windows' maximum filename length.
+            # set the device for computing domain adaptation loss
+            self.uda_loss_function.device = self.device
 
-        Parameters
-        ----------
-        filename : `str`
-            The filename to check.
+            # adjust metrics and initialize metric tracker
+            self.tracker.train_metrics.extend(['cla_loss', 'uda_loss'])
 
-        Returns
-        -------
-        filename : `str`
-            The modified filename, in case ``filename`` exceeds 255 characters.
+            # train using deep domain adaptation
+            self.uda = True
 
-        """
-        # check for maximum path component length: Windows allows a maximum
-        # of 255 characters for each component in a path
-        if len(filename) >= 255:
-            # try to parse the scene identifier
-            scene_metadata = self.trg_ds.dataset.parse_scene_id(filename)
-            if scene_metadata is not None:
-                # new, shorter name for the current batch
-                batch_name = '_'.join([scene_metadata['satellite'],
-                                       scene_metadata['tile'],
-                                       datetime.datetime.strftime(
-                                           scene_metadata['date'], '%Y%m%d')])
-                filename = filename.replace(scene_metadata['id'], batch_name)
+    def _inp_uda(self, src_input, trg_input):
+        """Domain adaptation at input feature level."""
 
-        return filename
+        # perform forward pass: classified source domain features
+        src_prdctn = self.model(src_input)
 
-    def map_to_target(self, prd):
-        """Map source domain labels to target domain labels.
+        return src_input, trg_input, src_prdctn
 
-        Parameters
-        ----------
-        prd : :py:class:`torch.Tensor`
-            The source domain class labels as predicted by ```self.model``.
+    def _enc_uda(self, src_input, trg_input):
+        """Domain adaptation at encoder feature level."""
 
-        Returns
-        -------
-        prd : :py:class:`torch.Tensor`
-            The predicted target domain labels.
+        # perform forward pass: encoded source domain features
+        src_feature = self.model.encoder(src_input)
+        src_dec_feature = self.model.decoder(src_feature,
+                                             self.model.encoder.cache)
+        # model logits on source domain
+        src_prdctn = self.model.classifier(src_dec_feature)
+        del self.model.encoder.cache  # clear intermediate encoder outputs
 
-        """
-        # map actual source labels to original source labels
-        for aid, oid in zip(self.source_labels.keys(),
-                            self._original_source_labels.keys()):
-            prd[torch.where(prd == aid)] = oid
+        # perform forward pass: encoded target domain features
+        trg_feature = self.model.encoder(trg_input)
 
-        # apply the label mapping
-        for src_label, trg_label in self.label_map.items():
-            prd[torch.where(prd == src_label)] = trg_label
+        return src_feature, trg_feature, src_prdctn
 
-        # map original target labels to actual target labels
-        for oid, aid in zip(self._original_target_labels.keys(),
-                            self.target_labels.keys()):
-            prd[torch.where(prd == oid)] = aid
+    def _dec_uda(self, src_input, trg_input):
+        """Domain adaptation at decoder feature level."""
 
-        return prd
+        # perform forward pass: decoded source domain features
+        src_feature = self.model.encoder(src_input)
+        src_feature = self.model.decoder(src_feature,
+                                         self.model.encoder.cache)
+        # model logits on source domain
+        src_prdctn = self.model.classifier(src_feature)
+        del self.model.encoder.cache  # clear intermediate encoder outputs
 
-    def predict(self):
-        """Classify the samples of the target dataset.
+        # perform forward pass: decoded target domain features
+        trg_feature = self.model.encoder(trg_input)
+        trg_feature = self.model.decoder(trg_feature,
+                                         self.model.encoder.cache)
+        del self.model.encoder.cache
 
-        Returns
-        -------
-        output : `dict` [`str`, `dict`]
-            The inference output dictionary. The keys are either the number of
-            the samples (``self.predict_scene=False``) or the name of the
-            scenes of the target dataset (``self.predict_scene=True``). The
-            values are dictionaries with keys:
-                ``'input'``
-                    Model input data of the sample (:py:class:`numpy.ndarray`).
-                ``'labels'
-                    Ground truth class labels (:py:class:`numpy.ndarray`).
-                ``'prediction'``
-                    Model prediction class labels (:py:class:`numpy.ndarray`).
+        return src_feature, trg_feature, src_prdctn
 
-        """
-        # iterate over the samples of the target dataset
-        output = {}
-        for batch, (inputs, labels) in enumerate(self.dataloader):
+    def _cla_uda(self, src_input, trg_input):
+        """Domain adaptation at classifier feature level."""
 
-            # send inputs and labels to device
-            inputs = inputs.to(self.device)
-            labels = labels.to(self.device)
+        # perform forward pass: classified source domain features
+        src_feature = self.model(src_input)
 
-            # compute model predictions
-            with torch.no_grad():
-                prdctn = F.softmax(self.model(inputs),
-                                   dim=1).argmax(dim=1).squeeze()
+        # perform forward pass: target domain features
+        trg_feature = self.model(trg_input)
 
-            # map source labels to target dataset labels
-            if self.apply_label_map:
-                prdctn = self.map_to_target(prdctn)
+        return src_feature, trg_feature, src_feature
 
-            # update confusion matrix
-            if self.compute_cm:
-                for ytrue, ypred in zip(labels.view(-1), prdctn.view(-1)):
-                    self.conf_mat[ytrue.long(), ypred.long()] += 1
+    def uda_frwd(self, src_input, trg_input):
+        """Forward function for deep domain adaptation.
 
-            # convert torch tensors to numpy arrays
-            inputs = inputs.numpy()
-            labels = labels.numpy()
-            prdctn = prdctn.numpy()
+        Parameters
+        ----------
+        src_input : :py:class:`torch.Tensor`
+            Source domain input features.
+        trg_input : :py:class:`torch.Tensor`
+            Target domain input features.
 
-            # progress string to log
-            progress = 'Sample: {:d}/{:d}'.format(batch + 1,
-                                                  len(self.dataloader))
+        """
+        if self.uda_pos == 'inp':
+            self._inp_uda(src_input, trg_input)
 
-            # check whether to reconstruct the scene
-            date = None
-            if self.dataloader.batch_size > 1:
+        if self.uda_pos == 'enc':
+            self._enc_uda(src_input, trg_input)
 
-                # id and date of the current scene
-                batch = self.trg_ds.ids[batch]
-                if self.trg_ds.split_mode == 'date':
-                    date = self.trg_ds.dataset.parse_scene_id(batch)['date']
+        if self.uda_pos == 'dec':
+            self._dec_uda(src_input, trg_input)
 
-                # modify the progress string
-                progress = progress.replace('Sample', 'Scene')
-                progress += ' Id: {}'.format(batch)
+        if self.uda_pos == 'cla':
+            self._cla_uda(src_input, trg_input)
 
-                # reconstruct the entire scene
-                inputs = reconstruct_scene(inputs)
-                labels = reconstruct_scene(labels)
-                prdctn = reconstruct_scene(prdctn)
+    def train_domain_adaptation(self, epoch):
+        """Train a model for an epoch on the source and target domain.
 
-            # save current batch to output dictionary
-            output[batch] = {'input': inputs, 'labels': labels,
-                             'prediction': prdctn}
-
-            # filename for the plot of the current batch
-            batch_name = self.basename + '_{}_{}.pt'.format(self.trg_ds.name,
-                                                            batch)
+        This function implements deep domain adaptation by extending the
+        standard classification loss by a "domain adaptation loss" calculated
+        from unlabelled target domain samples.
 
-            # check if the current batch name exceeds the Windows limit of
-            # 255 characters
-            batch_name = self._check_long_filename(batch_name)
+        Parameters
+        ----------
+        epoch : `int`
+            The current epoch.
 
-            # in case the source and target class labels are the same or a
-            # label mapping is applied, the accuracy of the prediction can be
-            # calculated
-            if self.source_is_target or self.apply_label_map:
-                progress += ', Accuracy: {:.2f}'.format(
-                    accuracy_function(prdctn, labels))
-            LOGGER.info(progress)
+        """
+        # create target domain iterator
+        target = iter(self.trg_train_dl)
 
-            # plot current scene
-            if self.plot:
+        # increase domain adaptation weight with increasing epochs
+        uda_lambda = self.uda_lambda * ((epoch + 1) / self.epochs)
 
-                # in case the source and target class labels are the same or a
-                # label mapping is applied, plot the ground truth, otherwise
-                # just plot the model prediction
-                gt = None
-                if self.source_is_target or self.apply_label_map:
-                    gt = labels
+        # iterate over the number of samples
+        for batch, (src_input, src_label) in enumerate(self.src_train_dl):
 
-                # plot inputs, ground truth and model predictions
-                plot_sample(inputs.clip(0, 1),
-                            self.bands,
-                            self.use_labels,
-                            y=gt,
-                            y_pred={self.model.__class__.__name__: prdctn},
-                            accuracy=True,
-                            state=batch_name,
-                            date=date,
-                            fig=self.fig,
-                            plot_path=self.scenes_path,
-                            **self.kwargs)
+            # get the target domain input data
+            try:
+                trg_input, _ = target.next()
+            # in case the iterator is finished, re-instanciate it
+            except StopIteration:
+                target = iter(self.trg_train_dl)
+                trg_input, _ = target.next()
 
-                # save current figure state as frame for animation
-                if self.animate:
-                    self.anim.frame(self.fig.axes)
+            # send the data to the gpu if available
+            src_input, src_label = (src_input.to(self.device),
+                                    src_label.to(self.device))
+            trg_input = trg_input.to(self.device)
 
-        # save animation
-        if self.animate:
-            self.anim.animate(self.fig, interval=1000, repeat=True, blit=True)
-            self.anim.save(self.basename + '_{}.gif'.format(self.trg_ds.name),
-                           dpi=200)
+            # reset the gradients
+            self.optimizer.zero_grad()
 
-        return output
+            # forward pass
+            src_feature, trg_feature, src_prdctn = self.uda_forward(src_input,
+                                                                    trg_input)
 
-    def evaluate(self):
-        """Evaluate a pretrained model on a defined dataset.
+            # compute classification loss
+            cla_loss = self.cla_loss_function(src_prdctn, src_label.long())
 
-        Returns
-        -------
-        output : `dict` [`str`, `dict`]
-            The inference output dictionary. The keys are either the number of
-            the samples (``self.predict_scene=False``) or the name of the
-            scenes of the target dataset (``self.predict_scene=True``). The
-            values are dictionaries with keys:
-                ``'input'``
-                    Model input data of the sample (:py:class:`numpy.ndarray`).
-                ``'labels'
-                    Ground truth class labels (:py:class:`numpy.ndarray`).
-                ``'prediction'``
-                    Model prediction class labels (:py:class:`numpy.ndarray`).
+            # compute domain adaptation loss:
+            # the difference between source and target domain is computed
+            # from the compressed representation of the model encoder
+            uda_loss = self.uda_loss_function(src_feature, trg_feature)
 
-        """
-        # plot loss and accuracy
-        plot_loss(check_filename_length(self.state_file),
-                  outpath=self.perfmc_path)
+            # total loss
+            tot_loss = cla_loss + uda_lambda * uda_loss
 
-        # set the model to evaluation mode
-        LOGGER.info('Setting model to evaluation mode ...')
-        self.model.eval()
-        self.model.to(self.device)
+            # compute the gradients of the loss function w.r.t.
+            # the network weights
+            tot_loss.backward()
 
-        # initialize confusion matrix
-        self.conf_mat = np.zeros(shape=2 * (len(self.use_labels), ))
+            # update the weights
+            self.optimizer.step()
 
-        # log which labels the model predicts
-        LOGGER.info(self._label_log)
+            # calculate predicted class labels
+            ypred = F.softmax(src_prdctn, dim=1).argmax(dim=1)
 
-        # evaluate the model on the target dataset
-        output = self.predict()
+            # calculate accuracy on current batch
+            acc = accuracy_function(ypred, src_label)
 
-        # whether to plot the confusion matrix
-        if self.compute_cm:
-            plot_confusion_matrix(self.conf_mat, self.use_labels,
-                                  state_file=self.state_file,
-                                  subset=self.domain + '_' + self.trg_ds.name,
-                                  outpath=self.perfmc_path)
+            # print progress
+            LOGGER.info('Epoch: {:d}/{:d}, Mini-batch: {:d}/{:d}, '
+                        'Cla_loss: {:.2f}, Uda_loss: {:.2f}, '
+                        'Tot_loss: {:.2f}, Acc: {:.2f}'
+                        .format(epoch + 1, self.epochs, batch + 1,
+                                self.tmbatch, cla_loss.item(),
+                                uda_loss.item(), tot_loss.item(), acc))
 
-        return output
+            # update training metrics
+            self.tracker.batch_update(self.tracker.train_metrics,
+                                      [tot_loss.item(), acc,
+                                       cla_loss.item(), uda_loss.item()])
 
+    def train_epoch(self, epoch):
+        """Wrap the function to train a model for a single epoch.
 
-@dataclasses.dataclass
-class LogConfig(BaseConfig):
-    """Logging configuration class.
+        Depends on whether to apply deep domain adaptation.
 
-    Generate the model log file.
+        Parameters
+        ----------
+        epoch : `int`
+            The current epoch.
 
-    Attributes
-    ----------
-    state_file : :py:class:`pathlib.Path`
-        Path to a model state file.
-    log_path : :py:class:`pathlib.Path`
-        Path to store model logs.
-    log_file : :py:class:`pathlib.Path`
-        Path to the log file of the model ``state_file``.
-    """
+        Returns
+        -------
+        `function`
+            The function to train a model for a single epoch.
 
-    state_file: pathlib.Path
+        """
+        if self.uda:
+            self.train_domain_adaptation(epoch)
+        else:
+            self.train_source_domain(epoch)
 
-    def __post_init__(self):
-        """Check the type of each argument.
+    @property
+    def params_to_save(self):
+        """The parameters and variables to save in the model state file."""
+        return {'src_train_dl': self.src_train_dl,
+                'src_valid_dl': self.src_valid_dl,
+                'src_test_dl': self.src_test_dl,
+                'trg_train_dl': self.trg_train_dl,
+                'trg_valid_dl': self.trg_valid_dl,
+                'trg_test_dl': self.trg_test_dl,
+                'uda': self.uda,
+                'uda_pos': self.uda_pos,
+                'uda_lambda': self.uda_lambda}
+
+    def _build_ds_repr(self, train_dl, valid_dl, test_dl):
+        """Build the dataset representation.
 
-        Generate model log file.
+        Returns
+        -------
+        fs : `str`
+            Representation string.
 
         """
-        super().__post_init__()
-
-        # the path to store model logs
-        self.log_path = pathlib.Path(HERE).joinpath('_logs')
+        # dataset configuration
+        fs = '    (dataset):\n        '
+        fs += ''.join(repr(train_dl.dataset.dataset)).replace('\n',
+                                                              '\n' + 8 * ' ')
+        fs += '\n    (batch):\n        '
+        fs += '- batch size: {}\n        '.format(train_dl.batch_size)
+        fs += '- mini-batch shape (b, c, h, w): {}'.format(
+            ((train_dl.batch_size, len(train_dl.dataset.dataset.use_bands),) +
+              2 * (train_dl.dataset.dataset.tile_size,)))
+
+        # dataset split
+        fs += '\n    (split):'
+        for dl in [train_dl, valid_dl, test_dl]:
+            if dl.dataset is not None:
+                fs += '\n' + 8 * ' ' + repr(dl.dataset)
 
-        # the log file of the current model
-        self.log_file = check_filename_length(self.log_path.joinpath(
-            self.state_file.name.replace('.pt', '.log')))
+        return fs
 
-    @staticmethod
-    def now():
-        """Return the current date and time.
+    def __repr__(self):
+        """Representation.
 
         Returns
         -------
-        date : :py:class:`datetime.datetime`
-            The current date and time.
+        fs : `str`
+            Representation string.
 
         """
-        return datetime.datetime.strftime(datetime.datetime.now(),
-                                          '%Y-%m-%dT%H:%M:%S')
+        # representation string to print
+        fs = self.__class__.__name__ + '(\n'
 
-    @staticmethod
-    def init_log(init_str):
-        """Generate a string to identify a new model run.
+        # source domain
+        fs += '    (source domain)\n    '
+        fs += self._build_ds_repr(
+            self.src_train_dl, self.src_valid_dl, self.src_test_dl).replace(
+                '\n', '\n' + 4 * ' ')
 
-        Parameters
-        ----------
-        init_str : `str`
-            The string to write to the model log file.
+        # target domain
+        if self.uda:
+            fs += '\n    (target domain)\n    '
+            fs += self._build_ds_repr(
+                self.trg_train_dl,
+                self.trg_valid_dl,
+                self.trg_test_dl).replace('\n', '\n' + 4 * ' ')
 
-        """
-        LOGGER.info(80 * '-')
-        LOGGER.info(init_str.format(LogConfig.now()))
-        LOGGER.info(80 * '-')
+        # model configuration
+        fs += self._build_model_repr_()
+
+        # domain adaptation
+        if self.uda:
+            fs += '\n    (adaptation)' + '\n' + 8 * ' '
+            fs += repr(self.uda_loss_function).replace('\n', '\n' + 8 * ' ')
+
+        fs += '\n)'
+        return fs
 
 
 @dataclasses.dataclass
@@ -1950,1017 +1914,1157 @@ class MetricTracker(BaseConfig):
                    for k in self.valid_metrics}}
 
 
-@dataclasses.dataclass
-class NetworkTrainer(BaseConfig):
-    """Model training class.
+class EarlyStopping(object):
+    """`Early Stopping`_ algorithm.
+
+    This implementation of the early stopping algorithm advances a counter each
+    time a metric did not improve over a training epoch. If the metric does not
+    improve over more than ``patience`` epochs, the early stopping criterion is
+    met.
 
-    Train an instance of :py:class:`pysegcnn.core.models.Network` on a dataset
-    of type :py:class:`pysegcnn.core.dataset.ImageDataset`.
+    See the :py:meth:`pysegcnn.core.trainer.NetworkTrainer.train` method for an
+    example implementation.
 
-    Supports training a model on a single source domain only and on a source
-    and target domain using deep domain adaptation.
+    .. _Early Stopping:
+        https://en.wikipedia.org/wiki/Early_stopping
 
     Attributes
     ----------
-    model : :py:class:`pysegcnn.core.models.Network`
-        The model to train. An instance of
-        :py:class:`pysegcnn.core.models.Network`.
-    optimizer : :py:class:`torch.optim.Optimizer`
-        The optimizer to update the model weights. An instance of
-        :py:class:`torch.optim.Optimizer`.
-    state_file : :py:class:`pathlib.Path`
-        Path to save the model state.
-    bands : `list` [`str`]
-        The spectral bands used to train ``model``.
-    src_train_dl : :py:class:`torch.utils.data.DataLoader`
-        The source domain training :py:class:`torch.utils.data.DataLoader`
-        instance build from an instance of
-        :py:class:`pysegcnn.core.split.CustomSubset`.
-    src_valid_dl : :py:class:`torch.utils.data.DataLoader`
-        The source domain validation :py:class:`torch.utils.data.DataLoader`
-        instance build from an instance of
-        :py:class:`pysegcnn.core.split.CustomSubset`.
-    src_test_dl : :py:class:`torch.utils.data.DataLoader`
-        The source domain test :py:class:`torch.utils.data.DataLoader`
-        instance build from an instance of
-        :py:class:`pysegcnn.core.split.CustomSubset`.
-    cla_loss_function : :py:class:`torch.nn.Module`
-        The classification loss function to compute the model error. An
-        instance of :py:class:`torch.nn.Module`.
-    trg_train_dl : :py:class:`torch.utils.data.DataLoader`
-        The target domain training :py:class:`torch.utils.data.DataLoader`
-        instance build from an instance of
-        :py:class:`pysegcnn.core.split.CustomSubset`. The default is an empty
-        :py:class:`torch.utils.data.DataLoader`.
-    trg_valid_dl : :py:class:`torch.utils.data.DataLoader`
-        The target domain validation :py:class:`torch.utils.data.DataLoader`
-        instance build from an instance of
-        :py:class:`pysegcnn.core.split.CustomSubset`. The default is an empty
-        :py:class:`torch.utils.data.DataLoader`.
-    trg_test_dl : :py:class:`torch.utils.data.DataLoader`
-        The target domain test :py:class:`torch.utils.data.DataLoader`
-        instance build from an instance of
-        :py:class:`pysegcnn.core.split.CustomSubset`. The default is an empty
-        :py:class:`torch.utils.data.DataLoader`.
-    uda_loss_function : :py:class:`torch.nn.Module`
-        The domain adaptation loss function. An instance of
-        :py:class:`torch.nn.Module`.
-        The default is :py:class:`pysegcnn.core.uda.CoralLoss`.
-    uda_lambda : `float`
-        The weight of the domain adaptation, trading off adaptation with
-        classification accuracy on the source domain. The default is `0`.
-    uda_pos : `str`
-        The layer where to compute the domain adaptation loss. The default
-        is `'enc'`, i.e. compute the domain adaptation loss using the output of
-        the model encoder.
-    epochs : `int`
-        The maximum number of epochs to train. The default is `1`.
-    nthreads : `int`
-        The number of cpu threads to use during training. The default is
-        :py:func:`torch.get_num_threads()`.
-    early_stop : `bool`
-        Whether to apply `Early Stopping`_. The default is `False`.
     mode : `str`
-        The early stopping mode. Depends on the metric measuring
-        performance. When using model loss as metric, use ``mode='min'``,
-        however, when using accuracy as metric, use ``mode='max'``. For now,
-        only ``mode='max'`` is supported. Only used if ``early_stop=True``.
-        The default is `'max'`.
-    delta : `float`
+        The early stopping mode.
+    best : `float`
+        Best metric score.
+    min_delta : `float`
         Minimum change in early stopping metric to be considered as an
-        improvement. Only used if ``early_stop=True``. The default is `0`.
+        improvement.
     patience : `int`
-        The number of epochs to wait for an improvement in the early stopping
-        metric. If the model does not improve over more than ``patience``
-        epochs, quit training. Only used if ``early_stop=True``. The default is
-        `10`.
-    checkpoint_state : `dict` [`str`, :py:class:`numpy.ndarray`]
-        A model checkpoint for ``model``. If specified, ``checkpoint_state``
-        should be a dictionary with keys describing the training metric.
-        The default is `{}`.
-    save : `bool`
-        Whether to save the model state to ``state_file``. The default is
-        `True`.
-    device : `str`
-        The device to train the model on, i.e. `cpu` or `cuda`.
-    tracker : :py:class:`pysegcnn.core.trainer.MetricTracker`
-        A :py:class:`pysegcnn.core.trainer.MetricTracker` instance tracking
-        training metrics, i.e. loss and accuracy.
-    uda : `bool`
-        Whether to apply deep domain adaptation.
-    max_accuracy : `float`
-        Maximum accuracy of ``model`` on the validation dataset.
-    es : `None` or :py:class:`pysegcnn.core.trainer.EarlyStopping`
-        The early stopping instance if ``early_stop=True``, else `None`.
-    tmbatch : `int`
-        Number of mini-batches in the training dataset.
-    vmbatch : `int`
-        Number of mini-batches in the validation dataset.
-    training_state : `dict` [`str`, :py:class:`numpy.ndarray`]
-        The training state dictionary. The keys describe the type of the
-        training metric.
-
-    .. _Early Stopping:
-        https://en.wikipedia.org/wiki/Early_stopping
+        The number of epochs to wait for an improvement.
+    is_better : `function`
+        Function indicating whether the metric improved.
+    early_stop : `bool`
+        Whether the early stopping criterion is met.
+    counter : `int`
+        The counter advancing each time a metric does not improve.
 
     """
 
-    model: Network
-    optimizer: Optimizer
-    state_file: pathlib.Path
-    bands: list
-    src_train_dl: DataLoader
-    src_valid_dl: DataLoader
-    src_test_dl: DataLoader
-    cla_loss_function: nn.Module
-    trg_train_dl: DataLoader = DataLoader(None)
-    trg_valid_dl: DataLoader = DataLoader(None)
-    trg_test_dl: DataLoader = DataLoader(None)
-    uda_loss_function: nn.Module = CoralLoss(uda_lambda=0)
-    uda_lambda: float = 0
-    uda_pos: str = 'enc'
-    epochs: int = 1
-    nthreads: int = torch.get_num_threads()
-    early_stop: bool = False
-    mode: str = 'max'
-    delta: float = 0
-    patience: int = 10
-    checkpoint_state: dict = dataclasses.field(default_factory={})
-    save: bool = True
+    def __init__(self, mode='max', best=0, min_delta=0, patience=10):
+        """Initialize.
+
+        Parameters
+        ----------
+        mode : `str`, optional
+            The early stopping mode. Depends on the metric measuring
+            performance. When using model loss as metric, use ``mode='min'``,
+            however, when using accuracy as metric, use ``mode='max'``. For
+            now, only ``mode='max'`` is supported. Only used if
+            ``early_stop=True``. The default is `'max'`.
+        best : `float`, optional
+            Threshold indicating the best metric score. At instanciation, set
+            ``best`` to the worst possible score of the metric. ``best`` will
+            be overwritten during training. The default is `0`.
+        min_delta : `float`, optional
+            Minimum change in early stopping metric to be considered as an
+            improvement. Only used if ``early_stop=True``. The default is `0`.
+        patience : `int`, optional
+            The number of epochs to wait for an improvement in the early
+            stopping metric. If the model does not improve over more than
+            ``patience`` epochs, quit training. Only used if
+            ``early_stop=True``. The default is `10`.
+
+        Raises
+        ------
+        ValueError
+            Raised if ``mode`` is not either 'min' or 'max'.
+
+        """
+        # check if mode is correctly specified
+        if mode not in ['min', 'max']:
+            raise ValueError('Mode "{}" not supported. '
+                             'Mode is either "min" (check whether the metric '
+                             'decreased, e.g. loss) or "max" (check whether '
+                             'the metric increased, e.g. accuracy).'
+                             .format(mode))
+
+        # mode to determine if metric improved
+        self.mode = mode
+
+        # whether to check for an increase or a decrease in a given metric
+        self.is_better = self.decreased if mode == 'min' else self.increased
+
+        # minimum change in metric to be considered as an improvement
+        self.min_delta = min_delta
+
+        # number of epochs to wait for improvement
+        self.patience = patience
+
+        # initialize best metric
+        self.best = best
+
+        # initialize early stopping flag
+        self.early_stop = False
+
+        # initialize the early stop counter
+        self.counter = 0
+
+    def stop(self, metric):
+        """Advance early stopping counter.
+
+        Parameters
+        ----------
+        metric : `float`
+            The current metric score.
+
+        Returns
+        -------
+        early_stop : `bool`
+            Whether the early stopping criterion is met.
+
+        """
+        # if the metric improved, reset the epochs counter, else, advance
+        if self.is_better(metric, self.best, self.min_delta):
+            self.counter = 0
+            self.best = metric
+        else:
+            self.counter += 1
+            LOGGER.info('Early stopping counter: {}/{}'.format(
+                self.counter, self.patience))
+
+        # if the metric did not improve over the last patience epochs,
+        # the early stopping criterion is met
+        if self.counter >= self.patience:
+            LOGGER.info('Early stopping criterion met, stopping training.')
+            self.early_stop = True
+
+        return self.early_stop
+
+    def decreased(self, metric, best, min_delta):
+        """Whether a metric decreased with respect to a best score.
+
+        Measure improvement for metrics that are considered as 'better' when
+        they decrease, e.g. model loss, mean squared error, etc.
+
+        Parameters
+        ----------
+        metric : `float`
+            The current score.
+        best : `float`
+            The current best score.
+        min_delta : `float`
+            Minimum change to be considered as an improvement.
+
+        Returns
+        -------
+        `bool`
+            Whether the metric improved.
+
+        """
+        return metric < best - min_delta
+
+    def increased(self, metric, best, min_delta):
+        """Whether a metric increased with respect to a best score.
+
+        Measure improvement for metrics that are considered as 'better' when
+        they increase, e.g. accuracy, precision, recall, etc.
+
+        Parameters
+        ----------
+        metric : `float`
+            The current score.
+        best : `float`
+            The current best score.
+        min_delta : `float`
+            Minimum change to be considered as an improvement.
+
+        Returns
+        -------
+        `bool`
+            Whether the metric improved.
+
+        """
+        return metric > best + min_delta
+
+    def __repr__(self):
+        """Representation.
+
+        Returns
+        -------
+        fs : `str`
+            Representation string.
+
+        """
+        fs = self.__class__.__name__
+        fs += '(mode={}, best={:.2f}, delta={}, patience={})'.format(
+            self.mode, self.best, self.min_delta, self.patience)
+
+        return fs
+
+
+@dataclasses.dataclass
+class NetworkInference(BaseConfig):
+    """Model inference configuration.
+
+    Evaluate a model.
+
+    Attributes
+    ----------
+    state_file : :py:class:`pathlib.Path`
+        Path to the model to evaluate.
+    implicit : `bool`
+        Whether to evaluate the model on the datasets defined at training time.
+    domain : `str`
+        Whether to evaluate on the source domain (``domain='src'``), i.e. the
+        domain the model in ``state_file`` was trained on, or a target domain
+        (``domain='trg'``).
+    test : `bool` or `None`
+        Whether to evaluate the model on the training(``test=None``), the
+        validation (``test=False``) or the test set (``test=True``).
+    map_labels : `bool`
+        Whether to map the model labels from the model source domain to the
+        defined ``domain`` in case the domain class labels differ.
+    ds : `dict`
+        The dataset configuration dictionary passed to
+        :py:class:`pysegcnn.core.trainer.DatasetConfig` when evaluating on
+        an explicitly defined dataset, i.e. ``implicit=False``.
+    ds_split : `dict`
+        The dataset split configuration dictionary passed to
+        :py:class:`pysegcnn.core.trainer.SplitConfig` when evaluating on
+        an explicitly defined dataset, i.e. ``implicit=False``.
+    predict_scene : `bool`
+        The model prediction order. If False, the samples (tiles) of a dataset
+        are predicted in any order and the scenes are not reconstructed.
+        If True, the samples (tiles) are ordered according to the scene they
+        belong to and a model prediction for each entire reconstructed scene is
+        returned. The default is `False`.
+    plot_samples : `bool`
+        Whether to save a plot of false color composite, ground truth and model
+        prediction for each sample (tile). Only used if ``predict_scene=False``
+        . The default is `False`.
+    plot_scenes : `bool`
+        Whether to save a plot of false color composite, ground truth and model
+        prediction for each entire scene. Only used if ``predict_scene=True``.
+        The default is `False`.
+    plot_bands : `list` [`str`]
+        The bands to build the false color composite. The default is
+        `['nir', 'red', 'green']`.
+    cm : `bool`
+        Whether to compute and plot the confusion matrix. The default is `True`
+        .
+    figsize : `tuple`
+        The figure size in centimeters. The default is `(10, 10)`.
+    alpha : `int`
+        The level of the percentiles for contrast stretching of the false color
+        compsite. The default is `0`, i.e. no stretching.
+    animate : `bool`
+        Whether to create an animation of (input, ground truth, prediction) for
+        the scenes in the train/validation/test dataset. Only works if
+        ``predict_scene=True`` and ``plot_scene=True``.
+    device : `str`
+        The device to evaluate the model on, i.e. `cpu` or `cuda`.
+    base_path : :py:class:`pathlib.Path`
+        Root path to store model output.
+    sample_path : :py:class:`pathlib.Path`
+        Path to store plots of model predictions for single samples.
+    scenes_path : :py:class:`pathlib.Path`
+        Path to store plots of model predictions for entire scenes.
+    perfmc_path : :py:class:`pathlib.Path`
+        Path to store plots of model performance, e.g. confusion matrix.
+    animtn_path : :py:class:`pathlib.Path`
+        Path to store animations.
+    models_path : :py:class:`pathlib.Path`
+        Path to search for model state files, i.e. pretrained models.
+    kwargs : `dict`
+        Keyword arguments for :py:func:`pysegcnn.core.graphics.plot_sample`
+    basename : `str`
+        Base filename for each plot.
+    model : :py:class:`pysegcnn.core.models.Network`
+        The model to use for inference.
+    model_state : `dict`
+        A dictionary containing the model and optimizer state, as
+        constructed by :py:meth:`~pysegcnn.core.Network.save`.
+    trg_ds : :py:class:`pysegcnn.core.split.CustomSubset`
+        The dataset to evaluate ``model`` on.
+    src_ds : :py:class:`pysegcnn.core.split.CustomSubset`
+        The model source domain training dataset.
+    fig : :py:class:`matplotlib.figure.Figure`
+        A :py:class:`matplotlib.figure.Figure` instance to iteratively plot to.
+    anim : :py:class:`pysegcnn.core.graphics.Animate`
+        An instance :py:class:`pysegcnn.core.graphics.Animate` Used to create
+        animations if ``animate=True``.
+    conf_mat : :py:class:`numpy.ndarray`
+        The model confusion matrix.
+
+    """
+
+    state_file: pathlib.Path
+    implicit: bool
+    domain: str
+    test: object
+    map_labels: bool
+    ds: dict = dataclasses.field(default_factory={})
+    ds_split: dict = dataclasses.field(default_factory={})
+    predict_scene: bool = False
+    plot_samples: bool = False
+    plot_scenes: bool = False
+    plot_bands: list = dataclasses.field(
+        default_factory=lambda: ['nir', 'red', 'green'])
+    cm: bool = True
+    figsize: tuple = (10, 10)
+    alpha: int = 5
+    animate: bool = False
 
     def __post_init__(self):
         """Check the type of each argument.
 
-        Configure the device to train the model on, i.e. train on the gpu if
-        available.
-
-        Configure early stopping if required.
+        Configure figure output paths.
 
-        Initialize training metric tracking.
+        Raises
+        ------
+        TypeError
+            Raised if ``test`` is not of type `bool` or `None`.
+        ValueError
+            Raised if ``domain`` is not 'src' or 'trg'.
 
         """
         super().__post_init__()
 
-        # the device to train the model on
-        self.device = torch.device('cuda:0' if torch.cuda.is_available() else
-                                   'cpu')
-        # set the number of threads
-        torch.set_num_threads(self.nthreads)
-
-        # send the model to the gpu if available
-        self.model = self.model.to(self.device)
-
-        # instanciate metric tracker
-        self.tracker = MetricTracker(
-            train_metrics=['train_loss', 'train_accu'],
-            valid_metrics=['valid_loss', 'valid_accu'])
-
-        # whether to train using supervised transfer learning or
-        # deep domain adaptation
+        # check whether the test input parameter is correctly specified
+        if self.test not in [None, False, True]:
+            raise TypeError('Expected "test" to be None, True or False, got '
+                            '{}.'.format(self.test))
 
-        # dummy variables for easy model evaluation
-        self.uda = False
-        if self.trg_train_dl.dataset is not None and self.uda_lambda > 0:
+        # check whether the domain is correctly specified
+        if self.domain not in ['src', 'trg']:
+            raise ValueError('Expected "domain" to be "src" or "trg", got {}.'
+                             .format(self.domain))
 
-            # set the device for computing domain adaptation loss
-            self.uda_loss_function.device = self.device
+        # the device to compute on, use gpu if available
+        self.device = torch.device("cuda:0" if torch.cuda.is_available() else
+                                   "cpu")
 
-            # adjust metrics and initialize metric tracker
-            self.tracker.train_metrics.extend(['cla_loss', 'uda_loss'])
+        # the output paths for the different graphics
+        self.base_path = pathlib.Path(HERE)
+        self.sample_path = self.base_path.joinpath('_samples')
+        self.scenes_path = self.base_path.joinpath('_scenes')
+        self.perfmc_path = self.base_path.joinpath('_graphics')
+        self.animtn_path = self.base_path.joinpath('_animations')
 
-            # train using deep domain adaptation
-            self.uda = True
+        # input path for model state files
+        self.models_path = self.base_path.joinpath('_models')
+        self.state_file = self.models_path.joinpath(self.state_file)
 
-            # forward function for deep domain adaptation
-            self.uda_forward = self._uda_frwd()
+        # initialize logging
+        log = LogConfig(self.state_file)
+        dictConfig(log_conf(log.log_file))
+        log.init_log('{}: ' + 'Evaluating model: {}.'
+                     .format(self.state_file.name))
 
-        # initialize metric tracker
-        self.tracker.initialize()
+        # plotting keyword arguments
+        self.kwargs = {'bands': self.plot_bands,
+                       'alpha': self.alpha,
+                       'figsize': self.figsize}
 
-        # maximum accuracy on the validation set
-        self.max_accuracy = 0
-        if self.checkpoint_state:
-            self.max_accuracy = self.checkpoint_state['valid_accu'].mean(
-                axis=0).max().item()
+        # base filename for each plot
+        self.basename = self.state_file.stem
 
-        # whether to use early stopping
-        self.es = None
-        if self.early_stop:
-            self.es = EarlyStopping(self.mode, self.max_accuracy, self.delta,
-                                    self.patience)
+        # load the model state
+        self.model, _, self.model_state = Network.load(self.state_file)
 
-        # number of mini-batches in the training and validation sets
-        self.tmbatch = len(self.src_train_dl)
-        self.vmbatch = len(self.src_valid_dl)
+        # load the target dataset: dataset to evaluate the model on
+        self.trg_ds = self.load_dataset()
 
-        # log representation
-        LOGGER.info(repr(self))
+        # load the source dataset: dataset the model was trained on
+        self.src_ds = self.model_state['src_train_dl'].dataset.dataset
 
-        # initialize training log
-        LOGGER.info(35 * '-' + ' Training ' + 35 * '-')
+        # create a figure to use for plotting
+        self.fig, _ = plt.subplots(1, 3, figsize=self.kwargs['figsize'])
 
-        # log the device and number of threads
-        LOGGER.info('Device: {}'.format(self.device))
-        LOGGER.info('Number of cpu threads: {}'.format(self.nthreads))
+        # check if the animate parameter is correctly specified
+        if self.animate:
+            if not self.plot:
+                LOGGER.warning('animate requires plot_scenes=True or '
+                               'plot_samples=True. Can not create animation.')
+            else:
+                # check whether the output path is valid
+                if not self.anim_path.exists():
+                    # create output path
+                    self.anim_path.mkdir(parents=True, exist_ok=True)
+                    self.anim = Animate(self.anim_path)
 
-    def _train_source_domain(self, epoch):
-        """Train a model for an epoch on the source domain.
+    @staticmethod
+    def get_scene_tiles(ds, scene_id):
+        """Return the tiles of the scene with id ``scene_id``.
 
         Parameters
         ----------
-        epoch : `int`
-            The current epoch.
-
-        """
-        # iterate over the dataloader object
-        for batch, (inputs, labels) in enumerate(self.src_train_dl):
-
-            # send the data to the gpu if available
-            inputs = inputs.to(self.device)
-            labels = labels.to(self.device)
-
-            # reset the gradients
-            self.optimizer.zero_grad()
-
-            # perform forward pass
-            outputs = self.model(inputs)
-
-            # compute loss
-            loss = self.cla_loss_function(outputs, labels.long())
-
-            # compute the gradients of the loss function w.r.t.
-            # the network weights
-            loss.backward()
-
-            # update the weights
-            self.optimizer.step()
-
-            # calculate predicted class labels
-            ypred = F.softmax(outputs, dim=1).argmax(dim=1)
-
-            # calculate accuracy on current batch
-            acc = accuracy_function(ypred, labels)
-
-            # print progress
-            LOGGER.info('Epoch: {:d}/{:d}, Mini-batch: {:d}/{:d}, '
-                        'Loss: {:.2f}, Accuracy: {:.2f}'
-                        .format(epoch + 1, self.epochs, batch + 1,
-                                self.tmbatch, loss.item(), acc))
-
-            # update training metrics
-            self.tracker.batch_update(self.tracker.train_metrics,
-                                      [loss.item(), acc])
-
-    def _enc_uda(self, src_input, trg_input):
-
-        # perform forward pass: encoded source domain features
-        src_feature = self.model.encoder(src_input)
-        src_dec_feature = self.model.decoder(src_feature,
-                                             self.model.encoder.cache)
-        # model logits on source domain
-        src_prdctn = self.model.classifier(src_dec_feature)
-        del self.model.encoder.cache  # clear intermediate encoder outputs
-
-        # perform forward pass: encoded target domain features
-        trg_feature = self.model.encoder(trg_input)
-
-        return src_feature, trg_feature, src_prdctn
-
-    def _dec_uda(self, src_input, trg_input):
-
-        # perform forward pass: decoded source domain features
-        src_feature = self.model.encoder(src_input)
-        src_feature = self.model.decoder(src_feature,
-                                         self.model.encoder.cache)
-        # model logits on source domain
-        src_prdctn = self.model.classifier(src_feature)
-        del self.model.encoder.cache  # clear intermediate encoder outputs
-
-        # perform forward pass: decoded target domain features
-        trg_feature = self.model.encoder(trg_input)
-        trg_feature = self.model.decoder(trg_feature,
-                                         self.model.encoder.cache)
-        del self.model.encoder.cache
-
-        return src_feature, trg_feature, src_prdctn
+        ds : :py:class:`pysegcnn.core.split.CustomSubset`
+            A instance of a subclass of
+            :py:class:`pysegcnn.core.split.CustomSubset`.
+        scene_id : `str`
+            A valid scene identifier.
 
-    def _cla_uda(self, src_input, trg_input):
+        Raises
+        ------
+        ValueError
+            Raised if ``scene_id`` is not a valid scene identifier for the
+            dataset ``ds``.
 
-        # perform forward pass: classified source domain features
-        src_feature = self.model(src_input)
+        Returns
+        -------
+        indices : `list` [`int`]
+            List of indices of the tiles of the scene with id ``scene_id`` in
+            ``ds``.
+        date : :py:class:`datetime.datetime`
+            The date of the scene with id ``scene_id``.
 
-        # perform forward pass: target domain features
-        trg_feature = self.model(trg_input)
+        """
+        # check if the scene id is valid
+        scene_meta = ds.dataset.parse_scene_id(scene_id)
+        if scene_meta is None:
+            raise ValueError('{} is not a valid scene identifier'
+                             .format(scene_id))
 
-        return src_feature, trg_feature, src_feature
+        # iterate over the scenes of the dataset
+        indices = []
+        for i, scene in enumerate(ds.scenes):
+            # if the scene id matches a given id, save the index of the scene
+            if scene['id'] == scene_id:
+                indices.append(i)
 
-    def _uda_frwd(self):
-        if self.uda_pos == 'enc':
-            forward = self._enc_uda
+        return indices, scene_meta['date']
 
-        if self.uda_pos == 'dec':
-            forward = self._dec_uda
+    @staticmethod
+    def replace_dataset_path(ds, drive_path):
+        """Replace the path to the datasets.
 
-        if self.uda_pos == 'cla':
-            forward = self._cla_uda
+        Useful to evaluate models on machines, that are different from the
+        machine the model was trained on.
 
-        return forward
+        .. important::
 
-    def _train_domain_adaptation(self, epoch):
-        """Train a model for an epoch on the source and target domain.
+            This function assumes that the datasets are stored in a directory
+            named "Datasets" on each machine.
 
-        This function implements deep domain adaptation by extending the
-        standard classification loss by a "domain adaptation loss" calculated
-        from unlabelled target domain samples.
+        See ``DRIVE_PATH`` in :py:mod:`pysegcnn.main.config`.
 
         Parameters
         ----------
-        epoch : `int`
-            The current epoch.
+        ds : :py:class:`pysegcnn.core.split.CustomSubset`
+            A subset of an instance of
+            :py:class:`pysegcnn.core.dataset.ImageDataset`.
+        drive_path : `str`
+            Base path to the datasets on the current machine. ``drive_path``
+            should end with `'Datasets'`.
+
+        Raises
+        ------
+        TypeError
+            Raised if ``ds`` is not an instance of
+            :py:class:`pysegcnn.core.split.CustomSubset` and if ``ds`` is not
+            a subset of an instance of
+            :py:class:`pysegcnn.core.dataset.ImageDataset`.
 
         """
-        # create target domain iterator
-        target = iter(self.trg_train_dl)
+        # check input type
+        if isinstance(ds, CustomSubset):
+            # check the type of the dataset
+            if not isinstance(ds.dataset, ImageDataset):
+                raise TypeError('ds should be a subset created from a {}.'
+                                .format(repr(ImageDataset)))
+        else:
+            raise TypeError('ds should be an instance of {}.'
+                            .format(repr(CustomSubset)))
 
-        # increase domain adaptation weight with increasing epochs
-        uda_lambda = self.uda_lambda * ((epoch + 1) / self.epochs)
+        # iterate over the scenes of the dataset
+        for scene in ds.dataset.scenes:
+            for k, v in scene.items():
+                # do only look for paths
+                if isinstance(v, str) and k != 'id':
 
-        # iterate over the number of samples
-        for batch, (src_input, src_label) in enumerate(self.src_train_dl):
+                    # drive path: match path before "Datasets"
+                    # dpath = re.search('^(.*)(?=(/.*Datasets))', v)
 
-            # get the target domain input data
-            try:
-                trg_input, _ = target.next()
-            # in case the iterator is finished, re-instanciate it
-            except StopIteration:
-                target = iter(self.trg_train_dl)
-                trg_input, _ = target.next()
+                    # drive path: match path up to "Datasets"
+                    dpath = re.search('^(.*?Datasets)', v)[0]
 
-            # send the data to the gpu if available
-            src_input, src_label = (src_input.to(self.device),
-                                    src_label.to(self.device))
-            trg_input = trg_input.to(self.device)
+                    # replace drive path
+                    if dpath != drive_path:
+                        scene[k] = v.replace(str(dpath), drive_path)
 
-            # reset the gradients
-            self.optimizer.zero_grad()
+    def load_dataset(self):
+        """Load the defined dataset.
 
-            # forward pass
-            src_feature, trg_feature, src_prdctn = self.uda_forward(src_input,
-                                                                    trg_input)
+        Raises
+        ------
+        ValueError
+            Raised if the requested dataset was not available at training time,
+            if ``implicit=True``.
 
-            # compute classification loss
-            cla_loss = self.cla_loss_function(src_prdctn, src_label.long())
+            Raised if the dataset ``ds`` does not have the same spectral bands
+            as the model to evaluate, if ``implicit=False``.
 
-            # compute domain adaptation loss:
-            # the difference between source and target domain is computed
-            # from the compressed representation of the model encoder
-            uda_loss = self.uda_loss_function(src_feature, trg_feature)
+        Returns
+        -------
+        ds : :py:class:`pysegcnn.core.split.CustomSubset`
+            The dataset to evaluate the model on.
 
-            # total loss
-            tot_loss = cla_loss + uda_lambda * uda_loss
+        """
+        # check whether to evaluate on the datasets defined at training time
+        if self.implicit:
 
-            # compute the gradients of the loss function w.r.t.
-            # the network weights
-            tot_loss.backward()
+            # check whether to evaluate the model on the training, validation
+            # or test set
+            if self.test is None:
+                ds_set = 'train'
+            else:
+                ds_set = 'test' if self.test else 'valid'
 
-            # update the weights
-            self.optimizer.step()
+            # the dataset to evaluate the model on
+            ds = self.model_state[
+                self.domain + '_{}_dl'.format(ds_set)].dataset
+            if ds is None:
+                raise ValueError('Requested dataset "{}" is not available.'
+                                 .format(self.domain + '_{}_dl'.format(ds_set))
+                                 )
 
-            # calculate predicted class labels
-            ypred = F.softmax(src_prdctn, dim=1).argmax(dim=1)
+            # log dataset representation
+            LOGGER.info('Evaluating on {} set of the {} domain defined at '
+                        'training time.'.format(ds_set, self.domain))
 
-            # calculate accuracy on current batch
-            acc = accuracy_function(ypred, src_label)
+        else:
+            # explicitly defined dataset
+            ds = DatasetConfig(**self.ds).init_dataset()
 
-            # print progress
-            LOGGER.info('Epoch: {:d}/{:d}, Mini-batch: {:d}/{:d}, '
-                        'Cla_loss: {:.2f}, Uda_loss: {:.2f}, '
-                        'Tot_loss: {:.2f}, Acc: {:.2f}'
-                        .format(epoch + 1, self.epochs, batch + 1,
-                                self.tmbatch, cla_loss.item(),
-                                uda_loss.item(), tot_loss.item(), acc))
+            # check if the spectral bands match
+            if ds.use_bands != self.model_state['bands']:
+                raise ValueError('The model was trained with bands {}, not '
+                                 'with bands {}.'.format(
+                                     self.model_state['bands'], ds.use_bands))
 
-            # update training metrics
-            self.tracker.batch_update(self.tracker.train_metrics,
-                                      [tot_loss.item(), acc,
-                                       cla_loss.item(), uda_loss.item()])
+            # split configuration
+            sc = SplitConfig(**self.ds_split)
+            train_ds, valid_ds, test_ds = sc.train_val_test_split(ds)
 
-    def train_epoch(self, epoch):
-        """Wrap the function to train a model for a single epoch.
+            # check whether to evaluate the model on the training, validation
+            # or test set
+            if self.test is None:
+                ds = train_ds
+            else:
+                ds = test_ds if self.test else valid_ds
 
-        Depends on whether to apply deep domain adaptation.
+            # log dataset representation
+            LOGGER.info('Evaluating on {} set of explicitly defined dataset: '
+                        '\n {}'.format(ds.name, repr(ds.dataset)))
 
-        Parameters
-        ----------
-        epoch : `int`
-            The current epoch.
+        # check the dataset path: replace by path on current machine
+        self.replace_dataset_path(ds, DRIVE_PATH)
+
+        return ds
+
+    @property
+    def source_labels(self):
+        """Class labels of the source domain the model was trained on.
 
         Returns
         -------
-        `function`
-            The function to train a model for a single epoch.
+        source_labels : `dict` [`int`, `dict`]
+            The class labels of the source domain.
 
         """
-        if self.uda:
-            self._train_domain_adaptation(epoch)
-        else:
-            self._train_source_domain(epoch)
+        return self.src_ds.labels
 
-    def train(self):
-        """Train the model.
+    @property
+    def target_labels(self):
+        """Class labels of the dataset to evaluate.
 
         Returns
         -------
-        training_state : `dict` [`str`, :py:class:`numpy.ndarray`]
-            The training state dictionary. The keys describe the type of the
-            training metric. See
-            :py:meth:`~pysegcnn.core.trainer.NetworkTrainer.training_state`.
+        target_labels :  `dict` [`int`, `dict`]
+            The class labels of the dataset to evaluate.
 
         """
-        # initialize the training: iterate over the entire training dataset
-        for epoch in range(self.epochs):
-
-            # set the model to training mode
-            LOGGER.info('Setting model to training mode ...')
-            self.model.train()
-
-            # train model for a single epoch
-            self.train_epoch(epoch)
-
-            # update the number of epochs trained
-            self.model.epoch += 1
-
-            # whether to evaluate model performance on the validation set and
-            # early stop the training process
-            if self.early_stop:
-
-                # model predictions on the validation set
-                valid_accu, valid_loss = self.predict(self.src_valid_dl)
-
-                # update validation metrics
-                self.tracker.batch_update(self.tracker.valid_metrics,
-                                          [valid_loss, valid_accu])
+        return self.trg_ds.dataset.labels
 
-                # metric to assess model performance on the validation set
-                epoch_acc = np.mean(valid_accu)
+    # @property
+    # def label_map(self):
+    #     """Label mapping from the source to the target domain.
 
-                # whether the model improved with respect to the previous epoch
-                if self.es.increased(epoch_acc, self.max_accuracy, self.delta):
-                    self.max_accuracy = epoch_acc
+    #     See :py:func:`pysegcnn.core.constants.map_labels`.
 
-                    # save model state if the model improved with
-                    # respect to the previous epoch
-                    self.save_state()
+    #     Returns
+    #     -------
+    #     label_map : `dict` [`int`, `int`]
+    #         Dictionary with source labels as keys and corresponding target
+    #         labels as values.
 
-                # whether the early stopping criterion is met
-                if self.es.stop(epoch_acc):
-                    break
+    #     """
+    #     # check whether the source domain labels are the same as the target
+    #     # domain labels
+    #     return map_labels(self.src_ds.get_labels(),
+    #                       self.trg_ds.dataset.get_labels())
 
-            else:
-                # if no early stopping is required, the model state is
-                # saved after each epoch
-                self.save_state()
+    @property
+    def source_is_target(self):
+        """Whether the source and target domain labels are the same.
 
-        return self.training_state
+        Returns
+        -------
+        source_is_target : `bool`
+            `True` if the source and target domain labels are the same, `False`
+            if not.
 
-    def predict(self, dataloader):
-        """Model inference at training time.
+        """
+        return self.label_map is None
 
-        Parameters
-        ----------
-        dataloader : :py:class:`torch.utils.data.DataLoader`
-            The validation dataloader to evaluate the model predictions.
+    @property
+    def apply_label_map(self):
+        """Whether to map source labels to target labels.
 
         Returns
         -------
-        accuracy : :py:class:`numpy.ndarray`
-            The mean model prediction accuracy on each mini-batch in the
-            validation set.
-        loss : :py:class:`numpy.ndarray`
-            The model loss for each mini-batch in the validation set.
+        apply_label_map : `bool`
+            `True` if source and target labels differ and label mapping is
+            requested, `False` otherwise.
 
         """
-        # set the model to evaluation mode
-        LOGGER.info('Setting model to evaluation mode ...')
-        self.model.eval()
-
-        # create arrays of the observed loss and accuracy
-        accuracy = []
-        loss = []
+        return not self.source_is_target and self.map_labels
 
-        # iterate over the validation/test set
-        LOGGER.info('Calculating accuracy on the validation set ...')
-        for batch, (inputs, labels) in enumerate(dataloader):
+    @property
+    def use_labels(self):
+        """Labels to be predicted.
 
-            # send the data to the gpu if available
-            inputs = inputs.to(self.device)
-            labels = labels.to(self.device)
+        Returns
+        -------
+        use_labels : `dict` [`int`, `dict`]
+            The labels of the classes to be predicted.
 
-            # calculate network outputs
-            with torch.no_grad():
-                outputs = self.model(inputs)
+        """
+        return (self.target_labels if self.apply_label_map else
+                self.source_labels)
 
-            # compute loss
-            cla_loss = self.cla_loss_function(outputs, labels.long())
-            loss.append(cla_loss.item())
+    @property
+    def bands(self):
+        """Spectral bands the model was trained with.
 
-            # calculate predicted class labels
-            pred = F.softmax(outputs, dim=1).argmax(dim=1)
+        Returns
+        -------
+        bands : `list` [`str`]
+            A list of the named spectral bands used to train the model.
 
-            # calculate accuracy on current batch
-            acc = accuracy_function(pred, labels)
-            accuracy.append(acc)
+        """
+        return self.src_ds.use_bands
 
-            # print progress
-            LOGGER.info('Mini-batch: {:d}/{:d}, Accuracy: {:.2f}'
-                        .format(batch + 1, len(dataloader), acc))
+    @property
+    def compute_cm(self):
+        """Whether to compute the confusion matrix.
 
-        # calculate overall accuracy on the validation/test set
-        LOGGER.info('Epoch: {:d}, Mean accuracy: {:.2f}%.'
-                    .format(self.model.epoch, np.mean(accuracy) * 100))
+        Returns
+        -------
+        compute_cm : `bool`
+            Whether the confusion matrix can be computed. For datasets with
+            labels different from the source domain labels, the confusion
+            matrix can not be computed.
 
-        return accuracy, loss
+        """
+        return (False if not self.source_is_target and not self.map_labels
+                else self.cm)
 
     @property
-    def training_state(self):
-        """Model training metrics.
+    def plot(self):
+        """Whether to save plots of (input, ground truth, prediction).
 
         Returns
         -------
-        state : `dict` [`str`, :py:class:`numpy.ndarray`]
-            The training state dictionary. The keys describe the type of the
-            training metric and the values are :py:class:`numpy.ndarray`'s of
-            the corresponding metric observed during training with
-            shape=(mini_batch, epoch).
+        plot : `bool`
+            Save plots for each sample or for each scene of the target dataset,
+            depending on ``self.predict_scene``.
 
         """
-        # current training state
-        state = self.tracker.np_state(self.tmbatch, self.vmbatch)
+        return self.plot_scenes if self.predict_scene else self.plot_samples
 
-        # optional: training state of the model checkpoint
-        if self.checkpoint_state:
-            # prepend values from checkpoint to current training state
-            state = {k1: np.hstack([v1, v2]) for (k1, v1), (k2, v2) in
-                     zip(self.checkpoint_state.items(), state.items())
-                     if k1 == k2}
+    @property
+    def is_scene_subset(self):
+        """Check the type of the target dataset.
 
-        return state
+        Whether ``self.trg_ds`` is an instance of
+        :py:class:`pysegcnn.core.split.SceneSubset`, as required when
+        ``self.predict_scene=True``.
 
-    def save_state(self):
-        """Save the model state."""
-        if self.save:
-            _ = self.model.save(self.state_file,
-                                self.optimizer,
-                                bands=self.bands,
-                                nclasses=self.model.nclasses,
-                                src_train_dl=self.src_train_dl,
-                                src_valid_dl=self.src_valid_dl,
-                                src_test_dl=self.src_test_dl,
-                                trg_train_dl=self.trg_train_dl,
-                                trg_valid_dl=self.trg_valid_dl,
-                                trg_test_dl=self.trg_test_dl,
-                                state=self.training_state,
-                                uda_lambda=self.uda_lambda
-                                )
+        Returns
+        -------
+        is_scene_subset : `bool`
+            Whether ``self.trg_ds`` is an instance of
+            :py:class:`pysegcnn.core.split.SceneSubset`.
 
-    @staticmethod
-    def init_network_trainer(src_ds_config, src_split_config, trg_ds_config,
-                             trg_split_config, model_config):
-        """Prepare network training.
+        """
+        return isinstance(self.trg_ds, SceneSubset)
 
-        Parameters
-        ----------
-        src_ds_config : :py:class:`pysegcnn.core.trainer.DatasetConfig`
-            The source domain dataset configuration.
-        src_split_config : :py:class:`pysegcnn.core.trainer.SplitConfig`
-            The source domain dataset split configuration.
-        trg_ds_config : :py:class:`pysegcnn.core.trainer.DatasetConfig`
-            The target domain dataset configuration..
-        trg_split_config : :py:class:`pysegcnn.core.trainer.SplitConfig`
-            The target domain dataset split configuration.
-        model_config : :py:class:`pysegcnn.core.trainer.ModelConfig`
-            The model configuration.
+    @property
+    def dataloader(self):
+        """Dataloader instance for model inference.
 
         Returns
         -------
-        trainer : :py:class:`pysegcnn.core.trainer.NetworkTrainer`
-            A network trainer instance.
-
-        See :py:mod:`pysegcnn.main.train.py` for an example on how to
-        instanciate a :py:class:`pysegcnn.core.trainer.NetworkTrainer`
-        instance.
-
-        """
-        # (i) instanciate the source domain configurations
-        src_dc = DatasetConfig(**src_ds_config)   # source domain dataset
-        src_sc = SplitConfig(**src_split_config)  # source domain dataset split
-
-        # (ii) instanciate the target domain configuration
-        trg_dc = DatasetConfig(**trg_ds_config)   # target domain dataset
-        trg_sc = SplitConfig(**trg_split_config)  # target domain dataset split
-
-        # (iii) instanciate the model configuration
-        mdlcfg = ModelConfig(**model_config)
-
-        # (iv) instanciate the model state file
-        sttcfg = StateConfig(src_dc, src_sc, trg_dc, trg_sc, mdlcfg)
-        state_file = sttcfg.init_state()
-
-        # (v) instanciate logging configuration
-        logcfg = LogConfig(state_file)
-        dictConfig(log_conf(logcfg.log_file))
-
-        # (vi) instanciate the source domain dataset
-        src_ds = src_dc.init_dataset()
-
-        # the spectral bands used to train the model
-        bands = src_ds.use_bands
-
-        # (vii) instanciate the training, validation and test datasets and
-        # dataloaders for the source domain
-        (src_train_ds,
-         src_valid_ds,
-         src_test_ds) = src_sc.train_val_test_split(src_ds)
-        (src_train_dl,
-         src_valid_dl,
-         src_test_dl) = src_sc.dataloaders(src_train_ds,
-                                           src_valid_ds,
-                                           src_test_ds,
-                                           batch_size=mdlcfg.batch_size,
-                                           shuffle=True, drop_last=False)
-
-        # (viii) instanciate the loss function
-        cla_loss_function = mdlcfg.init_cla_loss_function()
-
-        # (ix) check whether to apply transfer learning
-        if mdlcfg.transfer:
-
-            # (a) instanciate the target domain dataset
-            trg_ds = trg_dc.init_dataset()
-
-            # (b) instanciate the training, validation and test datasets and
-            # dataloaders for the target domain
-            (trg_train_ds,
-             trg_valid_ds,
-             trg_test_ds) = trg_sc.train_val_test_split(trg_ds)
-            (trg_train_dl,
-             trg_valid_dl,
-             trg_test_dl) = trg_sc.dataloaders(trg_train_ds,
-                                               trg_valid_ds,
-                                               trg_test_ds,
-                                               batch_size=mdlcfg.batch_size,
-                                               shuffle=True, drop_last=False)
-
-            # (c) instanciate the model: supervised transfer learning
-            if mdlcfg.supervised:
-                model, optimizer, checkpoint_state = mdlcfg.init_model(
-                    trg_ds, state_file)
-
-                # (x) instanciate the network trainer
-                trainer = NetworkTrainer(model,
-                                         optimizer,
-                                         state_file,
-                                         bands,
-                                         trg_train_dl,
-                                         trg_valid_dl,
-                                         trg_test_dl,
-                                         cla_loss_function,
-                                         epochs=mdlcfg.epochs,
-                                         nthreads=mdlcfg.nthreads,
-                                         early_stop=mdlcfg.early_stop,
-                                         mode=mdlcfg.mode,
-                                         delta=mdlcfg.delta,
-                                         patience=mdlcfg.patience,
-                                         checkpoint_state=checkpoint_state,
-                                         save=mdlcfg.save)
-
-            # (c) instanciate the model: unsupervised transfer learning
-            else:
-                model, optimizer, checkpoint_state = mdlcfg.init_model(
-                    src_ds, state_file)
-
-                # (x) instanciate the domain adaptation loss
-                uda_loss_function = mdlcfg.init_uda_loss_function(
-                    mdlcfg.uda_lambda)
-
-                # (xi) instanciate the network trainer
-                trainer = NetworkTrainer(model,
-                                         optimizer,
-                                         state_file,
-                                         bands,
-                                         src_train_dl,
-                                         src_valid_dl,
-                                         src_test_dl,
-                                         cla_loss_function,
-                                         trg_train_dl,
-                                         trg_valid_dl,
-                                         trg_test_dl,
-                                         uda_loss_function,
-                                         mdlcfg.uda_lambda,
-                                         mdlcfg.uda_pos,
-                                         mdlcfg.epochs,
-                                         mdlcfg.nthreads,
-                                         mdlcfg.early_stop,
-                                         mdlcfg.mode,
-                                         mdlcfg.delta,
-                                         mdlcfg.patience,
-                                         checkpoint_state,
-                                         mdlcfg.save)
+        dataloader : :py:class:`torch.utils.data.DataLoader`
+            The dataset for model inference.
 
-        else:
-            # (x) instanciate the model
-            model, optimizer, checkpoint_state = mdlcfg.init_model(
-                src_ds, state_file)
-
-            # (xi) instanciate the network trainer
-            trainer = NetworkTrainer(model,
-                                     optimizer,
-                                     state_file,
-                                     bands,
-                                     src_train_dl,
-                                     src_valid_dl,
-                                     src_test_dl,
-                                     cla_loss_function,
-                                     epochs=mdlcfg.epochs,
-                                     nthreads=mdlcfg.nthreads,
-                                     early_stop=mdlcfg.early_stop,
-                                     mode=mdlcfg.mode,
-                                     delta=mdlcfg.delta,
-                                     patience=mdlcfg.patience,
-                                     checkpoint_state=checkpoint_state,
-                                     save=mdlcfg.save)
-
-        return trainer
-
-    # def _build_ds_repr(self, train_dl, valid_dl, test_dl):
-    #     """Build the dataset representation.
+        """
+        # build the dataloader for model inference
+        return DataLoader(self.trg_ds, batch_size=self._batch_size,
+                          shuffle=False, drop_last=False)
 
-    #     Returns
-    #     -------
-    #     fs : `str`
-    #         Representation string.
+    @property
+    def _original_source_labels(self):
+        """Original source domain labels.
 
-    #     """
-    #     # dataset configuration
-    #     fs = '    (dataset):\n        '
-    #     fs += ''.join(repr(train_dl.dataset.dataset)).replace('\n',
-    #                                                           '\n' + 8 * ' ')
-    #     fs += '\n    (batch):\n        '
-    #     fs += '- batch size: {}\n        '.format(train_dl.batch_size)
-    #     fs += '- mini-batch shape (b, c, h, w): {}'.format(
-    #         ((train_dl.batch_size, len(train_dl.dataset.dataset.use_bands),) +
-    #           2 * (train_dl.dataset.dataset.tile_size,)))
-
-    #     # dataset split
-    #     fs += '\n    (split):'
-    #     for dl in [train_dl, valid_dl, test_dl]:
-    #         if dl.dataset is not None:
-    #             fs += '\n' + 8 * ' ' + repr(dl.dataset)
-
-    #     return fs
-
-    # def _build_model_repr_(self):
-    #     """Build the model representation.
+        Since PyTorch requires class labels to be an ascending sequence
+        starting from 0, the actual class labels in the ground truth may differ
+        from the class labels fed to the model.
 
-    #     Returns
-    #     -------
-    #     fs : `str`
-    #         Representation string.
+        Returns
+        -------
+        original_source_labels : `dict` [`int`, `dict`]
+            The original class labels of the source domain.
 
-    #     """
-    #     # model
-    #     fs = '\n    (model):' + '\n' + 8 * ' '
-    #     fs += ''.join(repr(self.model)).replace('\n', '\n' + 8 * ' ')
+        """
+        return self.src_ds._labels
+
+    @property
+    def _original_target_labels(self):
+        """Original target domain labels.
+
+        Returns
+        -------
+        original_target_labels : `dict` [`int`, `dict`]
+            The original class labels of the target domain.
 
-    #     # optimizer
-    #     fs += '\n    (optimizer):' + '\n' + 8 * ' '
-    #     fs += ''.join(repr(self.optimizer)).replace('\n', '\n' + 8 * ' ')
+        """
+        return self.trg_ds.dataset._labels
 
-    #     # early stopping
-    #     fs += '\n    (early stop):' + '\n' + 8 * ' '
-    #     fs += ''.join(repr(self.es)).replace('\n', '\n' + 8 * ' ')
+    @property
+    def _label_log(self):
+        """Log if a label mapping is applied.
 
-    #     # domain adaptation
-    #     if self.uda:
-    #         fs += '\n    (adaptation)' + '\n' + 8 * ' '
-    #         fs += repr(self.uda_loss_function).replace('\n', '\n' + 8 * ' ')
+        Returns
+        -------
+        log : `str`
+            Represenation of the label mapping.
 
-    #     return fs
+        """
+        log = 'Retaining model labels ({})'.format(
+            ', '.join([v['label'] for _, v in self.source_labels.items()]))
+        if self.apply_label_map:
+            log = (log.replace('Retaining', 'Mapping') + ' to target labels '
+                   '({}).'.format(', '.join([v['label'] for _, v in
+                                  self.target_labels.items()])))
+        return log
 
-    # def __repr__(self):
-    #     """Representation.
+    @property
+    def _batch_size(self):
+        """Batch size of the inference dataloader.
 
-    #     Returns
-    #     -------
-    #     fs : `str`
-    #         Representation string.
+        Returns
+        -------
+        batch_size : `int`
+            The batch size of the dataloader used for model inference. Depends
+            on whether to predict each sample of the target dataset
+            individually or whether to reconstruct each scene in the target
+            dataset.
 
-    #     """
-    #     # representation string to print
-    #     fs = self.__class__.__name__ + '(\n'
+        """
+        return (self.trg_ds.dataset.tiles if self.predict_scene and
+                self.is_scene_subset else 1)
 
-    #     # source domain
-    #     fs += '    (source domain)\n    '
-    #     fs += self._build_ds_repr(
-    #         self.src_train_dl, self.src_valid_dl, self.src_test_dl).replace(
-    #             '\n', '\n' + 4 * ' ')
+    def _check_long_filename(self, filename):
+        """Modify filenames that exceed Windows' maximum filename length.
 
-    #     # target domain
-    #     if self.uda:
-    #         fs += '\n    (target domain)\n    '
-    #         fs += self._build_ds_repr(
-    #             self.trg_train_dl,
-    #             self.trg_valid_dl,
-    #             self.trg_test_dl).replace('\n', '\n' + 4 * ' ')
+        Parameters
+        ----------
+        filename : `str`
+            The filename to check.
 
-    #     # model configuration
-    #     fs += self._build_model_repr_()
+        Returns
+        -------
+        filename : `str`
+            The modified filename, in case ``filename`` exceeds 255 characters.
 
-    #     fs += '\n)'
-    #     return fs
+        """
+        # check for maximum path component length: Windows allows a maximum
+        # of 255 characters for each component in a path
+        if len(filename) >= 255:
+            # try to parse the scene identifier
+            scene_metadata = self.trg_ds.dataset.parse_scene_id(filename)
+            if scene_metadata is not None:
+                # new, shorter name for the current batch
+                batch_name = '_'.join([scene_metadata['satellite'],
+                                       scene_metadata['tile'],
+                                       datetime.datetime.strftime(
+                                           scene_metadata['date'], '%Y%m%d')])
+                filename = filename.replace(scene_metadata['id'], batch_name)
 
+        return filename
 
-class EarlyStopping(object):
-    """`Early Stopping`_ algorithm.
+    def map_to_target(self, prd):
+        """Map source domain labels to target domain labels.
 
-    This implementation of the early stopping algorithm advances a counter each
-    time a metric did not improve over a training epoch. If the metric does not
-    improve over more than ``patience`` epochs, the early stopping criterion is
-    met.
+        Parameters
+        ----------
+        prd : :py:class:`torch.Tensor`
+            The source domain class labels as predicted by ```self.model``.
 
-    See the :py:meth:`pysegcnn.core.trainer.NetworkTrainer.train` method for an
-    example implementation.
+        Returns
+        -------
+        prd : :py:class:`torch.Tensor`
+            The predicted target domain labels.
 
-    .. _Early Stopping:
-        https://en.wikipedia.org/wiki/Early_stopping
+        """
+        # map actual source labels to original source labels
+        for aid, oid in zip(self.source_labels.keys(),
+                            self._original_source_labels.keys()):
+            prd[torch.where(prd == aid)] = oid
 
-    Attributes
-    ----------
-    mode : `str`
-        The early stopping mode.
-    best : `float`
-        Best metric score.
-    min_delta : `float`
-        Minimum change in early stopping metric to be considered as an
-        improvement.
-    patience : `int`
-        The number of epochs to wait for an improvement.
-    is_better : `function`
-        Function indicating whether the metric improved.
-    early_stop : `bool`
-        Whether the early stopping criterion is met.
-    counter : `int`
-        The counter advancing each time a metric does not improve.
+        # apply the label mapping
+        for src_label, trg_label in self.label_map.items():
+            prd[torch.where(prd == src_label)] = trg_label
 
-    """
+        # map original target labels to actual target labels
+        for oid, aid in zip(self._original_target_labels.keys(),
+                            self.target_labels.keys()):
+            prd[torch.where(prd == oid)] = aid
 
-    def __init__(self, mode='max', best=0, min_delta=0, patience=10):
-        """Initialize.
+        return prd
 
-        Parameters
-        ----------
-        mode : `str`, optional
-            The early stopping mode. Depends on the metric measuring
-            performance. When using model loss as metric, use ``mode='min'``,
-            however, when using accuracy as metric, use ``mode='max'``. For
-            now, only ``mode='max'`` is supported. Only used if
-            ``early_stop=True``. The default is `'max'`.
-        best : `float`, optional
-            Threshold indicating the best metric score. At instanciation, set
-            ``best`` to the worst possible score of the metric. ``best`` will
-            be overwritten during training. The default is `0`.
-        min_delta : `float`, optional
-            Minimum change in early stopping metric to be considered as an
-            improvement. Only used if ``early_stop=True``. The default is `0`.
-        patience : `int`, optional
-            The number of epochs to wait for an improvement in the early
-            stopping metric. If the model does not improve over more than
-            ``patience`` epochs, quit training. Only used if
-            ``early_stop=True``. The default is `10`.
+    def predict(self):
+        """Classify the samples of the target dataset.
 
-        Raises
-        ------
-        ValueError
-            Raised if ``mode`` is not either 'min' or 'max'.
+        Returns
+        -------
+        output : `dict` [`str`, `dict`]
+            The inference output dictionary. The keys are either the number of
+            the samples (``self.predict_scene=False``) or the name of the
+            scenes of the target dataset (``self.predict_scene=True``). The
+            values are dictionaries with keys:
+                ``'input'``
+                    Model input data of the sample (:py:class:`numpy.ndarray`).
+                ``'labels'
+                    Ground truth class labels (:py:class:`numpy.ndarray`).
+                ``'prediction'``
+                    Model prediction class labels (:py:class:`numpy.ndarray`).
 
         """
-        # check if mode is correctly specified
-        if mode not in ['min', 'max']:
-            raise ValueError('Mode "{}" not supported. '
-                             'Mode is either "min" (check whether the metric '
-                             'decreased, e.g. loss) or "max" (check whether '
-                             'the metric increased, e.g. accuracy).'
-                             .format(mode))
+        # iterate over the samples of the target dataset
+        output = {}
+        for batch, (inputs, labels) in enumerate(self.dataloader):
 
-        # mode to determine if metric improved
-        self.mode = mode
+            # send inputs and labels to device
+            inputs = inputs.to(self.device)
+            labels = labels.to(self.device)
 
-        # whether to check for an increase or a decrease in a given metric
-        self.is_better = self.decreased if mode == 'min' else self.increased
+            # compute model predictions
+            with torch.no_grad():
+                prdctn = F.softmax(self.model(inputs),
+                                   dim=1).argmax(dim=1).squeeze()
 
-        # minimum change in metric to be considered as an improvement
-        self.min_delta = min_delta
+            # map source labels to target dataset labels
+            if self.apply_label_map:
+                prdctn = self.map_to_target(prdctn)
 
-        # number of epochs to wait for improvement
-        self.patience = patience
+            # update confusion matrix
+            if self.compute_cm:
+                for ytrue, ypred in zip(labels.view(-1), prdctn.view(-1)):
+                    self.conf_mat[ytrue.long(), ypred.long()] += 1
 
-        # initialize best metric
-        self.best = best
+            # convert torch tensors to numpy arrays
+            inputs = inputs.numpy()
+            labels = labels.numpy()
+            prdctn = prdctn.numpy()
 
-        # initialize early stopping flag
-        self.early_stop = False
+            # progress string to log
+            progress = 'Sample: {:d}/{:d}'.format(batch + 1,
+                                                  len(self.dataloader))
 
-        # initialize the early stop counter
-        self.counter = 0
+            # check whether to reconstruct the scene
+            date = None
+            if self.dataloader.batch_size > 1:
 
-    def stop(self, metric):
-        """Advance early stopping counter.
+                # id and date of the current scene
+                batch = self.trg_ds.ids[batch]
+                if self.trg_ds.split_mode == 'date':
+                    date = self.trg_ds.dataset.parse_scene_id(batch)['date']
 
-        Parameters
-        ----------
-        metric : `float`
-            The current metric score.
+                # modify the progress string
+                progress = progress.replace('Sample', 'Scene')
+                progress += ' Id: {}'.format(batch)
 
-        Returns
-        -------
-        early_stop : `bool`
-            Whether the early stopping criterion is met.
+                # reconstruct the entire scene
+                inputs = reconstruct_scene(inputs)
+                labels = reconstruct_scene(labels)
+                prdctn = reconstruct_scene(prdctn)
 
-        """
-        # if the metric improved, reset the epochs counter, else, advance
-        if self.is_better(metric, self.best, self.min_delta):
-            self.counter = 0
-            self.best = metric
-        else:
-            self.counter += 1
-            LOGGER.info('Early stopping counter: {}/{}'.format(
-                self.counter, self.patience))
+            # save current batch to output dictionary
+            output[batch] = {'input': inputs, 'labels': labels,
+                             'prediction': prdctn}
 
-        # if the metric did not improve over the last patience epochs,
-        # the early stopping criterion is met
-        if self.counter >= self.patience:
-            LOGGER.info('Early stopping criterion met, stopping training.')
-            self.early_stop = True
+            # filename for the plot of the current batch
+            batch_name = self.basename + '_{}_{}.pt'.format(self.trg_ds.name,
+                                                            batch)
 
-        return self.early_stop
+            # check if the current batch name exceeds the Windows limit of
+            # 255 characters
+            batch_name = self._check_long_filename(batch_name)
 
-    def decreased(self, metric, best, min_delta):
-        """Whether a metric decreased with respect to a best score.
+            # in case the source and target class labels are the same or a
+            # label mapping is applied, the accuracy of the prediction can be
+            # calculated
+            if self.source_is_target or self.apply_label_map:
+                progress += ', Accuracy: {:.2f}'.format(
+                    accuracy_function(prdctn, labels))
+            LOGGER.info(progress)
 
-        Measure improvement for metrics that are considered as 'better' when
-        they decrease, e.g. model loss, mean squared error, etc.
+            # plot current scene
+            if self.plot:
 
-        Parameters
-        ----------
-        metric : `float`
-            The current score.
-        best : `float`
-            The current best score.
-        min_delta : `float`
-            Minimum change to be considered as an improvement.
+                # in case the source and target class labels are the same or a
+                # label mapping is applied, plot the ground truth, otherwise
+                # just plot the model prediction
+                gt = None
+                if self.source_is_target or self.apply_label_map:
+                    gt = labels
 
-        Returns
-        -------
-        `bool`
-            Whether the metric improved.
+                # plot inputs, ground truth and model predictions
+                plot_sample(inputs.clip(0, 1),
+                            self.bands,
+                            self.use_labels,
+                            y=gt,
+                            y_pred={self.model.__class__.__name__: prdctn},
+                            accuracy=True,
+                            state=batch_name,
+                            date=date,
+                            fig=self.fig,
+                            plot_path=self.scenes_path,
+                            **self.kwargs)
 
-        """
-        return metric < best - min_delta
+                # save current figure state as frame for animation
+                if self.animate:
+                    self.anim.frame(self.fig.axes)
 
-    def increased(self, metric, best, min_delta):
-        """Whether a metric increased with respect to a best score.
+        # save animation
+        if self.animate:
+            self.anim.animate(self.fig, interval=1000, repeat=True, blit=True)
+            self.anim.save(self.basename + '_{}.gif'.format(self.trg_ds.name),
+                           dpi=200)
 
-        Measure improvement for metrics that are considered as 'better' when
-        they increase, e.g. accuracy, precision, recall, etc.
+        return output
 
-        Parameters
-        ----------
-        metric : `float`
-            The current score.
-        best : `float`
-            The current best score.
-        min_delta : `float`
-            Minimum change to be considered as an improvement.
+    def evaluate(self):
+        """Evaluate a pretrained model on a defined dataset.
 
         Returns
         -------
-        `bool`
-            Whether the metric improved.
+        output : `dict` [`str`, `dict`]
+            The inference output dictionary. The keys are either the number of
+            the samples (``self.predict_scene=False``) or the name of the
+            scenes of the target dataset (``self.predict_scene=True``). The
+            values are dictionaries with keys:
+                ``'input'``
+                    Model input data of the sample (:py:class:`numpy.ndarray`).
+                ``'labels'
+                    Ground truth class labels (:py:class:`numpy.ndarray`).
+                ``'prediction'``
+                    Model prediction class labels (:py:class:`numpy.ndarray`).
 
         """
-        return metric > best + min_delta
+        # plot loss and accuracy
+        plot_loss(check_filename_length(self.state_file),
+                  outpath=self.perfmc_path)
 
-    def __repr__(self):
-        """Representation.
+        # set the model to evaluation mode
+        LOGGER.info('Setting model to evaluation mode ...')
+        self.model.eval()
+        self.model.to(self.device)
 
-        Returns
-        -------
-        fs : `str`
-            Representation string.
+        # initialize confusion matrix
+        self.conf_mat = np.zeros(shape=2 * (len(self.use_labels), ))
 
-        """
-        fs = self.__class__.__name__
-        fs += '(mode={}, best={:.2f}, delta={}, patience={})'.format(
-            self.mode, self.best, self.min_delta, self.patience)
+        # log which labels the model predicts
+        LOGGER.info(self._label_log)
 
-        return fs
+        # evaluate the model on the target dataset
+        output = self.predict()
+
+        # whether to plot the confusion matrix
+        if self.compute_cm:
+            plot_confusion_matrix(self.conf_mat, self.use_labels,
+                                  state_file=self.state_file,
+                                  subset=self.domain + '_' + self.trg_ds.name,
+                                  outpath=self.perfmc_path)
+
+        return output
+
+# def init_network_trainer(src_ds_config, src_split_config, trg_ds_config,
+#                          trg_split_config, model_config):
+#     """Prepare network training.
+
+#     Parameters
+#     ----------
+#     src_ds_config : :py:class:`pysegcnn.core.trainer.DatasetConfig`
+#         The source domain dataset configuration.
+#     src_split_config : :py:class:`pysegcnn.core.trainer.SplitConfig`
+#         The source domain dataset split configuration.
+#     trg_ds_config : :py:class:`pysegcnn.core.trainer.DatasetConfig`
+#         The target domain dataset configuration..
+#     trg_split_config : :py:class:`pysegcnn.core.trainer.SplitConfig`
+#         The target domain dataset split configuration.
+#     model_config : :py:class:`pysegcnn.core.trainer.ModelConfig`
+#         The model configuration.
+
+#     Returns
+#     -------
+#     trainer : :py:class:`pysegcnn.core.trainer.NetworkTrainer`
+#         A network trainer instance.
+
+#     See :py:mod:`pysegcnn.main.train.py` for an example on how to
+#     instanciate a :py:class:`pysegcnn.core.trainer.NetworkTrainer`
+#     instance.
+
+#     """
+#     # (i) instanciate the source domain configurations
+#     src_dc = DatasetConfig(**src_ds_config)   # source domain dataset
+#     src_sc = SplitConfig(**src_split_config)  # source domain dataset split
+
+#     # (ii) instanciate the target domain configuration
+#     trg_dc = DatasetConfig(**trg_ds_config)   # target domain dataset
+#     trg_sc = SplitConfig(**trg_split_config)  # target domain dataset split
+
+#     # (iii) instanciate the model configuration
+#     mdlcfg = ModelConfig(**model_config)
+
+#     # (iv) instanciate the model state file
+#     sttcfg = StateConfig(src_dc, src_sc, trg_dc, trg_sc, mdlcfg)
+#     state_file = sttcfg.init_state()
+
+#     # (v) instanciate logging configuration
+#     logcfg = LogConfig(state_file)
+#     dictConfig(log_conf(logcfg.log_file))
+
+#     # (vi) instanciate the source domain dataset
+#     src_ds = src_dc.init_dataset()
+
+#     # the spectral bands used to train the model
+#     bands = src_ds.use_bands
+
+#     # (vii) instanciate the training, validation and test datasets and
+#     # dataloaders for the source domain
+#     (src_train_ds,
+#      src_valid_ds,
+#      src_test_ds) = src_sc.train_val_test_split(src_ds)
+#     (src_train_dl,
+#      src_valid_dl,
+#      src_test_dl) = src_sc.dataloaders(src_train_ds,
+#                                        src_valid_ds,
+#                                        src_test_ds,
+#                                        batch_size=mdlcfg.batch_size,
+#                                        shuffle=True, drop_last=False)
+
+#     # (viii) instanciate the loss function
+#     cla_loss_function = mdlcfg.init_cla_loss_function()
+
+#     # (ix) check whether to apply transfer learning
+#     if mdlcfg.transfer:
+
+#         # (a) instanciate the target domain dataset
+#         trg_ds = trg_dc.init_dataset()
+
+#         # (b) instanciate the training, validation and test datasets and
+#         # dataloaders for the target domain
+#         (trg_train_ds,
+#          trg_valid_ds,
+#          trg_test_ds) = trg_sc.train_val_test_split(trg_ds)
+#         (trg_train_dl,
+#          trg_valid_dl,
+#          trg_test_dl) = trg_sc.dataloaders(trg_train_ds,
+#                                            trg_valid_ds,
+#                                            trg_test_ds,
+#                                            batch_size=mdlcfg.batch_size,
+#                                            shuffle=True, drop_last=False)
+
+#         # (c) instanciate the model: supervised transfer learning
+#         if mdlcfg.supervised:
+#             model, optimizer, checkpoint_state = mdlcfg.init_model(
+#                 trg_ds, state_file)
+
+#             # (x) instanciate the network trainer
+#             trainer = NetworkTrainer(model,
+#                                      optimizer,
+#                                      state_file,
+#                                      bands,
+#                                      trg_train_dl,
+#                                      trg_valid_dl,
+#                                      trg_test_dl,
+#                                      cla_loss_function,
+#                                      epochs=mdlcfg.epochs,
+#                                      nthreads=mdlcfg.nthreads,
+#                                      early_stop=mdlcfg.early_stop,
+#                                      mode=mdlcfg.mode,
+#                                      delta=mdlcfg.delta,
+#                                      patience=mdlcfg.patience,
+#                                      checkpoint_state=checkpoint_state,
+#                                      save=mdlcfg.save)
+
+#         # (c) instanciate the model: unsupervised transfer learning
+#         else:
+#             model, optimizer, checkpoint_state = mdlcfg.init_model(
+#                 src_ds, state_file)
+
+#             # (x) instanciate the domain adaptation loss
+#             uda_loss_function = mdlcfg.init_uda_loss_function(
+#                 mdlcfg.uda_lambda)
+
+#             # (xi) instanciate the network trainer
+#             trainer = NetworkTrainer(model,
+#                                      optimizer,
+#                                      state_file,
+#                                      bands,
+#                                      src_train_dl,
+#                                      src_valid_dl,
+#                                      src_test_dl,
+#                                      cla_loss_function,
+#                                      trg_train_dl,
+#                                      trg_valid_dl,
+#                                      trg_test_dl,
+#                                      uda_loss_function,
+#                                      mdlcfg.uda_lambda,
+#                                      mdlcfg.uda_pos,
+#                                      mdlcfg.epochs,
+#                                      mdlcfg.nthreads,
+#                                      mdlcfg.early_stop,
+#                                      mdlcfg.mode,
+#                                      mdlcfg.delta,
+#                                      mdlcfg.patience,
+#                                      checkpoint_state,
+#                                      mdlcfg.save)
+
+#     else:
+#         # (x) instanciate the model
+#         model, optimizer, checkpoint_state = mdlcfg.init_model(
+#             src_ds, state_file)
+
+#         # (xi) instanciate the network trainer
+#         trainer = NetworkTrainer(model,
+#                                  optimizer,
+#                                  state_file,
+#                                  bands,
+#                                  src_train_dl,
+#                                  src_valid_dl,
+#                                  src_test_dl,
+#                                  cla_loss_function,
+#                                  epochs=mdlcfg.epochs,
+#                                  nthreads=mdlcfg.nthreads,
+#                                  early_stop=mdlcfg.early_stop,
+#                                  mode=mdlcfg.mode,
+#                                  delta=mdlcfg.delta,
+#                                  patience=mdlcfg.patience,
+#                                  checkpoint_state=checkpoint_state,
+#                                  save=mdlcfg.save)
+
+#     return trainer