diff --git a/pytorch/trainer.py b/pytorch/trainer.py
index e6ee801df9cd6ff8866305fbb3021e1eecd7e7f2..dbcfb805442d2ad20cf40e10aeb2fc7db3f84262 100755
--- a/pytorch/trainer.py
+++ b/pytorch/trainer.py
@@ -20,7 +20,7 @@ sys.path.append('..')
 
 # local modules
 from pytorch.dataset import SparcsDataset, Cloud95Dataset
-from pytorch.constants import SparcsLabels, Cloud95Labels
+from pytorch.layers import Conv2dSame
 
 
 class NetworkTrainer(object):
@@ -31,8 +31,6 @@ class NetworkTrainer(object):
         for k, v in config.items():
             setattr(self, k, v)
 
-    def initialize(self):
-
         # check which dataset the model is trained on
         if self.dataset_name == 'Sparcs':
             # instanciate the SparcsDataset
@@ -81,6 +79,12 @@ class NetworkTrainer(object):
             self.dataset, self.tvratio, self.ttratio, self.seed)
         print('--------------------------------------------------------------')
 
+        # number of batches in the validation set
+        self.nvbatches = int(len(self.valid_ds) / self.batch_size)
+
+        # number of batches in the training set
+        self.nbatches = int(len(self.train_ds) / self.batch_size)
+
         # the training and validation dataloaders
         self.train_dl = DataLoader(self.train_ds,
                                    self.batch_size,
@@ -108,48 +112,49 @@ class NetworkTrainer(object):
                                    self.batch_size,
                                    bformat))
 
+        # check whether a pretrained model was used and change state filename
+        # accordingly
+        if self.pretrained:
+            # add the configuration of the pretrained model to the state name
+            self.state_file = (self.state_file.replace('.pt', '_') +
+                               'pretrained_' + self.pretrained_model)
+
         # path to model state
         self.state = os.path.join(self.state_path, self.state_file)
 
         # path to model loss/accuracy
         self.loss_state = self.state.replace('.pt', '_loss.pt')
 
-
     def from_pretrained(self):
 
-        # name of the dataset the pretrained model was trained on
-        dataset_name = self.pretrained_model.split('_')[1]
+        # load the pretrained model
+        model_state = torch.load(
+            os.path.join(self.state_path, self.pretrained_model))
 
-        # input bands of the pretrained model
-        bands = self.pretrained_model.split('_')[-1].split('.')[0]
+        # get the input bands of the pretrained model
+        bands = model_state['bands']
 
-        if dataset_name == SparcsDataset.__name__:
+        # get the number of convolutional filters
+        filters = model_state['params']['filters']
 
-            # number of input channels
-            in_channels = len(bands) if bands != 'all' else 10
+        # check whether the current dataset uses the correct spectral bands
+        if self.bands != bands:
+            raise ValueError('The bands of the pretrained network do not '
+                             'match the specified bands: {}'
+                             .format(self.bands))
 
-            # instanciate pretrained model architecture
-            model = self.net(in_channels=in_channels,
-                             nclasses=len(SparcsLabels),
-                             filters=self.filters,
-                             skip=self.skip_connection,
-                             **self.kwargs)
-
-        if dataset_name == Cloud95Dataset.__name__:
-
-            # number of input channels
-            in_channels = len(bands) if bands != 'all' else 4
-
-            # instanciate pretrained model architecture
-            model = self.net(in_channels=in_channels,
-                             nclasses=len(Cloud95Labels),
-                             filters=self.filters,
-                             skip=self.skip_connection,
-                             **self.kwargs)
+        # instanciate pretrained model architecture
+        model = self.net(**model_state['params'], **model_state['kwargs'])
 
         # load pretrained model weights
         model.load(self.pretrained_model, inpath=self.state_path)
 
+        # adjust the classification layer to the number of classes of the
+        # current dataset
+        model.classifier = Conv2dSame(in_channels=filters[0],
+                                      out_channels=len(self.dataset.labels),
+                                      kernel_size=1)
+
         return model
 
 
@@ -191,6 +196,22 @@ class NetworkTrainer(object):
     def accuracy_function(self, outputs, labels):
         return (outputs == labels).float().mean()
 
+    def _save_loss(self, training_state, checkpoint=False,
+                   checkpoint_state=None):
+
+        # save losses and accuracy
+        if checkpoint and checkpoint_state is not None:
+
+            # append values from checkpoint to current training
+            # state
+            torch.save({
+                k1: np.hstack([v1, v2]) for (k1, v1), (k2, v2) in
+                zip(checkpoint_state.items(), training_state.items())
+                if k1 == k2},
+                self.loss_state)
+        else:
+            torch.save(training_state, self.loss_state)
+
     def train(self):
 
         # set the number of threads
@@ -206,31 +227,36 @@ class NetworkTrainer(object):
             # initial accuracy on the validation set
             max_accuracy = 0
 
+        # create dictionary of the observed losses and accuracies on the
+        # training and validation dataset
+        training_state = {'tl': np.zeros(shape=(self.nbatches, self.epochs)),
+                          'ta': np.zeros(shape=(self.nbatches, self.epochs)),
+                          'vl': np.zeros(shape=(self.nvbatches, self.epochs)),
+                          'va': np.zeros(shape=(self.nvbatches, self.epochs))
+                          }
+
         # whether to resume training from an existing model
+        checkpoint_state = None
         if os.path.exists(self.state) and self.checkpoint:
+            # load the model state
             state = self.model.load(self.state_file, self.optimizer,
                                     self.state_path)
             print('Resuming training from {} ...'.format(state))
             print('Model epoch: {:d}'.format(self.model.epoch))
 
-        # send the model to the gpu if available
-        self.model = self.model.to(self.device)
+            # load the model loss and accuracy
+            checkpoint_state = torch.load(self.loss_state)
 
-        # number of batches in the validation set
-        nvbatches = int(len(self.valid_ds) / self.batch_size)
-
-        # number of batches in the training set
-        nbatches = int(len(self.train_ds) / self.batch_size)
+            # get all non-zero elements, i.e. get number of epochs trained
+            # before the early stop
+            checkpoint_state = {k: v[np.nonzero(v)].reshape(v.shape[0], -1) for
+                                k, v in checkpoint_state.items()}
 
-        # create arrays of the observed losses and accuracies on the
-        # training set
-        losses = np.zeros(shape=(nbatches, self.epochs))
-        accuracies = np.zeros(shape=(nbatches, self.epochs))
+            # maximum accuracy on the validation set
+            max_accuracy = checkpoint_state['va'][:, -1].mean().item()
 
-        # create arrays of the observed losses and accuracies on the
-        # validation set
-        vlosses = np.zeros(shape=(nvbatches, self.epochs))
-        vaccuracies = np.zeros(shape=(nvbatches, self.epochs))
+        # send the model to the gpu if available
+        self.model = self.model.to(self.device)
 
         # initialize the training: iterate over the entire training data set
         for epoch in range(self.epochs):
@@ -254,7 +280,8 @@ class NetworkTrainer(object):
 
                 # compute loss
                 loss = self.loss_function(outputs, labels.long())
-                losses[batch, epoch] = loss.detach().numpy().item()
+                observed_loss = loss.detach().numpy().item()
+                training_state['tl'][batch, epoch] = observed_loss
 
                 # compute the gradients of the loss function w.r.t.
                 # the network weights
@@ -267,14 +294,17 @@ class NetworkTrainer(object):
                 ypred = F.softmax(outputs, dim=1).argmax(dim=1)
 
                 # calculate accuracy on current batch
-                acc = self.accuracy_function(ypred, labels)
-                accuracies[batch, epoch] = acc
+                observed_accuracy = self.accuracy_function(ypred, labels)
+                training_state['ta'][batch, epoch] = observed_accuracy
 
                 # print progress
                 print('Epoch: {:d}/{:d}, Batch: {:d}/{:d}, Loss: {:.2f}, '
-                      'Accuracy: {:.2f}'.format(epoch, self.epochs, batch,
-                                                nbatches, losses[batch, epoch],
-                                                acc))
+                      'Accuracy: {:.2f}'.format(epoch,
+                                                self.epochs,
+                                                batch,
+                                                self.nbatches,
+                                                observed_loss,
+                                                observed_accuracy))
 
             # update the number of epochs trained
             self.model.epoch += 1
@@ -287,8 +317,8 @@ class NetworkTrainer(object):
                 _, vacc, vloss = self.predict()
 
                 # append observed accuracy and loss to arrays
-                vaccuracies[:, epoch] = vacc.squeeze()
-                vlosses[:, epoch] = vloss.squeeze()
+                training_state['va'][:, epoch] = vacc.squeeze()
+                training_state['vl'][:, epoch] = vloss.squeeze()
 
                 # metric to assess model performance on the validation set
                 epoch_acc = vacc.squeeze().mean()
@@ -298,37 +328,34 @@ class NetworkTrainer(object):
                     max_accuracy = epoch_acc
                     # save model state if the model improved with
                     # respect to the previous epoch
-                    _ = self.model.save(self.optimizer, self.state_file,
+                    _ = self.model.save(self.state_file,
+                                        self.optimizer,
+                                        self.bands,
                                         self.state_path)
 
+                    # save losses and accuracy
+                    self._save_loss(training_state,
+                                    self.checkpoint,
+                                    checkpoint_state)
+
                 # whether the early stopping criterion is met
                 if es.stop(epoch_acc):
-
-                    # save losses and accuracy before exiting training
-                    torch.save({'epoch': epoch,
-                                'training_loss': losses,
-                                'training_accuracy': accuracies,
-                                'validation_loss': vlosses,
-                                'validation_accuracy': vaccuracies},
-                               self.loss_state)
-
                     break
 
             else:
                 # if no early stopping is required, the model state is saved
                 # after each epoch
-                _ = self.model.save(self.optimizer, self.state_file,
+                _ = self.model.save(self.state_file,
+                                    self.optimizer,
+                                    self.bands,
                                     self.state_path)
 
-            # save losses and accuracy after each epoch to file
-            torch.save({'epoch': epoch,
-                        'training_loss': losses,
-                        'training_accuracy': accuracies,
-                        'validation_loss': vlosses,
-                        'validation_accuracy': vaccuracies},
-                       self.loss_state)
+                # save losses and accuracy after each epoch
+                self._save_loss(training_state,
+                                self.checkpoint,
+                                checkpoint_state)
 
-        return losses, accuracies, vlosses, vaccuracies
+        return training_state
 
     def predict(self, pretrained=False, confusion=False):
 
@@ -347,12 +374,9 @@ class NetworkTrainer(object):
         # initialize confusion matrix
         cm = torch.zeros(self.model.nclasses, self.model.nclasses)
 
-        # number of batches in the validation set
-        nbatches = int(len(self.valid_ds) / self.batch_size)
-
         # create arrays of the observed losses and accuracies
-        accuracies = np.zeros(shape=(nbatches, 1))
-        losses = np.zeros(shape=(nbatches, 1))
+        accuracies = np.zeros(shape=(self.nvbatches, 1))
+        losses = np.zeros(shape=(self.nvbatches, 1))
 
         # iterate over the validation/test set
         print('Calculating accuracy on validation set ...')
@@ -379,7 +403,8 @@ class NetworkTrainer(object):
 
             # print progress
             print('Batch: {:d}/{:d}, Accuracy: {:.2f}'.format(batch,
-                                                              nbatches, acc))
+                                                              self.nvbatches,
+                                                              acc))
 
             # update confusion matrix
             if confusion: