diff --git a/pysegcnn/core/trainer.py b/pysegcnn/core/trainer.py
index fa8715ed8882478792d8095bc07ad55ba0ce1546..73d3a4399f3bfa6ca3ce4c011e48cf71a3ad143c 100644
--- a/pysegcnn/core/trainer.py
+++ b/pysegcnn/core/trainer.py
@@ -1,53 +1,338 @@
-# !/usr/bin/env python
 # -*- coding: utf-8 -*-
 """
-Created on Fri Jun 26 16:31:36 2020
+Created on Wed Aug 12 10:24:34 2020
 
 @author: Daniel
 """
 # builtins
-import os
+import dataclasses
+import pathlib
 
 # externals
 import numpy as np
 import torch
+import torch.nn as nn
 import torch.nn.functional as F
-from torch.utils.data import DataLoader
+from torch.utils.data import DataLoader, Dataset
+from torch.optim import Optimizer
+
 
 # locals
-from pysegcnn.core.dataset import SupportedDatasets
+from pysegcnn.core.dataset import SupportedDatasets, ImageDataset
+from pysegcnn.core.transforms import Augment
+from pysegcnn.core.utils import img2np, item_in_enum, accuracy_function
+from pysegcnn.core.split import SupportedSplits
+from pysegcnn.core.models import (SupportedModels, SupportedOptimizers,
+                                  SupportedLossFunctions, Network)
 from pysegcnn.core.layers import Conv2dSame
-from pysegcnn.core.utils import img2np, accuracy_function
-from pysegcnn.core.split import (RandomTileSplit, RandomSceneSplit, DateSplit,
-                                 VALID_SPLIT_MODES)
+from pysegcnn.main.config import HERE
+
+
+@dataclasses.dataclass
+class BaseConfig:
+
+    def __post_init__(self):
+        # check input types
+        for field in dataclasses.fields(self):
+            # the value of the current field
+            value = getattr(self, field.name)
+
+            # check whether the value is of the correct type
+            if not isinstance(value, field.type):
+                # try to convert the value to the correct type
+                try:
+                    setattr(self, field.name, field.type(value))
+                except TypeError:
+                    # raise an exception if the conversion fails
+                    raise TypeError('Expected {} to be {}, got {}.'
+                                    .format(field.name, field.type,
+                                            type(value)))
+
+
+@dataclasses.dataclass
+class DatasetConfig(BaseConfig):
+    dataset_name: str
+    root_dir: pathlib.Path
+    bands: list
+    tile_size: int
+    gt_pattern: str
+    seed: int
+    sort: bool = False
+    transforms: list = dataclasses.field(default_factory=list)
+    pad: bool = False
+    cval: int = 99
+
+    def __post_init__(self):
+        # check input types
+        super().__post_init__()
 
+        # check whether the dataset is currently supported
+        self.dataset_class = item_in_enum(self.dataset_name, SupportedDatasets)
 
-class NetworkTrainer(object):
+        # check whether the root directory exists
+        if not self.root_dir.exists():
+            raise FileNotFoundError('{} does not exist.'.format(self.root_dir))
 
-    def __init__(self, config):
+        # check whether the transformations inherit from the correct class
+        if not all([isinstance(t, Augment) for t in self.transforms if
+                    self.transforms]):
+            raise TypeError('Each transformation is expected to be an instance'
+                            ' of {}.'.format('.'.join([Augment.__module__,
+                                                       Augment.__name__])))
 
-        # the configuration file as defined in pysegcnn.main.config.py
-        for k, v in config.items():
-            setattr(self, k, v)
+        # check whether the constant padding value is within the valid range
+        if not 0 < self.cval < 255:
+            raise ValueError('Expecting 0 <= cval <= 255, got cval={}.'
+                             .format(self.cval))
 
-        # whether to use the gpu
-        self.device = torch.device("cuda:0" if torch.cuda.is_available() else
-                                   "cpu")
+    def init_dataset(self):
+
+        # instanciate the dataset
+        dataset = self.dataset_class(
+                    root_dir=str(self.root_dir),
+                    use_bands=self.bands,
+                    tile_size=self.tile_size,
+                    seed=self.seed,
+                    sort=self.sort,
+                    transforms=self.transforms,
+                    pad=self.pad,
+                    cval=self.cval,
+                    gt_pattern=self.gt_pattern
+                    )
+
+        return dataset
+
+
+@dataclasses.dataclass
+class SplitConfig(BaseConfig):
+    split_mode: str
+    ttratio: float
+    tvratio: float
+    date: str = 'yyyymmdd'
+    dateformat: str = '%Y%m%d'
+    drop: float = 0
+
+    def __post_init__(self):
+        # check input types
+        super().__post_init__()
+
+        # check if the split mode is valid
+        self.split_class = item_in_enum(self.split_mode, SupportedSplits)
+
+    # function to drop samples with a fraction of pixels equal to the constant
+    # padding value self.cval >= self.drop
+    def _drop_samples(self, ds, drop_threshold=1):
+
+        # iterate over the scenes returned by self.compose_scenes()
+        dropped = []
+        for pos, i in enumerate(ds.indices):
+
+            # the current scene
+            s = ds.dataset.scenes[i]
+
+            # the current tile in the ground truth
+            tile_gt = img2np(s['gt'], ds.dataset.tile_size, s['tile'],
+                             ds.dataset.pad, ds.dataset.cval)
 
-        # initialize the dataset to train the model on
-        self._init_dataset()
+            # percent of pixels equal to the constant padding value
+            npixels = (tile_gt[tile_gt == ds.dataset.cval].size / tile_gt.size)
+
+            # drop samples where npixels >= self.drop
+            if npixels >= drop_threshold:
+                print('Skipping scene {}, tile {}: {:.2f}% padded pixels ...'
+                      .format(s['id'], s['tile'], npixels * 100))
+                dropped.append(s)
+                _ = ds.indices.pop(pos)
+
+        return dropped
+
+    def train_val_test_split(self, ds):
+
+        if not isinstance(ds, ImageDataset):
+            raise TypeError('Expected "ds" to be {}.'
+                            .format('.'.join([ImageDataset.__module__,
+                                              ImageDataset.__name__])))
+
+        if self.split_mode == 'random' or self.split_mode == 'scene':
+            subset = self.split_class(ds,
+                                      self.ttratio,
+                                      self.tvratio,
+                                      ds.seed)
+
+        else:
+            subset = self.split_class(ds, self.date, self.dateformat)
+
+        # the training, validation and test dataset
+        train_ds, valid_ds, test_ds = subset.split()
+
+        # whether to drop training samples with a fraction of pixels equal to
+        # the constant padding value cval >= drop
+        if ds.pad and self.drop > 0:
+            self.dropped = self._drop_samples(train_ds, self.drop)
+
+        return train_ds, valid_ds, test_ds
+
+    @staticmethod
+    def dataloaders(*args, **kwargs):
+        # check whether each dataset in args has the correct type
+        loaders = []
+        for ds in args:
+            if not isinstance(ds, Dataset):
+                raise TypeError('Expected {}, got {}.'
+                                .format(repr(Dataset), type(ds)))
+
+            # check if the dataset is not empty
+            if len(ds) > 0:
+                # build the dataloader
+                loader = DataLoader(ds, **kwargs)
+            else:
+                loader = None
+            loaders.append(loader)
 
-        # initialize the model state files
-        self._init_state()
+        return loaders
 
-        # initialize the model
-        self._init_model()
 
-    def from_pretrained(self):
+@dataclasses.dataclass
+class ModelConfig(BaseConfig):
+    model_name: str
+    filters: list
+    torch_seed: int
+    skip_connection: bool = True
+    kwargs: dict = dataclasses.field(
+        default_factory=lambda: {'kernel_size': 3, 'stride': 1, 'dilation': 1})
+    state_path: pathlib.Path = pathlib.Path(HERE).joinpath('_models/')
+    batch_size: int = 64
+    checkpoint: bool = False
+    pretrained: bool = False
+    pretrained_model: str = ''
+
+    def __post_init__(self):
+        # check input types
+        super().__post_init__()
+
+        # check whether the model is currently supported
+        self.model_class = item_in_enum(self.model_name, SupportedModels)
+
+    def init_state(self, ds, sc, tc):
+
+        # file to save model state to:
+        # network_dataset_optim_split_splitparams_tilesize_batchsize_bands.pt
+
+        # model state filename
+        state_file = '{}_{}_{}_{}Split_{}_t{}_b{}_{}.pt'
+
+        # get the band numbers
+        bformat = ''.join(band[0] +
+                          str(ds.sensor.__members__[band].value) for
+                          band in ds.use_bands)
+
+        # check which split mode was used
+        if sc.split_mode == 'date':
+            # store the date that was used to split the dataset
+            state_file = state_file.format(self.model_class.__name__,
+                                           ds.__class__.__name__,
+                                           tc.optim_name,
+                                           sc.split_mode.capitalize(),
+                                           sc.date,
+                                           ds.tile_size,
+                                           self.batch_size,
+                                           bformat)
+        else:
+            # store the random split parameters
+            split_params = 's{}_t{}v{}'.format(
+                ds.seed, str(sc.ttratio).replace('.', ''),
+                str(sc.tvratio).replace('.', ''))
+
+            # model state filename
+            state_file = state_file.format(self.model_class.__name__,
+                                           ds.__class__.__name__,
+                                           tc.optim_name,
+                                           sc.split_mode.capitalize(),
+                                           split_params,
+                                           ds.tile_size,
+                                           self.batch_size,
+                                           bformat)
+
+        # check whether a pretrained model was used and change state filename
+        # accordingly
+        if self.pretrained:
+            # add the configuration of the pretrained model to the state name
+            state_file = (state_file.replace('.pt', '_') +
+                          'pretrained_' + self.pretrained_model)
+
+        # path to model state
+        state = self.state_path.joinpath(state_file)
+
+        # path to model loss/accuracy
+        loss_state = pathlib.Path(str(state).replace('.pt', '_loss.pt'))
+
+        return state, loss_state
+
+    def init_model(self, ds):
+
+        # case (1): build a new model
+        if not self.pretrained:
+
+            # set the random seed for reproducibility
+            torch.manual_seed(self.torch_seed)
+
+            # instanciate the model
+            model = self.model_class(
+                in_channels=len(ds.use_bands),
+                nclasses=len(ds.labels),
+                filters=self.filters,
+                skip=self.skip_connection,
+                **self.kwargs)
+
+        # case (2): load a pretrained model
+        else:
+
+            # load pretrained model
+            model = self.load_pretrained()
+
+        return model
+
+    def load_checkpoint(self, state_file, loss_state, model, optimizer):
+
+        # initial accuracy on the validation set
+        max_accuracy = 0
+
+        # set the model checkpoint to None, overwritten when resuming
+        # training from an existing model checkpoint
+        checkpoint_state = {}
+
+        # whether to resume training from an existing model
+        if self.checkpoint:
+
+            # check if a model checkpoint exists
+            if not state_file.exists():
+                raise FileNotFoundError('Model checkpoint {} does not exist.'
+                                        .format(state_file))
+
+            # load the model state
+            state = model.load(state_file.name, optimizer, self.state_path)
+            print('Found checkpoint: {}'.format(state))
+            print('Resuming training from checkpoint ...'.format(state))
+            print('Model epoch: {:d}'.format(model.epoch))
+
+            # load the model loss and accuracy
+            checkpoint_state = torch.load(loss_state)
+
+            # get all non-zero elements, i.e. get number of epochs trained
+            # before the early stop
+            checkpoint_state = {k: v[np.nonzero(v)].reshape(v.shape[0], -1)
+                                for k, v in checkpoint_state.items()}
+
+            # maximum accuracy on the validation set
+            max_accuracy = checkpoint_state['va'][:, -1].mean().item()
+
+        return checkpoint_state, max_accuracy
+
+    def load_pretrained(self, ds):
 
         # load the pretrained model
-        model_state = os.path.join(self.state_path, self.pretrained_model)
-        if not os.path.exists(model_state):
+        model_state = self.state_path.joinpath(self.pretrained_model)
+        if not model_state.exists():
             raise FileNotFoundError('Pretrained model {} does not exist.'
                                     .format(model_state))
 
@@ -61,23 +346,24 @@ class NetworkTrainer(object):
         filters = model_state['params']['filters']
 
         # check whether the current dataset uses the correct spectral bands
-        if self.bands != bands:
+        if ds.use_bands != bands:
             raise ValueError('The bands of the pretrained network do not '
                              'match the specified bands: {}'
-                             .format(self.bands))
+                             .format(bands))
 
         # instanciate pretrained model architecture
-        model = self.model(**model_state['params'], **model_state['kwargs'])
+        model = self.model_class(**model_state['params'],
+                                 **model_state['kwargs'])
 
         # load pretrained model weights
-        model.load(self.pretrained_model, inpath=self.state_path)
+        model.load(self.pretrained_model, inpath=str(self.state_path))
 
         # reset model epoch to 0, since the model is trained on a different
         # dataset
         model.epoch = 0
 
         # adjust the number of classes in the model
-        model.nclasses = len(self.dataset.labels)
+        model.nclasses = len(ds.labels)
 
         # adjust the classification layer to the number of classes of the
         # current dataset
@@ -85,49 +371,103 @@ class NetworkTrainer(object):
                                       out_channels=model.nclasses,
                                       kernel_size=1)
 
-
         return model
 
-    def from_checkpoint(self):
-
-        # whether to resume training from an existing model
-        if not os.path.exists(self.state):
-            raise FileNotFoundError('Model checkpoint {} does not exist.'
-                                    .format(self.state))
-
-        # load the model state
-        state = self.model.load(self.state_file, self.optimizer,
-                                self.state_path)
-        print('Resuming training from {} ...'.format(state))
-        print('Model epoch: {:d}'.format(self.model.epoch))
-
-        # load the model loss and accuracy
-        checkpoint_state = torch.load(self.loss_state)
 
-        # get all non-zero elements, i.e. get number of epochs trained
-        # before the early stop
-        checkpoint_state = {k: v[np.nonzero(v)].reshape(v.shape[0], -1) for
-                            k, v in checkpoint_state.items()}
+@dataclasses.dataclass
+class TrainConfig(BaseConfig):
+    optim_name: str
+    loss_name: str
+    lr: float = 0.001
+    early_stop: bool = False
+    mode: str = 'max'
+    delta: float = 0
+    patience: int = 10
+    epochs: int = 50
+    nthreads: int = torch.get_num_threads()
+    save: bool = True
+
+    def __post_init__(self):
+        super().__post_init__()
+
+        # check whether the optimizer is currently supported
+        self.optim_class = item_in_enum(self.optim_name, SupportedOptimizers)
+
+        # check whether the loss function is currently supported
+        self.loss_class = item_in_enum(self.loss_name, SupportedLossFunctions)
+
+    def init_optimizer(self, model):
+
+        # initialize the optimizer for the specified model
+        optimizer = self.optim_class(model.parameters(), self.lr)
+
+        return optimizer
+
+    def init_loss_function(self):
+
+        loss_function = self.loss_class()
+
+        return loss_function
+
+
+@dataclasses.dataclass
+class EvalConfig(BaseConfig):
+    test: object
+    predict_scene: bool = False
+    plot_samples: bool = False
+    plot_scenes: bool = False
+    plot_bands: list = dataclasses.field(
+        default_factory=lambda: ['nir', 'red', 'green'])
+    cm: bool = True
+
+    def __post_init__(self):
+        super().__post_init__()
+
+        # check whether the test input parameter is correctly specified
+        if self.test not in [None, False, True]:
+            raise TypeError('Expected "test" to be None, True or False, got '
+                            '{}.'.format(self.test))
+
+@dataclasses.dataclass
+class NetworkTrainer(BaseConfig):
+    model: Network
+    optimizer: Optimizer
+    loss_function: nn.Module
+    train_dl: DataLoader
+    valid_dl: DataLoader
+    state_file: pathlib.Path
+    loss_state: pathlib.Path
+    epochs: int = 1
+    nthreads: int = torch.get_num_threads()
+    early_stop: bool = False
+    mode: str = 'max'
+    delta: float = 0
+    patience: int = 10
+    max_accuracy: float = 0
+    checkpoint_state: dict = dataclasses.field(default_factory=dict)
+    save: bool = True
+
+    def __post_init__(self):
+        super().__post_init__()
 
-        # maximum accuracy on the validation set
-        max_accuracy = checkpoint_state['va'][:, -1].mean().item()
+        # whether to use the gpu
+        self.device = torch.device("cuda:0" if torch.cuda.is_available()
+                                   else "cpu")
 
-        return checkpoint_state, max_accuracy
+        # whether to use early stopping
+        self.es = None
+        if self.early_stop:
+            self.es = EarlyStopping(self.mode, self.delta, self.patience)
 
     def train(self):
 
         print('------------------------- Training ---------------------------')
 
         # set the number of threads
+        print('Device: {}'.format(self.device))
+        print('Number of cpu threads: {}'.format(self.nthreads))
         torch.set_num_threads(self.nthreads)
 
-        # instanciate early stopping class
-        if self.early_stop:
-            es = EarlyStopping(self.mode, self.delta, self.patience)
-            print('Initializing early stopping ...')
-            print('mode = {}, delta = {}, patience = {} epochs ...'
-                  .format(self.mode, self.delta, self.patience))
-
         # create dictionary of the observed losses and accuracies on the
         # training and validation dataset
         tshape = (len(self.train_dl), self.epochs)
@@ -207,36 +547,22 @@ class NetworkTrainer(object):
                 epoch_acc = vacc.squeeze().mean()
 
                 # whether the model improved with respect to the previous epoch
-                if es.increased(epoch_acc, self.max_accuracy, self.delta):
+                if self.es.increased(epoch_acc, self.max_accuracy, self.delta):
                     self.max_accuracy = epoch_acc
+
                     # save model state if the model improved with
                     # respect to the previous epoch
-                    _ = self.model.save(self.state_file,
-                                        self.optimizer,
-                                        self.bands,
-                                        self.state_path)
-
-                    # save losses and accuracy
-                    self._save_loss(training_state,
-                                    self.checkpoint,
-                                    self.checkpoint_state)
+                    self.save_state(training_state)
 
                 # whether the early stopping criterion is met
-                if es.stop(epoch_acc):
+                if self.es.stop(epoch_acc):
                     break
 
             else:
-                # if no early stopping is required, the model state is saved
-                # after each epoch
-                _ = self.model.save(self.state_file,
-                                    self.optimizer,
-                                    self.bands,
-                                    self.state_path)
+                # if no early stopping is required, the model state is
+                # saved after each epoch
+                self.save_state(training_state)
 
-                # save losses and accuracy after each epoch
-                self._save_loss(training_state,
-                                self.checkpoint,
-                                self.checkpoint_state)
 
         return training_state
 
@@ -283,217 +609,37 @@ class NetworkTrainer(object):
                   .format(batch + 1, len(self.valid_dl), acc))
 
         # calculate overall accuracy on the validation/test set
-        print('After training for {:d} epochs, we achieved an overall '
-              'accuracy of {:.2f}%  on the validation set!'
+        print('Epoch {:d}, Overall accuracy: {:.2f}%.'
               .format(self.model.epoch, accuracies.mean() * 100))
 
         return accuracies, losses
 
-    def _init_state(self):
-
-        # file to save model state to
-        # format: network_dataset_seed_tilesize_batchsize_bands.pt
-
-        # get the band numbers
-        bformat = ''.join(band[0] +
-                          str(self.dataset.sensor.__members__[band].value) for
-                          band in self.bands)
-
-        # model state filename
-        self.state_file = ('{}_{}_s{}_t{}_b{}_{}.pt'
-                           .format(self.model.__name__,
-                                   self.dataset.__class__.__name__,
-                                   self.seed,
-                                   self.tile_size,
-                                   self.batch_size,
-                                   bformat))
-
-        # check whether a pretrained model was used and change state filename
-        # accordingly
-        if self.pretrained:
-            # add the configuration of the pretrained model to the state name
-            self.state_file = (self.state_file.replace('.pt', '_') +
-                               'pretrained_' + self.pretrained_model)
-
-        # path to model state
-        self.state = os.path.join(self.state_path, self.state_file)
-
-        # path to model loss/accuracy
-        self.loss_state = self.state.replace('.pt', '_loss.pt')
-
-    def _init_dataset(self):
-
-        # the dataset name
-        self.dataset_name = os.path.basename(self.root_dir)
-
-        # check whether the dataset is currently supported
-        if self.dataset_name not in SupportedDatasets.__members__:
-            raise ValueError('{} is not a valid dataset. '
-                             .format(self.dataset_name) +
-                             'Available datasets are: \n' +
-                             '\n'.join(name for name, _ in
-                                       SupportedDatasets.__members__.items()))
-        else:
-            self.dataset_class = SupportedDatasets.__members__[
-                self.dataset_name].value
-
-        # instanciate the dataset
-        self.dataset = self.dataset_class(
-                    self.root_dir,
-                    use_bands=self.bands,
-                    tile_size=self.tile_size,
-                    sort=self.sort,
-                    transforms=self.transforms,
-                    pad=self.pad,
-                    cval=self.cval,
-                    gt_pattern=self.gt_pattern
-                    )
-
-        # the mode to split
-        if self.split_mode not in VALID_SPLIT_MODES:
-            raise ValueError('{} is not supported. Valid modes are {}, see '
-                             'pysegcnn.main.config.py for a description of '
-                             'each mode.'.format(self.split_mode,
-                                                 VALID_SPLIT_MODES))
-        if self.split_mode == 'random':
-            self.subset = RandomTileSplit(self.dataset,
-                                          self.ttratio,
-                                          self.tvratio,
-                                          self.seed)
-        if self.split_mode == 'scene':
-            self.subset = RandomSceneSplit(self.dataset,
-                                           self.ttratio,
-                                           self.tvratio,
-                                           self.seed)
-        if self.split_mode == 'date':
-            self.subset = DateSplit(self.dataset,
-                                    self.date,
-                                    self.dateformat)
-
-        # the training, validation and test dataset
-        self.train_ds, self.valid_ds, self.test_ds = self.subset.split()
-
-        # whether to drop training samples with a fraction of pixels equal to
-        # the constant padding value self.cval >= self.drop
-        if self.pad and self.drop:
-            self._drop(self.train_ds)
-
-        # the shape of a single batch
-        self.batch_shape = (len(self.bands), self.tile_size, self.tile_size)
-
-        # the training dataloader
-        self.train_dl = None
-        if len(self.train_ds) > 0:
-            self.train_dl = DataLoader(self.train_ds,
-                                       self.batch_size,
-                                       shuffle=True,
-                                       drop_last=False)
-        # the validation dataloader
-        self.valid_dl = None
-        if len(self.valid_ds) > 0:
-            self.valid_dl = DataLoader(self.valid_ds,
-                                       self.batch_size,
-                                       shuffle=True,
-                                       drop_last=False)
-
-        # the test dataloader
-        self.test_dl = None
-        if len(self.test_ds) > 0:
-            self.test_dl = DataLoader(self.test_ds,
-                                      self.batch_size,
-                                      shuffle=True,
-                                      drop_last=False)
-
-    def _init_model(self):
-
-        # initial accuracy on the validation set
-        self.max_accuracy = 0
-
-        # set the model checkpoint to None, overwritten when resuming
-        # training from an existing model checkpoint
-        self.checkpoint_state = None
-
-        # case (1): build a model for the specified dataset
-        if not self.pretrained and not self.checkpoint:
-
-            # instanciate the model
-            self.model = self.model(in_channels=len(self.dataset.use_bands),
-                                    nclasses=len(self.dataset.labels),
-                                    filters=self.filters,
-                                    skip=self.skip_connection,
-                                    **self.kwargs)
-
-            # the optimizer used to update the model weights
-            self.optimizer = self.optimizer(self.model.parameters(), self.lr)
-
-        # case (2): using a pretrained model withouth existing checkpoint on
-        #           a new dataset, i.e. transfer learning
-        if self.pretrained and not self.checkpoint:
-            # load pretrained model
-            self.model = self.from_pretrained()
-
-            # the optimizer used to update the model weights
-            self.optimizer = self.optimizer(self.model.parameters(), self.lr)
-
-        # case (3): using a pretrained model with existing checkpoint on the
-        #           same dataset the pretrained model was trained on
-        elif self.checkpoint:
-
-            # instanciate the model
-            self.model = self.model(in_channels=len(self.dataset.use_bands),
-                                    nclasses=len(self.dataset.labels),
-                                    filters=self.filters,
-                                    skip=self.skip_connection,
-                                    **self.kwargs)
-
-            # the optimizer used to update the model weights
-            self.optimizer = self.optimizer(self.model.parameters(), self.lr)
-
-            # whether to resume training from an existing model checkpoint
-            if self.checkpoint:
-                (self.checkpoint_state,
-                 self.max_accuracy) = self.from_checkpoint()
-
-    # function to drop samples with a fraction of pixels equal to the constant
-    # padding value self.cval >= self.drop
-    def _drop(self, ds):
+    def save_state(self, training_state):
 
-        # iterate over the scenes returned by self.compose_scenes()
-        self.dropped = []
-        for pos, i in enumerate(ds.indices):
+        # whether to save the model state
+        if self.save:
+            # save model state
+            state = self.model.save(self.state_file.name,
+                                    self.optimizer,
+                                    self.train_dl.dataset.dataset.use_bands,
+                                    self.state_file.parent)
 
-            # the current scene
-            s = ds.dataset.scenes[i]
+            # save losses and accuracy
+            self._save_loss(training_state)
 
-            # the current tile in the ground truth
-            tile_gt = img2np(s['gt'], self.tile_size, s['tile'],
-                             self.pad, self.cval)
+    def _save_loss(self, training_state):
 
-            # percent of pixels equal to the constant padding value
-            npixels = (tile_gt[tile_gt == self.cval].size / tile_gt.size)
-
-            # drop samples where npixels >= self.drop
-            if npixels >= self.drop:
-                print('Skipping scene {}, tile {}: {:.2f}% padded pixels ...'
-                      .format(s['id'], s['tile'], npixels * 100))
-                self.dropped.append(s)
-                _ = ds.indices.pop(pos)
+        # save losses and accuracy
+        state = training_state
+        if self.checkpoint_state:
 
-    def _save_loss(self, training_state, checkpoint=False,
-                   checkpoint_state=None):
+            # append values from checkpoint to current training state
+            state = {k1: np.hstack([v1, v2]) for (k1, v1), (k2, v2) in
+                     zip(self.checkpoint_state.items(), training_state.items())
+                     if k1 == k2}
 
-        # save losses and accuracy
-        if checkpoint and checkpoint_state is not None:
-
-            # append values from checkpoint to current training
-            # state
-            torch.save({
-                k1: np.hstack([v1, v2]) for (k1, v1), (k2, v2) in
-                zip(checkpoint_state.items(), training_state.items())
-                if k1 == k2},
-                self.loss_state)
-        else:
-            torch.save(training_state, self.loss_state)
+        # save the model loss and accuracies to file
+        torch.save(state, self.loss_state)
 
     def __repr__(self):
 
@@ -502,26 +648,37 @@ class NetworkTrainer(object):
 
         # dataset
         fs += '    (dataset):\n        '
-        fs += ''.join(self.dataset.__repr__()).replace('\n', '\n        ')
+        fs += ''.join(
+            repr(self.train_dl.dataset.dataset)).replace('\n','\n        ')
 
         # batch size
         fs += '\n    (batch):\n        '
-        fs += '- batch size: {}\n        '.format(self.batch_size)
-        fs += '- batch shape (b, h, w): {}'.format(self.batch_shape)
+        fs += '- batch size: {}\n        '.format(self.train_dl.batch_size)
+        fs += '- mini-batch shape (b, c, h, w): {}'.format(
+            (self.train_dl.batch_size,
+             len(self.train_dl.dataset.dataset.use_bands),
+             self.train_dl.dataset.dataset.tile_size,
+             self.train_dl.dataset.dataset.tile_size)
+            )
 
         # dataset split
-        fs += '\n    (split):\n        '
-        fs += ''.join(self.subset.__repr__()).replace('\n', '\n        ')
+        fs += '\n    (split):'
+        fs += '\n        ' + repr(self.train_dl.dataset)
+        fs += '\n        ' + repr(self.valid_dl.dataset)
 
         # model
         fs += '\n    (model):\n        '
-        fs += ''.join(self.model.__repr__()).replace('\n', '\n        ')
+        fs += ''.join(repr(self.model)).replace('\n', '\n        ')
 
         # optimizer
         fs += '\n    (optimizer):\n        '
-        fs += ''.join(self.optimizer.__repr__()).replace('\n', '\n        ')
-        fs += '\n)'
+        fs += ''.join(repr(self.optimizer)).replace('\n', '\n        ')
+
+        # early stopping
+        fs += '\n    (early stop):\n        '
+        fs += ''.join(repr(self.es)).replace('\n', '\n        ')
 
+        fs += '\n)'
         return fs
 
 
@@ -555,6 +712,9 @@ class EarlyStopping(object):
         # initialize early stopping flag
         self.early_stop = False
 
+        # initialize the early stop counter
+        self.counter = 0
+
     def stop(self, metric):
 
         if self.best is not None:
@@ -584,3 +744,8 @@ class EarlyStopping(object):
 
     def increased(self, metric, best, min_delta):
         return metric > best + min_delta
+
+    def __repr__(self):
+        fs = (self.__class__.__name__ + '(mode={}, delta={}, patience={})'
+              .format(self.mode, self.min_delta, self.patience))
+        return fs