diff --git a/pysegcnn/core/trainer.py b/pysegcnn/core/trainer.py
index cbe4d0fadd2fef89537832bb7e9eb443f1c62df4..855b76b1559f878e0ea935d1451adc72833ad756 100644
--- a/pysegcnn/core/trainer.py
+++ b/pysegcnn/core/trainer.py
@@ -8,6 +8,7 @@ Created on Wed Aug 12 10:24:34 2020
 import dataclasses
 import pathlib
 import logging
+import datetime
 
 # externals
 import numpy as np
@@ -199,7 +200,6 @@ class ModelConfig(BaseConfig):
     skip_connection: bool = True
     kwargs: dict = dataclasses.field(
         default_factory=lambda: {'kernel_size': 3, 'stride': 1, 'dilation': 1})
-    state_path: pathlib.Path = pathlib.Path(HERE).joinpath('_models/')
     batch_size: int = 64
     checkpoint: bool = False
     transfer: bool = False
@@ -226,11 +226,16 @@ class ModelConfig(BaseConfig):
         # check whether the loss function is currently supported
         self.loss_class = item_in_enum(self.loss_name, SupportedLossFunctions)
 
+        # path to model states
+        self.state_path = pathlib.Path(HERE).joinpath('_models/')
+
         # path to pretrained model
         self.pretrained_path = self.state_path.joinpath(self.pretrained_model)
 
     def init_optimizer(self, model):
 
+        LOGGER.info('Optimizer: {}.'.format(repr(self.optim_class)))
+
         # initialize the optimizer for the specified model
         optimizer = self.optim_class(model.parameters(), self.lr)
 
@@ -238,17 +243,28 @@ class ModelConfig(BaseConfig):
 
     def init_loss_function(self):
 
+        LOGGER.info('Loss function: {}.'.format(repr(self.loss_class)))
+
+        # instanciate the loss function
         loss_function = self.loss_class()
 
         return loss_function
 
-    def init_model(self, ds):
+    def init_model(self, ds, state_file):
+
+        # write an initialization string to the log file
+        # now = datetime.datetime.strftime(datetime.datetime.now(),
+        #                                  '%Y-%m-%dT%H:%M:%S')
+        # LOGGER.info(80 * '-')
+        # LOGGER.info('{}: Initializing model run. '.format(now) + 35 * '-')
+        # LOGGER.info(80 * '-')
 
         # case (1): build a new model
         if not self.transfer:
 
             # set the random seed for reproducibility
             torch.manual_seed(self.torch_seed)
+            LOGGER.info('Initializing model: {}'.format(state_file.name))
 
             # instanciate the model
             model = self.model_class(
@@ -261,104 +277,86 @@ class ModelConfig(BaseConfig):
         # case (2): load a pretrained model for transfer learning
         else:
             # load pretrained model
-            model, _ = self.load_pretrained(self.pretrained_path, new_ds=ds)
+            LOGGER.info('Loading pretrained model for transfer learning from: '
+                        '{}'.format(self.pretrained_path))
+            model = self.transfer_model(self.pretrained_path, ds)
 
-        return model
-
-    def from_checkpoint(self, model, optimizer, state_file, loss_state):
+        # initialize the optimizer
+        optimizer = self.init_optimizer(model)
 
         # whether to resume training from an existing model checkpoint
         checkpoint_state = {}
-        max_accuracy = 0
         if self.checkpoint:
+            model, optimizer, checkpoint_state = self.load_checkpoint(
+                model, optimizer, state_file)
 
-            # check whether the checkpoint exists
-            if state_file.exists() and loss_state.exists():
-                # load model checkpoint
-                model, optimizer = self.load_pretrained(state_file, optimizer,
-                                                        new_ds=None)
-                (checkpoint_state, max_accuracy) = self.load_checkpoint(
-                    loss_state)
-            else:
-                LOGGER.info('Checkpoint for model {} does not exist. '
-                            'Initializing new model.'.format(state_file.name))
-
-        return model, optimizer, checkpoint_state, max_accuracy
+        return model, optimizer, checkpoint_state
 
     @staticmethod
-    def load_pretrained(state_file, optimizer=None, new_ds=None):
-
-        # load the pretrained model
-        if not state_file.exists():
-            raise FileNotFoundError('Pretrained model {} does not exist.'
-                                    .format(state_file))
-
-        LOGGER.info('Loading pretrained model: {}'.format(state_file.name))
-
-        # load the model state
-        model_state = torch.load(state_file)
+    def load_checkpoint(model, optimizer, state_file):
 
-        # the model class
-        model_class = model_state['cls']
-
-        # instanciate pretrained model architecture
-        model = model_class(**model_state['params'], **model_state['kwargs'])
-
-        # load pretrained model weights
-        _ = model.load(state_file.name, optimizer=optimizer,
-                       inpath=str(state_file.parent))
-        LOGGER.info('Model epoch: {:d}'.format(model.epoch))
+        # whether to resume training from an existing model checkpoint
+        checkpoint_state = {}
 
-        # check whether to apply pretrained model on a new dataset
-        if new_ds is not None:
-            LOGGER.info('Configuring model for new dataset: {}.'
-                        .format(new_ds.__class__.__name__))
+        # if no checkpoint exists, file a warning and continue with a model
+        # initialized from scratch
+        if not state_file.exists():
+            LOGGER.warning('Checkpoint for model {} does not exist. '
+                           'Initializing new model.'
+                           .format(state_file.name))
+        else:
+            # load model checkpoint
+            model, optimizer, model_state = Network.load(state_file, optimizer)
 
-            # the bands the model was trained with
-            bands = model_state['bands']
+            # load model loss and accuracy
 
-            # check whether the current dataset uses the correct spectral bands
-            if new_ds.use_bands != bands:
-                raise ValueError('The pretrained network was trained with the '
-                                 'bands {}, not with: {}'
-                                 .format(bands, new_ds.use_bands))
+            # get all non-zero elements, i.e. get number of epochs trained
+            # before the early stop
+            checkpoint_state = {k: v[np.nonzero(v)].reshape(v.shape[0], -1)
+                                for k, v in model_state['state'].items()}
 
-            # get the number of convolutional filters
-            filters = model_state['params']['filters']
+        return model, optimizer, checkpoint_state
 
-            # reset model epoch to 0, since the model is trained on a different
-            # dataset
-            model.epoch = 0
+    @staticmethod
+    def transfer_model(state_file, ds):
 
-            # adjust the number of classes in the model
-            model.nclasses = len(new_ds.labels)
-            LOGGER.info('Replacing classification layer to classes: {}.'
-                        .format(', '.join('({}, {})'.format(k, v['label'])
-                                          for k, v in new_ds.labels.items())))
+        # check input type
+        if not isinstance(ds, ImageDataset):
+            raise TypeError('Expected "ds" to be {}.'
+                            .format('.'.join([ImageDataset.__module__,
+                                              ImageDataset.__name__])))
 
-            # adjust the classification layer to the number of classes of the
-            # current dataset
-            model.classifier = Conv2dSame(in_channels=filters[0],
-                                          out_channels=model.nclasses,
-                                          kernel_size=1)
+        # load the pretrained model
+        model, _, model_state = Network.load(state_file)
+        LOGGER.info('Configuring model for new dataset: {}.'.format(
+            ds.__class__.__name__))
 
-        return model, optimizer
+        # check whether the current dataset uses the correct spectral bands
+        if new_ds.use_bands != model_state['bands']:
+            raise ValueError('The pretrained network was trained with '
+                             'bands {}, not with bands {}.'
+                             .format(model_state['bands'], new_ds.use_bands))
 
-    @staticmethod
-    def load_checkpoint(loss_state):
+        # get the number of convolutional filters
+        filters = model_state['params']['filters']
 
-        # load the model loss and accuracy
-        checkpoint_state = torch.load(loss_state)
+        # reset model epoch to 0, since the model is trained on a different
+        # dataset
+        model.epoch = 0
 
-        # get all non-zero elements, i.e. get number of epochs trained
-        # before the early stop
-        checkpoint_state = {k: v[np.nonzero(v)].reshape(v.shape[0], -1)
-                            for k, v in checkpoint_state.items()}
+        # adjust the number of classes in the model
+        model.nclasses = len(ds.labels)
+        LOGGER.info('Replacing classification layer to classes: {}.'
+                    .format(', '.join('({}, {})'.format(k, v['label'])
+                                      for k, v in ds.labels.items())))
 
-        # maximum accuracy on the validation set
-        max_accuracy = checkpoint_state['va'][:, -1].mean().item()
+        # adjust the classification layer to the number of classes of the
+        # current dataset
+        model.classifier = Conv2dSame(in_channels=filters[0],
+                                      out_channels=model.nclasses,
+                                      kernel_size=1)
 
-        return checkpoint_state, max_accuracy
+        return model
 
 
 @dataclasses.dataclass
@@ -420,14 +418,12 @@ class StateConfig(BaseConfig):
         # path to model state
         state = self.mc.state_path.joinpath(state_file)
 
-        # path to model loss/accuracy
-        loss_state = pathlib.Path(str(state).replace('.pt', '_loss.pt'))
-
-        return state, loss_state
+        return state
 
 
 @dataclasses.dataclass
 class EvalConfig(BaseConfig):
+    state_file: pathlib.Path
     test: object
     predict_scene: bool = False
     plot_samples: bool = False
@@ -446,6 +442,32 @@ class EvalConfig(BaseConfig):
             raise TypeError('Expected "test" to be None, True or False, got '
                             '{}.'.format(self.test))
 
+        # the output paths for the different graphics
+        self.base_path = pathlib.Path(HERE)
+        self.sample_path = self.base_path.joinpath('_samples')
+        self.scenes_path = self.base_path.joinpath('_scenes')
+        self.models_path = self.base_path.joinpath('_graphics')
+
+        # write initialization string to log file
+        # LOGGER.info(80 * '-')
+        # LOGGER.info('{}')
+        # LOGGER.info(80 * '-')
+
+
+@dataclasses.dataclass
+class LogConfig(BaseConfig):
+    state_file: pathlib.Path
+
+    def __post_init__(self):
+        super().__post_init__()
+
+        # the path to store model logs
+        self.log_path = pathlib.Path(HERE).joinpath('_logs')
+
+        # the log file of the current model
+        self.log_file = self.log_path.joinpath(
+            self.state_file.name.replace('.pt', '.log'))
+
 
 @dataclasses.dataclass
 class NetworkTrainer(BaseConfig):
@@ -454,15 +476,14 @@ class NetworkTrainer(BaseConfig):
     loss_function: nn.Module
     train_dl: DataLoader
     valid_dl: DataLoader
+    test_dl: DataLoader
     state_file: pathlib.Path
-    loss_state: pathlib.Path
     epochs: int = 1
     nthreads: int = torch.get_num_threads()
     early_stop: bool = False
     mode: str = 'max'
     delta: float = 0
     patience: int = 10
-    max_accuracy: float = 0
     checkpoint_state: dict = dataclasses.field(default_factory=dict)
     save: bool = True
 
@@ -473,16 +494,21 @@ class NetworkTrainer(BaseConfig):
         self.device = torch.device("cuda:0" if torch.cuda.is_available()
                                    else "cpu")
 
+        # maximum accuracy on the validation dataset
+        self.max_accuracy = 0
+        if self.checkpoint_state:
+            self.max_accuracy = self.checkpoint_state['va'].mean(
+                axis=0).max().item()
+
         # whether to use early stopping
         self.es = None
         if self.early_stop:
             self.es = EarlyStopping(self.mode, self.max_accuracy, self.delta,
                                     self.patience)
 
-
     def train(self):
 
-        LOGGER.info(30 * '-' + ' Training ' + 30 * '-')
+        LOGGER.info(35 * '-' + ' Training ' + 35 * '-')
 
         # set the number of threads
         LOGGER.info('Device: {}'.format(self.device))
@@ -493,11 +519,11 @@ class NetworkTrainer(BaseConfig):
         # training and validation dataset
         tshape = (len(self.train_dl), self.epochs)
         vshape = (len(self.valid_dl), self.epochs)
-        training_state = {'tl': np.zeros(shape=tshape),
-                          'ta': np.zeros(shape=tshape),
-                          'vl': np.zeros(shape=vshape),
-                          'va': np.zeros(shape=vshape)
-                          }
+        self.training_state = {'tl': np.zeros(shape=tshape),
+                               'ta': np.zeros(shape=tshape),
+                               'vl': np.zeros(shape=vshape),
+                               'va': np.zeros(shape=vshape)
+                               }
 
         # send the model to the gpu if available
         self.model = self.model.to(self.device)
@@ -525,7 +551,7 @@ class NetworkTrainer(BaseConfig):
                 # compute loss
                 loss = self.loss_function(outputs, labels.long())
                 observed_loss = loss.detach().numpy().item()
-                training_state['tl'][batch, epoch] = observed_loss
+                self.training_state['tl'][batch, epoch] = observed_loss
 
                 # compute the gradients of the loss function w.r.t.
                 # the network weights
@@ -539,7 +565,7 @@ class NetworkTrainer(BaseConfig):
 
                 # calculate accuracy on current batch
                 observed_accuracy = accuracy_function(ypred, labels)
-                training_state['ta'][batch, epoch] = observed_accuracy
+                self.training_state['ta'][batch, epoch] = observed_accuracy
 
                 # print progress
                 LOGGER.info('Epoch: {:d}/{:d}, Mini-batch: {:d}/{:d}, '
@@ -562,8 +588,8 @@ class NetworkTrainer(BaseConfig):
                 vacc, vloss = self.predict()
 
                 # append observed accuracy and loss to arrays
-                training_state['va'][:, epoch] = vacc.squeeze()
-                training_state['vl'][:, epoch] = vloss.squeeze()
+                self.training_state['va'][:, epoch] = vacc.squeeze()
+                self.training_state['vl'][:, epoch] = vloss.squeeze()
 
                 # metric to assess model performance on the validation set
                 epoch_acc = vacc.squeeze().mean()
@@ -574,7 +600,7 @@ class NetworkTrainer(BaseConfig):
 
                     # save model state if the model improved with
                     # respect to the previous epoch
-                    self.save_state(training_state)
+                    self.save_state()
 
                 # whether the early stopping criterion is met
                 if self.es.stop(epoch_acc):
@@ -583,15 +609,13 @@ class NetworkTrainer(BaseConfig):
             else:
                 # if no early stopping is required, the model state is
                 # saved after each epoch
-                self.save_state(training_state)
+                self.save_state()
 
 
-        return training_state
+        return self.training_state
 
     def predict(self):
 
-        LOGGER.info(30 * '-' + ' Predicting ' + 30 * '-')
-
         # send the model to the gpu if available
         self.model = self.model.to(self.device)
 
@@ -631,37 +655,38 @@ class NetworkTrainer(BaseConfig):
                         .format(batch + 1, len(self.valid_dl), acc))
 
         # calculate overall accuracy on the validation/test set
-        LOGGER.info('Epoch {:d}, Overall accuracy: {:.2f}%.'
+        LOGGER.info('Epoch: {:d}, Mean accuracy: {:.2f}%.'
                     .format(self.model.epoch, accuracies.mean() * 100))
 
         return accuracies, losses
 
-    def save_state(self, training_state):
+    def save_state(self):
 
         # whether to save the model state
         if self.save:
-            # save model state
-            state = self.model.save(self.state_file.name,
-                                    self.optimizer,
-                                    self.train_dl.dataset.dataset.use_bands,
-                                    self.state_file.parent)
-
-            # save losses and accuracy
-            self._save_loss(training_state)
 
-    def _save_loss(self, training_state):
+            # append the model performance before the checkpoint to the model
+            # state, if a checkpoint is passed
+            if self.checkpoint_state:
 
-        # save losses and accuracy
-        state = training_state
-        if self.checkpoint_state:
+                # append values from checkpoint to current training state
+                state = {k1: np.hstack([v1, v2]) for (k1, v1), (k2, v2) in
+                         zip(self.checkpoint_state.items(),
+                             self.training_state.items()) if k1 == k2}
+            else:
+                state = self.training_state
 
-            # append values from checkpoint to current training state
-            state = {k1: np.hstack([v1, v2]) for (k1, v1), (k2, v2) in
-                     zip(self.checkpoint_state.items(), training_state.items())
-                     if k1 == k2}
+            # save model state
+            _ = self.model.save(
+                self.state_file,
+                self.optimizer,
+                bands=self.train_dl.dataset.dataset.use_bands,
+                train_ds=self.train_dl.dataset,
+                valid_ds=self.valid_dl.dataset,
+                test_ds=self.test_dl.dataset,
+                state=state,
+                )
 
-        # save the model loss and accuracies to file
-        torch.save(state, self.loss_state)
 
     def __repr__(self):
 
@@ -687,6 +712,7 @@ class NetworkTrainer(BaseConfig):
         fs += '\n    (split):'
         fs += '\n        ' + repr(self.train_dl.dataset)
         fs += '\n        ' + repr(self.valid_dl.dataset)
+        fs += '\n        ' + repr(self.test_dl.dataset)
 
         # model
         fs += '\n    (model):\n        '
@@ -764,6 +790,6 @@ class EarlyStopping(object):
 
     def __repr__(self):
         fs = self.__class__.__name__
-        fs += '(mode={}, best={}, delta={}, patience={})'.format(
+        fs += '(mode={}, best={:.2f}, delta={}, patience={})'.format(
             self.mode, self.best, self.min_delta, self.patience)
         return fs