diff --git a/pysegcnn/core/trainer_old.py b/pysegcnn/core/trainer_old.py
new file mode 100644
index 0000000000000000000000000000000000000000..fa8715ed8882478792d8095bc07ad55ba0ce1546
--- /dev/null
+++ b/pysegcnn/core/trainer_old.py
@@ -0,0 +1,586 @@
+# !/usr/bin/env python
+# -*- coding: utf-8 -*-
+"""
+Created on Fri Jun 26 16:31:36 2020
+
+@author: Daniel
+"""
+# builtins
+import os
+
+# externals
+import numpy as np
+import torch
+import torch.nn.functional as F
+from torch.utils.data import DataLoader
+
+# locals
+from pysegcnn.core.dataset import SupportedDatasets
+from pysegcnn.core.layers import Conv2dSame
+from pysegcnn.core.utils import img2np, accuracy_function
+from pysegcnn.core.split import (RandomTileSplit, RandomSceneSplit, DateSplit,
+                                 VALID_SPLIT_MODES)
+
+
+class NetworkTrainer(object):
+
+    def __init__(self, config):
+
+        # the configuration file as defined in pysegcnn.main.config.py
+        for k, v in config.items():
+            setattr(self, k, v)
+
+        # whether to use the gpu
+        self.device = torch.device("cuda:0" if torch.cuda.is_available() else
+                                   "cpu")
+
+        # initialize the dataset to train the model on
+        self._init_dataset()
+
+        # initialize the model state files
+        self._init_state()
+
+        # initialize the model
+        self._init_model()
+
+    def from_pretrained(self):
+
+        # load the pretrained model
+        model_state = os.path.join(self.state_path, self.pretrained_model)
+        if not os.path.exists(model_state):
+            raise FileNotFoundError('Pretrained model {} does not exist.'
+                                    .format(model_state))
+
+        # load the model state
+        model_state = torch.load(model_state)
+
+        # get the input bands of the pretrained model
+        bands = model_state['bands']
+
+        # get the number of convolutional filters
+        filters = model_state['params']['filters']
+
+        # check whether the current dataset uses the correct spectral bands
+        if self.bands != bands:
+            raise ValueError('The bands of the pretrained network do not '
+                             'match the specified bands: {}'
+                             .format(self.bands))
+
+        # instanciate pretrained model architecture
+        model = self.model(**model_state['params'], **model_state['kwargs'])
+
+        # load pretrained model weights
+        model.load(self.pretrained_model, inpath=self.state_path)
+
+        # reset model epoch to 0, since the model is trained on a different
+        # dataset
+        model.epoch = 0
+
+        # adjust the number of classes in the model
+        model.nclasses = len(self.dataset.labels)
+
+        # adjust the classification layer to the number of classes of the
+        # current dataset
+        model.classifier = Conv2dSame(in_channels=filters[0],
+                                      out_channels=model.nclasses,
+                                      kernel_size=1)
+
+
+        return model
+
+    def from_checkpoint(self):
+
+        # whether to resume training from an existing model
+        if not os.path.exists(self.state):
+            raise FileNotFoundError('Model checkpoint {} does not exist.'
+                                    .format(self.state))
+
+        # load the model state
+        state = self.model.load(self.state_file, self.optimizer,
+                                self.state_path)
+        print('Resuming training from {} ...'.format(state))
+        print('Model epoch: {:d}'.format(self.model.epoch))
+
+        # load the model loss and accuracy
+        checkpoint_state = torch.load(self.loss_state)
+
+        # get all non-zero elements, i.e. get number of epochs trained
+        # before the early stop
+        checkpoint_state = {k: v[np.nonzero(v)].reshape(v.shape[0], -1) for
+                            k, v in checkpoint_state.items()}
+
+        # maximum accuracy on the validation set
+        max_accuracy = checkpoint_state['va'][:, -1].mean().item()
+
+        return checkpoint_state, max_accuracy
+
+    def train(self):
+
+        print('------------------------- Training ---------------------------')
+
+        # set the number of threads
+        torch.set_num_threads(self.nthreads)
+
+        # instanciate early stopping class
+        if self.early_stop:
+            es = EarlyStopping(self.mode, self.delta, self.patience)
+            print('Initializing early stopping ...')
+            print('mode = {}, delta = {}, patience = {} epochs ...'
+                  .format(self.mode, self.delta, self.patience))
+
+        # create dictionary of the observed losses and accuracies on the
+        # training and validation dataset
+        tshape = (len(self.train_dl), self.epochs)
+        vshape = (len(self.valid_dl), self.epochs)
+        training_state = {'tl': np.zeros(shape=tshape),
+                          'ta': np.zeros(shape=tshape),
+                          'vl': np.zeros(shape=vshape),
+                          'va': np.zeros(shape=vshape)
+                          }
+
+        # send the model to the gpu if available
+        self.model = self.model.to(self.device)
+
+        # initialize the training: iterate over the entire training data set
+        for epoch in range(self.epochs):
+
+            # set the model to training mode
+            print('Setting model to training mode ...')
+            self.model.train()
+
+            # iterate over the dataloader object
+            for batch, (inputs, labels) in enumerate(self.train_dl):
+
+                # send the data to the gpu if available
+                inputs = inputs.to(self.device)
+                labels = labels.to(self.device)
+
+                # reset the gradients
+                self.optimizer.zero_grad()
+
+                # perform forward pass
+                outputs = self.model(inputs)
+
+                # compute loss
+                loss = self.loss_function(outputs, labels.long())
+                observed_loss = loss.detach().numpy().item()
+                training_state['tl'][batch, epoch] = observed_loss
+
+                # compute the gradients of the loss function w.r.t.
+                # the network weights
+                loss.backward()
+
+                # update the weights
+                self.optimizer.step()
+
+                # calculate predicted class labels
+                ypred = F.softmax(outputs, dim=1).argmax(dim=1)
+
+                # calculate accuracy on current batch
+                observed_accuracy = accuracy_function(ypred, labels)
+                training_state['ta'][batch, epoch] = observed_accuracy
+
+                # print progress
+                print('Epoch: {:d}/{:d}, Mini-batch: {:d}/{:d}, Loss: {:.2f}, '
+                      'Accuracy: {:.2f}'.format(epoch + 1,
+                                                self.epochs,
+                                                batch + 1,
+                                                len(self.train_dl),
+                                                observed_loss,
+                                                observed_accuracy))
+
+            # update the number of epochs trained
+            self.model.epoch += 1
+
+            # whether to evaluate model performance on the validation set and
+            # early stop the training process
+            if self.early_stop:
+
+                # model predictions on the validation set
+                vacc, vloss = self.predict()
+
+                # append observed accuracy and loss to arrays
+                training_state['va'][:, epoch] = vacc.squeeze()
+                training_state['vl'][:, epoch] = vloss.squeeze()
+
+                # metric to assess model performance on the validation set
+                epoch_acc = vacc.squeeze().mean()
+
+                # whether the model improved with respect to the previous epoch
+                if es.increased(epoch_acc, self.max_accuracy, self.delta):
+                    self.max_accuracy = epoch_acc
+                    # save model state if the model improved with
+                    # respect to the previous epoch
+                    _ = self.model.save(self.state_file,
+                                        self.optimizer,
+                                        self.bands,
+                                        self.state_path)
+
+                    # save losses and accuracy
+                    self._save_loss(training_state,
+                                    self.checkpoint,
+                                    self.checkpoint_state)
+
+                # whether the early stopping criterion is met
+                if es.stop(epoch_acc):
+                    break
+
+            else:
+                # if no early stopping is required, the model state is saved
+                # after each epoch
+                _ = self.model.save(self.state_file,
+                                    self.optimizer,
+                                    self.bands,
+                                    self.state_path)
+
+                # save losses and accuracy after each epoch
+                self._save_loss(training_state,
+                                self.checkpoint,
+                                self.checkpoint_state)
+
+        return training_state
+
+    def predict(self):
+
+        print('------------------------ Predicting --------------------------')
+
+        # send the model to the gpu if available
+        self.model = self.model.to(self.device)
+
+        # set the model to evaluation mode
+        print('Setting model to evaluation mode ...')
+        self.model.eval()
+
+        # create arrays of the observed losses and accuracies
+        accuracies = np.zeros(shape=(len(self.valid_dl), 1))
+        losses = np.zeros(shape=(len(self.valid_dl), 1))
+
+        # iterate over the validation/test set
+        print('Calculating accuracy on the validation set ...')
+        for batch, (inputs, labels) in enumerate(self.valid_dl):
+
+            # send the data to the gpu if available
+            inputs = inputs.to(self.device)
+            labels = labels.to(self.device)
+
+            # calculate network outputs
+            with torch.no_grad():
+                outputs = self.model(inputs)
+
+            # compute loss
+            loss = self.loss_function(outputs, labels.long())
+            losses[batch, 0] = loss.detach().numpy().item()
+
+            # calculate predicted class labels
+            pred = F.softmax(outputs, dim=1).argmax(dim=1)
+
+            # calculate accuracy on current batch
+            acc = accuracy_function(pred, labels)
+            accuracies[batch, 0] = acc
+
+            # print progress
+            print('Mini-batch: {:d}/{:d}, Accuracy: {:.2f}'
+                  .format(batch + 1, len(self.valid_dl), acc))
+
+        # calculate overall accuracy on the validation/test set
+        print('After training for {:d} epochs, we achieved an overall '
+              'accuracy of {:.2f}%  on the validation set!'
+              .format(self.model.epoch, accuracies.mean() * 100))
+
+        return accuracies, losses
+
+    def _init_state(self):
+
+        # file to save model state to
+        # format: network_dataset_seed_tilesize_batchsize_bands.pt
+
+        # get the band numbers
+        bformat = ''.join(band[0] +
+                          str(self.dataset.sensor.__members__[band].value) for
+                          band in self.bands)
+
+        # model state filename
+        self.state_file = ('{}_{}_s{}_t{}_b{}_{}.pt'
+                           .format(self.model.__name__,
+                                   self.dataset.__class__.__name__,
+                                   self.seed,
+                                   self.tile_size,
+                                   self.batch_size,
+                                   bformat))
+
+        # check whether a pretrained model was used and change state filename
+        # accordingly
+        if self.pretrained:
+            # add the configuration of the pretrained model to the state name
+            self.state_file = (self.state_file.replace('.pt', '_') +
+                               'pretrained_' + self.pretrained_model)
+
+        # path to model state
+        self.state = os.path.join(self.state_path, self.state_file)
+
+        # path to model loss/accuracy
+        self.loss_state = self.state.replace('.pt', '_loss.pt')
+
+    def _init_dataset(self):
+
+        # the dataset name
+        self.dataset_name = os.path.basename(self.root_dir)
+
+        # check whether the dataset is currently supported
+        if self.dataset_name not in SupportedDatasets.__members__:
+            raise ValueError('{} is not a valid dataset. '
+                             .format(self.dataset_name) +
+                             'Available datasets are: \n' +
+                             '\n'.join(name for name, _ in
+                                       SupportedDatasets.__members__.items()))
+        else:
+            self.dataset_class = SupportedDatasets.__members__[
+                self.dataset_name].value
+
+        # instanciate the dataset
+        self.dataset = self.dataset_class(
+                    self.root_dir,
+                    use_bands=self.bands,
+                    tile_size=self.tile_size,
+                    sort=self.sort,
+                    transforms=self.transforms,
+                    pad=self.pad,
+                    cval=self.cval,
+                    gt_pattern=self.gt_pattern
+                    )
+
+        # the mode to split
+        if self.split_mode not in VALID_SPLIT_MODES:
+            raise ValueError('{} is not supported. Valid modes are {}, see '
+                             'pysegcnn.main.config.py for a description of '
+                             'each mode.'.format(self.split_mode,
+                                                 VALID_SPLIT_MODES))
+        if self.split_mode == 'random':
+            self.subset = RandomTileSplit(self.dataset,
+                                          self.ttratio,
+                                          self.tvratio,
+                                          self.seed)
+        if self.split_mode == 'scene':
+            self.subset = RandomSceneSplit(self.dataset,
+                                           self.ttratio,
+                                           self.tvratio,
+                                           self.seed)
+        if self.split_mode == 'date':
+            self.subset = DateSplit(self.dataset,
+                                    self.date,
+                                    self.dateformat)
+
+        # the training, validation and test dataset
+        self.train_ds, self.valid_ds, self.test_ds = self.subset.split()
+
+        # whether to drop training samples with a fraction of pixels equal to
+        # the constant padding value self.cval >= self.drop
+        if self.pad and self.drop:
+            self._drop(self.train_ds)
+
+        # the shape of a single batch
+        self.batch_shape = (len(self.bands), self.tile_size, self.tile_size)
+
+        # the training dataloader
+        self.train_dl = None
+        if len(self.train_ds) > 0:
+            self.train_dl = DataLoader(self.train_ds,
+                                       self.batch_size,
+                                       shuffle=True,
+                                       drop_last=False)
+        # the validation dataloader
+        self.valid_dl = None
+        if len(self.valid_ds) > 0:
+            self.valid_dl = DataLoader(self.valid_ds,
+                                       self.batch_size,
+                                       shuffle=True,
+                                       drop_last=False)
+
+        # the test dataloader
+        self.test_dl = None
+        if len(self.test_ds) > 0:
+            self.test_dl = DataLoader(self.test_ds,
+                                      self.batch_size,
+                                      shuffle=True,
+                                      drop_last=False)
+
+    def _init_model(self):
+
+        # initial accuracy on the validation set
+        self.max_accuracy = 0
+
+        # set the model checkpoint to None, overwritten when resuming
+        # training from an existing model checkpoint
+        self.checkpoint_state = None
+
+        # case (1): build a model for the specified dataset
+        if not self.pretrained and not self.checkpoint:
+
+            # instanciate the model
+            self.model = self.model(in_channels=len(self.dataset.use_bands),
+                                    nclasses=len(self.dataset.labels),
+                                    filters=self.filters,
+                                    skip=self.skip_connection,
+                                    **self.kwargs)
+
+            # the optimizer used to update the model weights
+            self.optimizer = self.optimizer(self.model.parameters(), self.lr)
+
+        # case (2): using a pretrained model withouth existing checkpoint on
+        #           a new dataset, i.e. transfer learning
+        if self.pretrained and not self.checkpoint:
+            # load pretrained model
+            self.model = self.from_pretrained()
+
+            # the optimizer used to update the model weights
+            self.optimizer = self.optimizer(self.model.parameters(), self.lr)
+
+        # case (3): using a pretrained model with existing checkpoint on the
+        #           same dataset the pretrained model was trained on
+        elif self.checkpoint:
+
+            # instanciate the model
+            self.model = self.model(in_channels=len(self.dataset.use_bands),
+                                    nclasses=len(self.dataset.labels),
+                                    filters=self.filters,
+                                    skip=self.skip_connection,
+                                    **self.kwargs)
+
+            # the optimizer used to update the model weights
+            self.optimizer = self.optimizer(self.model.parameters(), self.lr)
+
+            # whether to resume training from an existing model checkpoint
+            if self.checkpoint:
+                (self.checkpoint_state,
+                 self.max_accuracy) = self.from_checkpoint()
+
+    # function to drop samples with a fraction of pixels equal to the constant
+    # padding value self.cval >= self.drop
+    def _drop(self, ds):
+
+        # iterate over the scenes returned by self.compose_scenes()
+        self.dropped = []
+        for pos, i in enumerate(ds.indices):
+
+            # the current scene
+            s = ds.dataset.scenes[i]
+
+            # the current tile in the ground truth
+            tile_gt = img2np(s['gt'], self.tile_size, s['tile'],
+                             self.pad, self.cval)
+
+            # percent of pixels equal to the constant padding value
+            npixels = (tile_gt[tile_gt == self.cval].size / tile_gt.size)
+
+            # drop samples where npixels >= self.drop
+            if npixels >= self.drop:
+                print('Skipping scene {}, tile {}: {:.2f}% padded pixels ...'
+                      .format(s['id'], s['tile'], npixels * 100))
+                self.dropped.append(s)
+                _ = ds.indices.pop(pos)
+
+    def _save_loss(self, training_state, checkpoint=False,
+                   checkpoint_state=None):
+
+        # save losses and accuracy
+        if checkpoint and checkpoint_state is not None:
+
+            # append values from checkpoint to current training
+            # state
+            torch.save({
+                k1: np.hstack([v1, v2]) for (k1, v1), (k2, v2) in
+                zip(checkpoint_state.items(), training_state.items())
+                if k1 == k2},
+                self.loss_state)
+        else:
+            torch.save(training_state, self.loss_state)
+
+    def __repr__(self):
+
+        # representation string to print
+        fs = self.__class__.__name__ + '(\n'
+
+        # dataset
+        fs += '    (dataset):\n        '
+        fs += ''.join(self.dataset.__repr__()).replace('\n', '\n        ')
+
+        # batch size
+        fs += '\n    (batch):\n        '
+        fs += '- batch size: {}\n        '.format(self.batch_size)
+        fs += '- batch shape (b, h, w): {}'.format(self.batch_shape)
+
+        # dataset split
+        fs += '\n    (split):\n        '
+        fs += ''.join(self.subset.__repr__()).replace('\n', '\n        ')
+
+        # model
+        fs += '\n    (model):\n        '
+        fs += ''.join(self.model.__repr__()).replace('\n', '\n        ')
+
+        # optimizer
+        fs += '\n    (optimizer):\n        '
+        fs += ''.join(self.optimizer.__repr__()).replace('\n', '\n        ')
+        fs += '\n)'
+
+        return fs
+
+
+class EarlyStopping(object):
+
+    def __init__(self, mode='max', min_delta=0, patience=10):
+
+        # check if mode is correctly specified
+        if mode not in ['min', 'max']:
+            raise ValueError('Mode "{}" not supported. '
+                             'Mode is either "min" (check whether the metric '
+                             'decreased, e.g. loss) or "max" (check whether '
+                             'the metric increased, e.g. accuracy).'
+                             .format(mode))
+
+        # mode to determine if metric improved
+        self.mode = mode
+
+        # whether to check for an increase or a decrease in a given metric
+        self.is_better = self.decreased if mode == 'min' else self.increased
+
+        # minimum change in metric to be classified as an improvement
+        self.min_delta = min_delta
+
+        # number of epochs to wait for improvement
+        self.patience = patience
+
+        # initialize best metric
+        self.best = None
+
+        # initialize early stopping flag
+        self.early_stop = False
+
+    def stop(self, metric):
+
+        if self.best is not None:
+
+            # if the metric improved, reset the epochs counter, else, advance
+            if self.is_better(metric, self.best, self.min_delta):
+                self.counter = 0
+                self.best = metric
+            else:
+                self.counter += 1
+                print('Early stopping counter: {}/{}'.format(self.counter,
+                                                             self.patience))
+
+            # if the metric did not improve over the last patience epochs,
+            # the early stopping criterion is met
+            if self.counter >= self.patience:
+                print('Early stopping criterion met, exiting training ...')
+                self.early_stop = True
+
+        else:
+            self.best = metric
+
+        return self.early_stop
+
+    def decreased(self, metric, best, min_delta):
+        return metric < best - min_delta
+
+    def increased(self, metric, best, min_delta):
+        return metric > best + min_delta