diff --git a/pysegcnn/core/trainer.py b/pysegcnn/core/trainer.py
index 73d3a4399f3bfa6ca3ce4c011e48cf71a3ad143c..757fb3f37e2492be8db4b394997d1ba717232ee1 100644
--- a/pysegcnn/core/trainer.py
+++ b/pysegcnn/core/trainer.py
@@ -7,6 +7,7 @@ Created on Wed Aug 12 10:24:34 2020
 # builtins
 import dataclasses
 import pathlib
+import logging
 
 # externals
 import numpy as np
@@ -16,7 +17,6 @@ import torch.nn.functional as F
 from torch.utils.data import DataLoader, Dataset
 from torch.optim import Optimizer
 
-
 # locals
 from pysegcnn.core.dataset import SupportedDatasets, ImageDataset
 from pysegcnn.core.transforms import Augment
@@ -27,6 +27,9 @@ from pysegcnn.core.models import (SupportedModels, SupportedOptimizers,
 from pysegcnn.core.layers import Conv2dSame
 from pysegcnn.main.config import HERE
 
+# module level logger
+LOGGER = logging.getLogger(__name__)
+
 
 @dataclasses.dataclass
 class BaseConfig:
@@ -60,7 +63,6 @@ class DatasetConfig(BaseConfig):
     sort: bool = False
     transforms: list = dataclasses.field(default_factory=list)
     pad: bool = False
-    cval: int = 99
 
     def __post_init__(self):
         # check input types
@@ -80,11 +82,6 @@ class DatasetConfig(BaseConfig):
                             ' of {}.'.format('.'.join([Augment.__module__,
                                                        Augment.__name__])))
 
-        # check whether the constant padding value is within the valid range
-        if not 0 < self.cval < 255:
-            raise ValueError('Expecting 0 <= cval <= 255, got cval={}.'
-                             .format(self.cval))
-
     def init_dataset(self):
 
         # instanciate the dataset
@@ -96,7 +93,6 @@ class DatasetConfig(BaseConfig):
                     sort=self.sort,
                     transforms=self.transforms,
                     pad=self.pad,
-                    cval=self.cval,
                     gt_pattern=self.gt_pattern
                     )
 
@@ -121,7 +117,8 @@ class SplitConfig(BaseConfig):
 
     # function to drop samples with a fraction of pixels equal to the constant
     # padding value self.cval >= self.drop
-    def _drop_samples(self, ds, drop_threshold=1):
+    @staticmethod
+    def _drop_samples(ds, drop_threshold=1):
 
         # iterate over the scenes returned by self.compose_scenes()
         dropped = []
@@ -139,8 +136,8 @@ class SplitConfig(BaseConfig):
 
             # drop samples where npixels >= self.drop
             if npixels >= drop_threshold:
-                print('Skipping scene {}, tile {}: {:.2f}% padded pixels ...'
-                      .format(s['id'], s['tile'], npixels * 100))
+                LOGGER.info('Skipping scene {}, tile {}: {:.2f}% padded pixels'
+                            ' ...'.format(s['id'], s['tile'], npixels * 100))
                 dropped.append(s)
                 _ = ds.indices.pop(pos)
 
@@ -197,14 +194,24 @@ class ModelConfig(BaseConfig):
     model_name: str
     filters: list
     torch_seed: int
+    optim_name: str
+    loss_name: str
     skip_connection: bool = True
     kwargs: dict = dataclasses.field(
         default_factory=lambda: {'kernel_size': 3, 'stride': 1, 'dilation': 1})
     state_path: pathlib.Path = pathlib.Path(HERE).joinpath('_models/')
     batch_size: int = 64
     checkpoint: bool = False
-    pretrained: bool = False
+    transfer: bool = False
     pretrained_model: str = ''
+    lr: float = 0.001
+    early_stop: bool = False
+    mode: str = 'max'
+    delta: float = 0
+    patience: int = 10
+    epochs: int = 50
+    nthreads: int = torch.get_num_threads()
+    save: bool = True
 
     def __post_init__(self):
         # check input types
@@ -213,65 +220,32 @@ class ModelConfig(BaseConfig):
         # check whether the model is currently supported
         self.model_class = item_in_enum(self.model_name, SupportedModels)
 
-    def init_state(self, ds, sc, tc):
-
-        # file to save model state to:
-        # network_dataset_optim_split_splitparams_tilesize_batchsize_bands.pt
+        # check whether the optimizer is currently supported
+        self.optim_class = item_in_enum(self.optim_name, SupportedOptimizers)
 
-        # model state filename
-        state_file = '{}_{}_{}_{}Split_{}_t{}_b{}_{}.pt'
+        # check whether the loss function is currently supported
+        self.loss_class = item_in_enum(self.loss_name, SupportedLossFunctions)
 
-        # get the band numbers
-        bformat = ''.join(band[0] +
-                          str(ds.sensor.__members__[band].value) for
-                          band in ds.use_bands)
+        # path to pretrained model
+        self.pretrained_path = self.state_path.joinpath(self.pretrained_model)
 
-        # check which split mode was used
-        if sc.split_mode == 'date':
-            # store the date that was used to split the dataset
-            state_file = state_file.format(self.model_class.__name__,
-                                           ds.__class__.__name__,
-                                           tc.optim_name,
-                                           sc.split_mode.capitalize(),
-                                           sc.date,
-                                           ds.tile_size,
-                                           self.batch_size,
-                                           bformat)
-        else:
-            # store the random split parameters
-            split_params = 's{}_t{}v{}'.format(
-                ds.seed, str(sc.ttratio).replace('.', ''),
-                str(sc.tvratio).replace('.', ''))
+    def init_optimizer(self, model):
 
-            # model state filename
-            state_file = state_file.format(self.model_class.__name__,
-                                           ds.__class__.__name__,
-                                           tc.optim_name,
-                                           sc.split_mode.capitalize(),
-                                           split_params,
-                                           ds.tile_size,
-                                           self.batch_size,
-                                           bformat)
+        # initialize the optimizer for the specified model
+        optimizer = self.optim_class(model.parameters(), self.lr)
 
-        # check whether a pretrained model was used and change state filename
-        # accordingly
-        if self.pretrained:
-            # add the configuration of the pretrained model to the state name
-            state_file = (state_file.replace('.pt', '_') +
-                          'pretrained_' + self.pretrained_model)
+        return optimizer
 
-        # path to model state
-        state = self.state_path.joinpath(state_file)
+    def init_loss_function(self):
 
-        # path to model loss/accuracy
-        loss_state = pathlib.Path(str(state).replace('.pt', '_loss.pt'))
+        loss_function = self.loss_class()
 
-        return state, loss_state
+        return loss_function
 
     def init_model(self, ds):
 
         # case (1): build a new model
-        if not self.pretrained:
+        if not self.transfer:
 
             # set the random seed for reproducibility
             torch.manual_seed(self.torch_seed)
@@ -284,130 +258,172 @@ class ModelConfig(BaseConfig):
                 skip=self.skip_connection,
                 **self.kwargs)
 
-        # case (2): load a pretrained model
+        # case (2): load a pretrained model for transfer learning
         else:
-
             # load pretrained model
-            model = self.load_pretrained()
+            model, _ = self.load_pretrained(self.pretrained_path, new_ds=ds)
 
         return model
 
-    def load_checkpoint(self, state_file, loss_state, model, optimizer):
+    def from_checkpoint(self, model, optimizer, state_file, loss_state):
 
-        # initial accuracy on the validation set
+        # whether to resume training from an existing model checkpoint
+        checkpoint_state = {}
         max_accuracy = 0
+        if self.checkpoint:
 
-        # set the model checkpoint to None, overwritten when resuming
-        # training from an existing model checkpoint
-        checkpoint_state = {}
+            # check whether the checkpoint exists
+            if state_file.exists() and loss_state.exists():
+                # load model checkpoint
+                model, optimizer = self.load_pretrained(state_file, optimizer,
+                                                        new_ds=None)
+                (checkpoint_state, max_accuracy) = self.load_checkpoint(
+                    loss_state)
+            else:
+                LOGGER.info('Checkpoint for model {} does not exist. '
+                            'Initializing new model.'.format(state_file.name))
 
-        # whether to resume training from an existing model
-        if self.checkpoint:
+        return model, optimizer, checkpoint_state, max_accuracy
 
-            # check if a model checkpoint exists
-            if not state_file.exists():
-                raise FileNotFoundError('Model checkpoint {} does not exist.'
-                                        .format(state_file))
+    @staticmethod
+    def load_pretrained(state_file, optimizer=None, new_ds=None):
 
-            # load the model state
-            state = model.load(state_file.name, optimizer, self.state_path)
-            print('Found checkpoint: {}'.format(state))
-            print('Resuming training from checkpoint ...'.format(state))
-            print('Model epoch: {:d}'.format(model.epoch))
+        # load the pretrained model
+        if not state_file.exists():
+            raise FileNotFoundError('Pretrained model {} does not exist.'
+                                    .format(state_file))
 
-            # load the model loss and accuracy
-            checkpoint_state = torch.load(loss_state)
+        LOGGER.info('Loading pretrained model: {}'.format(state_file.name))
 
-            # get all non-zero elements, i.e. get number of epochs trained
-            # before the early stop
-            checkpoint_state = {k: v[np.nonzero(v)].reshape(v.shape[0], -1)
-                                for k, v in checkpoint_state.items()}
+        # load the model state
+        model_state = torch.load(state_file)
 
-            # maximum accuracy on the validation set
-            max_accuracy = checkpoint_state['va'][:, -1].mean().item()
+        # the model class
+        model_class = model_state['cls']
 
-        return checkpoint_state, max_accuracy
+        # instanciate pretrained model architecture
+        model = model_class(**model_state['params'], **model_state['kwargs'])
 
-    def load_pretrained(self, ds):
+        # load pretrained model weights
+        _ = model.load(state_file.name, optimizer=optimizer,
+                       inpath=str(state_file.parent))
+        LOGGER.info('Model epoch: {:d}'.format(model.epoch))
 
-        # load the pretrained model
-        model_state = self.state_path.joinpath(self.pretrained_model)
-        if not model_state.exists():
-            raise FileNotFoundError('Pretrained model {} does not exist.'
-                                    .format(model_state))
+        # check whether to apply pretrained model on a new dataset
+        if new_ds is not None:
+            LOGGER.info('Configuring model for new dataset: {}.'
+                        .format(new_ds.__class__.__name__))
 
-        # load the model state
-        model_state = torch.load(model_state)
+            # the bands the model was trained with
+            bands = model_state['bands']
 
-        # get the input bands of the pretrained model
-        bands = model_state['bands']
+            # check whether the current dataset uses the correct spectral bands
+            if new_ds.use_bands != bands:
+                raise ValueError('The pretrained network was trained with the '
+                                 'bands {}, not with: {}'
+                                 .format(bands, new_ds.use_bands))
 
-        # get the number of convolutional filters
-        filters = model_state['params']['filters']
+            # get the number of convolutional filters
+            filters = model_state['params']['filters']
 
-        # check whether the current dataset uses the correct spectral bands
-        if ds.use_bands != bands:
-            raise ValueError('The bands of the pretrained network do not '
-                             'match the specified bands: {}'
-                             .format(bands))
+            # reset model epoch to 0, since the model is trained on a different
+            # dataset
+            model.epoch = 0
 
-        # instanciate pretrained model architecture
-        model = self.model_class(**model_state['params'],
-                                 **model_state['kwargs'])
+            # adjust the number of classes in the model
+            model.nclasses = len(new_ds.labels)
+            LOGGER.info('Replacing classification layer to classes: {}.'
+                        .format(', '.join('({}, {})'.format(k, v['label'])
+                                          for k, v in new_ds.labels.items())))
 
-        # load pretrained model weights
-        model.load(self.pretrained_model, inpath=str(self.state_path))
+            # adjust the classification layer to the number of classes of the
+            # current dataset
+            model.classifier = Conv2dSame(in_channels=filters[0],
+                                          out_channels=model.nclasses,
+                                          kernel_size=1)
 
-        # reset model epoch to 0, since the model is trained on a different
-        # dataset
-        model.epoch = 0
+        return model, optimizer
 
-        # adjust the number of classes in the model
-        model.nclasses = len(ds.labels)
+    @staticmethod
+    def load_checkpoint(loss_state):
 
-        # adjust the classification layer to the number of classes of the
-        # current dataset
-        model.classifier = Conv2dSame(in_channels=filters[0],
-                                      out_channels=model.nclasses,
-                                      kernel_size=1)
+        # load the model loss and accuracy
+        checkpoint_state = torch.load(loss_state)
 
-        return model
+        # get all non-zero elements, i.e. get number of epochs trained
+        # before the early stop
+        checkpoint_state = {k: v[np.nonzero(v)].reshape(v.shape[0], -1)
+                            for k, v in checkpoint_state.items()}
+
+        # maximum accuracy on the validation set
+        max_accuracy = checkpoint_state['va'][:, -1].mean().item()
+
+        return checkpoint_state, max_accuracy
 
 
 @dataclasses.dataclass
-class TrainConfig(BaseConfig):
-    optim_name: str
-    loss_name: str
-    lr: float = 0.001
-    early_stop: bool = False
-    mode: str = 'max'
-    delta: float = 0
-    patience: int = 10
-    epochs: int = 50
-    nthreads: int = torch.get_num_threads()
-    save: bool = True
+class StateConfig(BaseConfig):
+    ds: ImageDataset
+    sc: SplitConfig
+    mc: ModelConfig
 
     def __post_init__(self):
         super().__post_init__()
 
-        # check whether the optimizer is currently supported
-        self.optim_class = item_in_enum(self.optim_name, SupportedOptimizers)
+    def init_state(self):
 
-        # check whether the loss function is currently supported
-        self.loss_class = item_in_enum(self.loss_name, SupportedLossFunctions)
+        # file to save model state to:
+        # network_dataset_optim_split_splitparams_tilesize_batchsize_bands.pt
 
-    def init_optimizer(self, model):
+        # model state filename
+        state_file = '{}_{}_{}_{}Split_{}_t{}_b{}_{}.pt'
 
-        # initialize the optimizer for the specified model
-        optimizer = self.optim_class(model.parameters(), self.lr)
+        # get the band numbers
+        bformat = ''.join(band[0] +
+                          str(self.ds.sensor.__members__[band].value) for
+                              band in self.ds.use_bands)
 
-        return optimizer
+        # check which split mode was used
+        if self.sc.split_mode == 'date':
+            # store the date that was used to split the dataset
+            state_file = state_file.format(self.mc.model_name,
+                                           self.ds.__class__.__name__,
+                                           self.mc.optim_name,
+                                           self.sc.split_mode.capitalize(),
+                                           self.sc.date,
+                                           self.ds.tile_size,
+                                           self.mc.batch_size,
+                                           bformat)
+        else:
+            # store the random split parameters
+            split_params = 's{}_t{}v{}'.format(
+                self.ds.seed, str(self.sc.ttratio).replace('.', ''),
+                str(self.sc.tvratio).replace('.', ''))
 
-    def init_loss_function(self):
+            # model state filename
+            state_file = state_file.format(self.mc.model_name,
+                                           self.ds.__class__.__name__,
+                                           self.mc.optim_name,
+                                           self.sc.split_mode.capitalize(),
+                                           split_params,
+                                           self.ds.tile_size,
+                                           self.mc.batch_size,
+                                           bformat)
 
-        loss_function = self.loss_class()
+        # check whether a pretrained model was used and change state filename
+        # accordingly
+        if self.mc.transfer:
+            # add the configuration of the pretrained model to the state name
+            state_file = (state_file.replace('.pt', '_') +
+                          'pretrained_' + self.mc.pretrained_model)
 
-        return loss_function
+        # path to model state
+        state = self.mc.state_path.joinpath(state_file)
+
+        # path to model loss/accuracy
+        loss_state = pathlib.Path(str(state).replace('.pt', '_loss.pt'))
+
+        return state, loss_state
 
 
 @dataclasses.dataclass
@@ -428,6 +444,7 @@ class EvalConfig(BaseConfig):
             raise TypeError('Expected "test" to be None, True or False, got '
                             '{}.'.format(self.test))
 
+
 @dataclasses.dataclass
 class NetworkTrainer(BaseConfig):
     model: Network
@@ -457,15 +474,16 @@ class NetworkTrainer(BaseConfig):
         # whether to use early stopping
         self.es = None
         if self.early_stop:
-            self.es = EarlyStopping(self.mode, self.delta, self.patience)
+            self.es = EarlyStopping(self.mode, self.max_accuracy, self.delta,
+                                    self.patience)
 
     def train(self):
 
-        print('------------------------- Training ---------------------------')
+        LOGGER.info(30 * '-' + ' Training ' + 30 * '-')
 
         # set the number of threads
-        print('Device: {}'.format(self.device))
-        print('Number of cpu threads: {}'.format(self.nthreads))
+        LOGGER.info('Device: {}'.format(self.device))
+        LOGGER.info('Number of cpu threads: {}'.format(self.nthreads))
         torch.set_num_threads(self.nthreads)
 
         # create dictionary of the observed losses and accuracies on the
@@ -485,7 +503,7 @@ class NetworkTrainer(BaseConfig):
         for epoch in range(self.epochs):
 
             # set the model to training mode
-            print('Setting model to training mode ...')
+            LOGGER.info('Setting model to training mode ...')
             self.model.train()
 
             # iterate over the dataloader object
@@ -521,13 +539,14 @@ class NetworkTrainer(BaseConfig):
                 training_state['ta'][batch, epoch] = observed_accuracy
 
                 # print progress
-                print('Epoch: {:d}/{:d}, Mini-batch: {:d}/{:d}, Loss: {:.2f}, '
-                      'Accuracy: {:.2f}'.format(epoch + 1,
-                                                self.epochs,
-                                                batch + 1,
-                                                len(self.train_dl),
-                                                observed_loss,
-                                                observed_accuracy))
+                LOGGER.info('Epoch: {:d}/{:d}, Mini-batch: {:d}/{:d}, '
+                            'Loss: {:.2f}, Accuracy: {:.2f}'.format(
+                                epoch + 1,
+                                self.epochs,
+                                batch + 1,
+                                len(self.train_dl),
+                                observed_loss,
+                                observed_accuracy))
 
             # update the number of epochs trained
             self.model.epoch += 1
@@ -568,13 +587,13 @@ class NetworkTrainer(BaseConfig):
 
     def predict(self):
 
-        print('------------------------ Predicting --------------------------')
+        LOGGER.info(30 * '-' + ' Predicting ' + 30 * '-')
 
         # send the model to the gpu if available
         self.model = self.model.to(self.device)
 
         # set the model to evaluation mode
-        print('Setting model to evaluation mode ...')
+        LOGGER.info('Setting model to evaluation mode ...')
         self.model.eval()
 
         # create arrays of the observed losses and accuracies
@@ -582,7 +601,7 @@ class NetworkTrainer(BaseConfig):
         losses = np.zeros(shape=(len(self.valid_dl), 1))
 
         # iterate over the validation/test set
-        print('Calculating accuracy on the validation set ...')
+        LOGGER.info('Calculating accuracy on the validation set ...')
         for batch, (inputs, labels) in enumerate(self.valid_dl):
 
             # send the data to the gpu if available
@@ -605,12 +624,12 @@ class NetworkTrainer(BaseConfig):
             accuracies[batch, 0] = acc
 
             # print progress
-            print('Mini-batch: {:d}/{:d}, Accuracy: {:.2f}'
-                  .format(batch + 1, len(self.valid_dl), acc))
+            LOGGER.info('Mini-batch: {:d}/{:d}, Accuracy: {:.2f}'
+                        .format(batch + 1, len(self.valid_dl), acc))
 
         # calculate overall accuracy on the validation/test set
-        print('Epoch {:d}, Overall accuracy: {:.2f}%.'
-              .format(self.model.epoch, accuracies.mean() * 100))
+        LOGGER.info('Epoch {:d}, Overall accuracy: {:.2f}%.'
+                    .format(self.model.epoch, accuracies.mean() * 100))
 
         return accuracies, losses
 
@@ -649,7 +668,7 @@ class NetworkTrainer(BaseConfig):
         # dataset
         fs += '    (dataset):\n        '
         fs += ''.join(
-            repr(self.train_dl.dataset.dataset)).replace('\n','\n        ')
+            repr(self.train_dl.dataset.dataset)).replace('\n', '\n        ')
 
         # batch size
         fs += '\n    (batch):\n        '
@@ -684,7 +703,7 @@ class NetworkTrainer(BaseConfig):
 
 class EarlyStopping(object):
 
-    def __init__(self, mode='max', min_delta=0, patience=10):
+    def __init__(self, mode='max', best=0, min_delta=0, patience=10):
 
         # check if mode is correctly specified
         if mode not in ['min', 'max']:
@@ -707,7 +726,7 @@ class EarlyStopping(object):
         self.patience = patience
 
         # initialize best metric
-        self.best = None
+        self.best = best
 
         # initialize early stopping flag
         self.early_stop = False
@@ -717,25 +736,20 @@ class EarlyStopping(object):
 
     def stop(self, metric):
 
-        if self.best is not None:
-
-            # if the metric improved, reset the epochs counter, else, advance
-            if self.is_better(metric, self.best, self.min_delta):
-                self.counter = 0
-                self.best = metric
-            else:
-                self.counter += 1
-                print('Early stopping counter: {}/{}'.format(self.counter,
-                                                             self.patience))
-
-            # if the metric did not improve over the last patience epochs,
-            # the early stopping criterion is met
-            if self.counter >= self.patience:
-                print('Early stopping criterion met, exiting training ...')
-                self.early_stop = True
-
-        else:
+        # if the metric improved, reset the epochs counter, else, advance
+        if self.is_better(metric, self.best, self.min_delta):
+            self.counter = 0
             self.best = metric
+        else:
+            self.counter += 1
+            LOGGER.info('Early stopping counter: {}/{}'.format(
+                self.counter, self.patience))
+
+        # if the metric did not improve over the last patience epochs,
+        # the early stopping criterion is met
+        if self.counter >= self.patience:
+            LOGGER.info('Early stopping criterion met, stopping training.')
+            self.early_stop = True
 
         return self.early_stop
 
@@ -746,6 +760,7 @@ class EarlyStopping(object):
         return metric > best + min_delta
 
     def __repr__(self):
-        fs = (self.__class__.__name__ + '(mode={}, delta={}, patience={})'
-              .format(self.mode, self.min_delta, self.patience))
+        fs = self.__class__.__name__
+        fs += '(mode={}, best={}, delta={}, patience={})'.format(
+            self.mode, self.best, self.min_delta, self.patience)
         return fs
diff --git a/pysegcnn/main/train.py b/pysegcnn/main/train.py
index 6c036ea6e10ae589595af6c98ff4246eae19afd0..bcffb54dd38220e5e308df477dc7428fe65db614 100644
--- a/pysegcnn/main/train.py
+++ b/pysegcnn/main/train.py
@@ -5,11 +5,14 @@ Created on Tue Jun 30 09:33:38 2020
 
 @author: Daniel
 """
+# builtins
+import logging
+
 # locals
 from pysegcnn.core.trainer import (DatasetConfig, SplitConfig, ModelConfig,
-                                   TrainConfig, NetworkTrainer)
-from pysegcnn.main.config import (dataset_config, split_config,
-                                  model_config, train_config)
+                                   StateConfig, NetworkTrainer)
+from pysegcnn.core.logging import log_conf
+from pysegcnn.main.config import (dataset_config, split_config, model_config)
 
 
 if __name__ == '__main__':
@@ -20,35 +23,36 @@ if __name__ == '__main__':
     dc = DatasetConfig(**dataset_config)
     sc = SplitConfig(**split_config)
     mc = ModelConfig(**model_config)
-    tc = TrainConfig(**train_config)
 
     # (ii) instanciate the dataset
     ds = dc.init_dataset()
     ds
 
-    # (iii) instanciate the training, validation and test datasets and
+    # (iii) instanciate the model state
+    state = StateConfig(ds, sc, mc)
+    state_file, loss_state = state.init_state()
+
+    # initialize logging
+    log_file = str(state_file).replace('.pt', '_train.log')
+    logging.config.dictConfig(log_conf(log_file))
+
+    # (iv) instanciate the training, validation and test datasets and
     # dataloaders
     train_ds, valid_ds, test_ds = sc.train_val_test_split(ds)
-    train_dl, valid_dl, test_dl = sc.dataloaders(train_ds,
-                                                 valid_ds,
-                                                 test_ds,
-                                                 batch_size=mc.batch_size,
-                                                 shuffle=True,
-                                                 drop_last=False)
-
-    # (iv) instanciate the model state files
-    state_file, loss_state = mc.init_state(ds, sc, tc)
+    train_dl, valid_dl, test_dl = sc.dataloaders(
+        train_ds, valid_ds, test_ds, batch_size=mc.batch_size, shuffle=True,
+        drop_last=False)
 
-    # (v) instanciate the model
+    # (iv) instanciate the model
     model = mc.init_model(ds)
 
     # (vi) instanciate the optimizer and the loss function
-    optimizer = tc.init_optimizer(model)
-    loss_function = tc.init_loss_function()
+    optimizer = mc.init_optimizer(model)
+    loss_function = mc.init_loss_function()
 
     # (vii) resume training from an existing model checkpoint
-    checkpoint_state, max_accuracy = mc.load_checkpoint(state_file, loss_state,
-                                                        model, optimizer)
+    (model, optimizer, checkpoint_state, max_accuracy) = mc.from_checkpoint(
+        model, optimizer, state_file, loss_state)
 
     # (viii) initialize network trainer class for eays model training
     trainer = NetworkTrainer(model=model,
@@ -58,15 +62,15 @@ if __name__ == '__main__':
                              valid_dl=valid_dl,
                              state_file=state_file,
                              loss_state=loss_state,
-                             epochs=tc.epochs,
-                             nthreads=tc.nthreads,
-                             early_stop=tc.early_stop,
-                             mode=tc.mode,
-                             delta=tc.delta,
-                             patience=tc.patience,
+                             epochs=mc.epochs,
+                             nthreads=mc.nthreads,
+                             early_stop=mc.early_stop,
+                             mode=mc.mode,
+                             delta=mc.delta,
+                             patience=mc.patience,
                              max_accuracy=max_accuracy,
                              checkpoint_state=checkpoint_state,
-                             save=tc.save
+                             save=mc.save
                              )
 
     # (ix) train model