diff --git a/pytorch/trainer.py b/pytorch/trainer.py
index 7376ccdee5ef2e0b9cc396a47de4647d369f46cb..3bf269588f7fd5aed77e44387d2be1dfedcf02ce 100755
--- a/pytorch/trainer.py
+++ b/pytorch/trainer.py
@@ -18,6 +18,7 @@ from torch.utils.data import random_split, DataLoader
 # local modules
 from pytorch.dataset import SparcsDataset, Cloud95Dataset
 from pytorch.layers import Conv2dSame
+from pytorch.constants import SupportedDatasets
 
 
 class NetworkTrainer(object):
@@ -28,105 +29,26 @@ class NetworkTrainer(object):
         for k, v in config.items():
             setattr(self, k, v)
 
-        # check which dataset the model is trained on
-        if self.dataset_name == 'Sparcs':
-            # instanciate the SparcsDataset
-            self.dataset = SparcsDataset(self.dataset_path,
-                                         use_bands=self.bands,
-                                         tile_size=self.tile_size)
-        elif self.dataset_name == 'Cloud95':
-            # instanciate the Cloud95Dataset
-            self.dataset = Cloud95Dataset(self.dataset_path,
-                                          use_bands=self.bands,
-                                          tile_size=self.tile_size,
-                                          exclude=self.patches)
-        else:
-            raise ValueError('{} is not a valid dataset. Available datasets '
-                             'are "Sparcs" and "Cloud95".'
-                             .format(self.dataset_name))
-
-        # print the bands used for the segmentation
-        print('------------------------ Input bands -------------------------')
-        print(*['Band {}: {}'.format(i, b) for i, b in
-                enumerate(self.dataset.use_bands)], sep='\n')
-        print('--------------------------------------------------------------')
-
-        # print the classes of interest
-        print('-------------------------- Classes ---------------------------')
-        print(*['Class {}: {}'.format(k, v['label']) for k, v in
-                self.dataset.labels.items()], sep='\n')
-        print('--------------------------------------------------------------')
-
-        # instanciate the segmentation network
-        print('------------------- Network architecture ---------------------')
-        if self.pretrained:
-            self.model = self.from_pretrained()
-        else:
-            self.model = self.net(in_channels=len(self.dataset.use_bands),
-                                  nclasses=len(self.dataset.labels),
-                                  filters=self.filters,
-                                  skip=self.skip_connection,
-                                  **self.kwargs)
-        print(self.model)
-        print('--------------------------------------------------------------')
-
-        # the training and validation dataset
-        print('------------------------ Dataset split -----------------------')
-        self.train_ds, self.valid_ds, self.test_ds = self.train_val_test_split(
-            self.dataset, self.tvratio, self.ttratio, self.seed)
-        print('--------------------------------------------------------------')
-
-        # number of batches in the validation set
-        self.nvbatches = int(len(self.valid_ds) / self.batch_size)
-
-        # number of batches in the training set
-        self.nbatches = int(len(self.train_ds) / self.batch_size)
-
-        # the training and validation dataloaders
-        self.train_dl = DataLoader(self.train_ds,
-                                   self.batch_size,
-                                   shuffle=True,
-                                   drop_last=True)
-        self.valid_dl = DataLoader(self.valid_ds,
-                                   self.batch_size,
-                                   shuffle=True,
-                                   drop_last=True)
-
-        # the optimizer used to update the model weights
-        self.optimizer = self.optimizer(self.model.parameters(), self.lr)
-
         # whether to use the gpu
         self.device = torch.device("cuda:0" if torch.cuda.is_available() else
                                    "cpu")
 
-        # file to save model state to
-        # format: networkname_datasetname_t(tilesize)_b(batchsize)_bands.pt
-        bformat = ''.join([b[0] for b in self.bands]) if self.bands else 'all'
-        self.state_file = ('{}_{}_t{}_b{}_{}.pt'
-                           .format(self.model.__class__.__name__,
-                                   self.dataset.__class__.__name__,
-                                   self.tile_size,
-                                   self.batch_size,
-                                   bformat))
-
-        # check whether a pretrained model was used and change state filename
-        # accordingly
-        if self.pretrained:
-            # add the configuration of the pretrained model to the state name
-            self.state_file = (self.state_file.replace('.pt', '_') +
-                               'pretrained_' + self.pretrained_model)
-
-        # path to model state
-        self.state = os.path.join(self.state_path, self.state_file)
+        # initialize the dataset to train the model on
+        self._init_dataset()
 
-        # path to model loss/accuracy
-        self.loss_state = self.state.replace('.pt', '_loss.pt')
+        # initialize the model
+        self._init_model()
 
     def from_pretrained(self):
 
         # load the pretrained model
-        model_state = torch.load(
-            os.path.join(self.state_path, self.pretrained_model))
+        model_state = os.path.join(self.state_path, self.pretrained_model)
+        if not os.path.exists(model_state):
+            raise FileNotFoundError('Pretrained model {} does not exist.'
+                                    .format(model_state))
+
+        # load the model state
+        model_state = torch.load(model_state)
 
         # get the input bands of the pretrained model
         bands = model_state['bands']
@@ -152,8 +74,10 @@ class NetworkTrainer(object):
                                       out_channels=len(self.dataset.labels),
                                       kernel_size=1)
 
-        return model
+        # adjust the number of classes in the model
+        model.nclasses = len(self.dataset.labels)
 
+        return model
 
     def ds_len(self, ds, ratio):
         return int(np.round(len(ds) * ratio))
@@ -193,22 +117,6 @@ class NetworkTrainer(object):
     def accuracy_function(self, outputs, labels):
         return (outputs == labels).float().mean()
 
-    def _save_loss(self, training_state, checkpoint=False,
-                   checkpoint_state=None):
-
-        # save losses and accuracy
-        if checkpoint and checkpoint_state is not None:
-
-            # append values from checkpoint to current training
-            # state
-            torch.save({
-                k1: np.hstack([v1, v2]) for (k1, v1), (k2, v2) in
-                zip(checkpoint_state.items(), training_state.items())
-                if k1 == k2},
-                self.loss_state)
-        else:
-            torch.save(training_state, self.loss_state)
-
     def train(self):
 
         # set the number of threads
@@ -418,6 +326,117 @@ class NetworkTrainer(object):
 
         return cm, accuracies, losses
 
+    def _init_dataset(self):
+
+        # check whether the dataset is currently supported
+        self.dataset = None
+        for dataset in SupportedDatasets:
+            if self.dataset_name == dataset.name:
+                self.dataset = dataset.value['class'](self.dataset_path,
+                                                      self.bands,
+                                                      self.tile_size,
+                                                      self.sort,
+                                                      self.transforms)
+        if self.dataset is None:
+            print('{} is not a valid dataset.').format(self.dataset_name)
+            print('Available datasets are:')
+            for name, _ in SupportedDatasets.__members__.items():
+                print(name)
+            raise ValueError('Dataset not supported.')
+
+        # print the bands used for the segmentation
+        print('------------------------ Input bands -------------------------')
+        print(*['Band {}: {}'.format(i, b) for i, b in
+                enumerate(self.dataset.use_bands)], sep='\n')
+        print('--------------------------------------------------------------')
+
+        # print the classes of interest
+        print('-------------------------- Classes ---------------------------')
+        print(*['Class {}: {}'.format(k, v['label']) for k, v in
+                self.dataset.labels.items()], sep='\n')
+        print('--------------------------------------------------------------')
+
+        # the training and validation dataset
+        print('------------------------ Dataset split -----------------------')
+        self.train_ds, self.valid_ds, self.test_ds = self.train_val_test_split(
+            self.dataset, self.tvratio, self.ttratio, self.seed)
+
+        # number of batches in the validation set
+        self.nvbatches = int(len(self.valid_ds) / self.batch_size)
+
+        # number of batches in the training set
+        self.nbatches = int(len(self.train_ds) / self.batch_size)
+        print('Number of training batches: {}'.format(self.nbatches))
+        print('Number of validation batches: {}'.format(self.nvbatches))
+        print('--------------------------------------------------------------')
+
+        # the training and validation dataloaders
+        self.train_dl = DataLoader(self.train_ds,
+                                   self.batch_size,
+                                   shuffle=True,
+                                   drop_last=True)
+        self.valid_dl = DataLoader(self.valid_ds,
+                                   self.batch_size,
+                                   shuffle=True,
+                                   drop_last=True)
+
+    def _init_model(self):
+
+        # instanciate the segmentation network
+        print('------------------- Network architecture ---------------------')
+        if self.pretrained:
+            self.model = self.from_pretrained()
+        else:
+            self.model = self.net(in_channels=len(self.dataset.use_bands),
+                                  nclasses=len(self.dataset.labels),
+                                  filters=self.filters,
+                                  skip=self.skip_connection,
+                                  **self.kwargs)
+        print(self.model)
+        print('--------------------------------------------------------------')
+
+        # the optimizer used to update the model weights
+        self.optimizer = self.optimizer(self.model.parameters(), self.lr)
+
+        # file to save model state to
+        # format: networkname_datasetname_t(tilesize)_b(batchsize)_bands.pt
+        bformat = ''.join([b[0] for b in self.bands]) if self.bands else 'all'
+        self.state_file = ('{}_{}_t{}_b{}_{}.pt'
+                           .format(self.model.__class__.__name__,
+                                   self.dataset.__class__.__name__,
+                                   self.tile_size,
+                                   self.batch_size,
+                                   bformat))
+
+        # check whether a pretrained model was used and change state filename
+        # accordingly
+        if self.pretrained:
+            # add the configuration of the pretrained model to the state name
+            self.state_file = (self.state_file.replace('.pt', '_') +
+                               'pretrained_' + self.pretrained_model)
+
+        # path to model state
+        self.state = os.path.join(self.state_path, self.state_file)
+
+        # path to model loss/accuracy
+        self.loss_state = self.state.replace('.pt', '_loss.pt')
+
+    def _save_loss(self, training_state, checkpoint=False,
+                   checkpoint_state=None):
+
+        # save losses and accuracy
+        if checkpoint and checkpoint_state is not None:
+
+            # append values from checkpoint to current training
+            # state
+            torch.save({
+                k1: np.hstack([v1, v2]) for (k1, v1), (k2, v2) in
+                zip(checkpoint_state.items(), training_state.items())
+                if k1 == k2},
+                self.loss_state)
+        else:
+            torch.save(training_state, self.loss_state)
+
 
 class EarlyStopping(object):