From 6e330d114a7ca8f2575dc900952a161e739af211 Mon Sep 17 00:00:00 2001
From: "Daniel.Frisinghelli" <daniel.frisinghelli@eurac.edu>
Date: Tue, 11 Aug 2020 16:41:15 +0200
Subject: [PATCH] Improved initialization of dataset and model

---
 pysegcnn/core/trainer.py | 166 +++++++++++++++++++++++----------------
 1 file changed, 100 insertions(+), 66 deletions(-)

diff --git a/pysegcnn/core/trainer.py b/pysegcnn/core/trainer.py
index 96bb55e..09239c9 100644
--- a/pysegcnn/core/trainer.py
+++ b/pysegcnn/core/trainer.py
@@ -22,7 +22,6 @@ from pysegcnn.core.split import (RandomTileSplit, RandomSceneSplit, DateSplit,
                                  VALID_SPLIT_MODES)
 
 
-
 class NetworkTrainer(object):
 
     def __init__(self, config):
@@ -38,12 +37,11 @@ class NetworkTrainer(object):
         # initialize the dataset to train the model on
         self._init_dataset()
 
-        # initialize the model
-        self._init_model()
-
         # initialize the model state files
         self._init_state()
 
+        # initialize the model
+        self._init_model()
 
     def from_pretrained(self):
 
@@ -69,7 +67,7 @@ class NetworkTrainer(object):
                              .format(self.bands))
 
         # instanciate pretrained model architecture
-        model = self.net(**model_state['params'], **model_state['kwargs'])
+        model = self.model(**model_state['params'], **model_state['kwargs'])
 
         # load pretrained model weights
         model.load(self.pretrained_model, inpath=self.state_path)
@@ -78,39 +76,41 @@ class NetworkTrainer(object):
         # dataset
         model.epoch = 0
 
+        # adjust the number of classes in the model
+        model.nclasses = len(self.dataset.labels)
+
         # adjust the classification layer to the number of classes of the
         # current dataset
         model.classifier = Conv2dSame(in_channels=filters[0],
-                                      out_channels=len(self.dataset.labels),
+                                      out_channels=model.nclasses,
                                       kernel_size=1)
 
-        # adjust the number of classes in the model
-        model.nclasses = len(self.dataset.labels)
 
         return model
 
     def from_checkpoint(self):
 
         # whether to resume training from an existing model
-        checkpoint_state = None
-        max_accuracy = 0
-        if os.path.exists(self.state) and self.checkpoint:
-            # load the model state
-            state = self.model.load(self.state_file, self.optimizer,
-                                    self.state_path)
-            print('Resuming training from {} ...'.format(state))
-            print('Model epoch: {:d}'.format(self.model.epoch))
+        if not os.path.exists(self.state):
+            raise FileNotFoundError('Model checkpoint {} does not exist.'
+                                    .format(self.state))
+
+        # load the model state
+        state = self.model.load(self.state_file, self.optimizer,
+                                self.state_path)
+        print('Resuming training from {} ...'.format(state))
+        print('Model epoch: {:d}'.format(self.model.epoch))
 
-            # load the model loss and accuracy
-            checkpoint_state = torch.load(self.loss_state)
+        # load the model loss and accuracy
+        checkpoint_state = torch.load(self.loss_state)
 
-            # get all non-zero elements, i.e. get number of epochs trained
-            # before the early stop
-            checkpoint_state = {k: v[np.nonzero(v)].reshape(v.shape[0], -1) for
-                                k, v in checkpoint_state.items()}
+        # get all non-zero elements, i.e. get number of epochs trained
+        # before the early stop
+        checkpoint_state = {k: v[np.nonzero(v)].reshape(v.shape[0], -1) for
+                            k, v in checkpoint_state.items()}
 
-            # maximum accuracy on the validation set
-            max_accuracy = checkpoint_state['va'][:, -1].mean().item()
+        # maximum accuracy on the validation set
+        max_accuracy = checkpoint_state['va'][:, -1].mean().item()
 
         return checkpoint_state, max_accuracy
 
@@ -128,9 +128,6 @@ class NetworkTrainer(object):
             print('mode = {}, delta = {}, patience = {} epochs ...'
                   .format(self.mode, self.delta, self.patience))
 
-            # initial accuracy on the validation set
-            max_accuracy = 0
-
         # create dictionary of the observed losses and accuracies on the
         # training and validation dataset
         tshape = (len(self.train_dl), self.epochs)
@@ -141,9 +138,6 @@ class NetworkTrainer(object):
                           'va': np.zeros(shape=vshape)
                           }
 
-        # whether to resume training from an existing model
-        checkpoint_state, max_accuracy = self.from_checkpoint()
-
         # send the model to the gpu if available
         self.model = self.model.to(self.device)
 
@@ -188,7 +182,8 @@ class NetworkTrainer(object):
 
                 # print progress
                 print('Epoch: {:d}/{:d}, Mini-batch: {:d}/{:d}, Loss: {:.2f}, '
-                      'Accuracy: {:.2f}'.format(epoch + 1, self.epochs,
+                      'Accuracy: {:.2f}'.format(epoch + 1,
+                                                self.epochs,
                                                 batch + 1,
                                                 len(self.train_dl),
                                                 observed_loss,
@@ -212,8 +207,8 @@ class NetworkTrainer(object):
                 epoch_acc = vacc.squeeze().mean()
 
                 # whether the model improved with respect to the previous epoch
-                if es.increased(epoch_acc, max_accuracy, self.delta):
-                    max_accuracy = epoch_acc
+                if es.increased(epoch_acc, self.max_accuracy, self.delta):
+                    self.max_accuracy = epoch_acc
                     # save model state if the model improved with
                     # respect to the previous epoch
                     _ = self.model.save(self.state_file,
@@ -224,7 +219,7 @@ class NetworkTrainer(object):
                     # save losses and accuracy
                     self._save_loss(training_state,
                                     self.checkpoint,
-                                    checkpoint_state)
+                                    self.checkpoint_state)
 
                 # whether the early stopping criterion is met
                 if es.stop(epoch_acc):
@@ -241,7 +236,7 @@ class NetworkTrainer(object):
                 # save losses and accuracy after each epoch
                 self._save_loss(training_state,
                                 self.checkpoint,
-                                checkpoint_state)
+                                self.checkpoint_state)
 
         return training_state
 
@@ -300,7 +295,7 @@ class NetworkTrainer(object):
         # format: networkname_datasetname_t(tilesize)_b(batchsize)_bands.pt
         bformat = ''.join([b[0] for b in self.bands]) if self.bands else 'all'
         self.state_file = ('{}_{}_t{}_b{}_{}.pt'
-                           .format(self.model.__class__.__name__,
+                           .format(self.model.__name__,
                                    self.dataset.__class__.__name__,
                                    self.tile_size,
                                    self.batch_size,
@@ -321,34 +316,38 @@ class NetworkTrainer(object):
 
     def _init_dataset(self):
 
-        # check whether the dataset is currently supported
-        self.dataset = None
-        for dataset in SupportedDatasets:
-            if self.dataset_name == dataset.name:
-                self.dataset = dataset.value['class'](
-                    self.dataset_path,
-                    use_bands=self.bands,
-                    tile_size=self.tile_size,
-                    sort=self.sort,
-                    transforms=self.transforms,
-                    pad=self.pad,
-                    cval=self.cval,
-                    gt_pattern=self.gt_pattern)
+        # the dataset name
+        self.dataset_name = os.path.basename(self.root_dir)
 
-        if self.dataset is None:
+        # check whether the dataset is currently supported
+        if self.dataset_name not in SupportedDatasets.__members__:
             raise ValueError('{} is not a valid dataset. '
                              .format(self.dataset_name) +
                              'Available datasets are: \n' +
                              '\n'.join(name for name, _ in
                                        SupportedDatasets.__members__.items()))
+        else:
+            self.dataset_class = SupportedDatasets.__members__[
+                self.dataset_name].value
 
+        # instanciate the dataset
+        self.dataset = self.dataset_class(
+                    self.root_dir,
+                    use_bands=self.bands,
+                    tile_size=self.tile_size,
+                    sort=self.sort,
+                    transforms=self.transforms,
+                    pad=self.pad,
+                    cval=self.cval,
+                    gt_pattern=self.gt_pattern
+                    )
 
         # the mode to split
         if self.split_mode not in VALID_SPLIT_MODES:
             raise ValueError('{} is not supported. Valid modes are {}, see '
-                              'pysegcnn.main.config.py for a description of '
-                              'each mode.'.format(self.split_mode,
-                                                  VALID_SPLIT_MODES))
+                             'pysegcnn.main.config.py for a description of '
+                             'each mode.'.format(self.split_mode,
+                                                 VALID_SPLIT_MODES))
         if self.split_mode == 'random':
             self.subset = RandomTileSplit(self.dataset,
                                           self.ttratio,
@@ -364,12 +363,12 @@ class NetworkTrainer(object):
                                     self.date,
                                     self.dateformat)
 
-        # the training, validation and dataset
+        # the training, validation and test dataset
         self.train_ds, self.valid_ds, self.test_ds = self.subset.split()
 
         # whether to drop training samples with a fraction of pixels equal to
         # the constant padding value self.cval >= self.drop
-        if self.pad:
+        if self.pad and self.drop:
             self._drop(self.train_ds)
 
         # the shape of a single batch
@@ -400,18 +399,53 @@ class NetworkTrainer(object):
 
     def _init_model(self):
 
-        # instanciate the segmentation network
-        if self.pretrained:
+        # initial accuracy on the validation set
+        self.max_accuracy = 0
+
+        # set the model checkpoint to None, overwritten when resuming
+        # training from an existing model checkpoint
+        self.checkpoint_state = None
+
+        # case (1): build a model for the specified dataset
+        if not self.pretrained and not self.checkpoint:
+
+            # instanciate the model
+            self.model = self.model(in_channels=len(self.dataset.use_bands),
+                                    nclasses=len(self.dataset.labels),
+                                    filters=self.filters,
+                                    skip=self.skip_connection,
+                                    **self.kwargs)
+
+            # the optimizer used to update the model weights
+            self.optimizer = self.optimizer(self.model.parameters(), self.lr)
+
+        # case (2): using a pretrained model withouth existing checkpoint on
+        #           a new dataset, i.e. transfer learning
+        if self.pretrained and not self.checkpoint:
+            # load pretrained model
             self.model = self.from_pretrained()
-        else:
-            self.model = self.net(in_channels=len(self.dataset.use_bands),
-                                  nclasses=len(self.dataset.labels),
-                                  filters=self.filters,
-                                  skip=self.skip_connection,
-                                  **self.kwargs)
-
-        # the optimizer used to update the model weights
-        self.optimizer = self.optimizer(self.model.parameters(), self.lr)
+
+            # the optimizer used to update the model weights
+            self.optimizer = self.optimizer(self.model.parameters(), self.lr)
+
+        # case (3): using a pretrained model with existing checkpoint on the
+        #           same dataset the pretrained model was trained on
+        elif self.checkpoint:
+
+            # instanciate the model
+            self.model = self.model(in_channels=len(self.dataset.use_bands),
+                                    nclasses=len(self.dataset.labels),
+                                    filters=self.filters,
+                                    skip=self.skip_connection,
+                                    **self.kwargs)
+
+            # the optimizer used to update the model weights
+            self.optimizer = self.optimizer(self.model.parameters(), self.lr)
+
+            # whether to resume training from an existing model checkpoint
+            if self.checkpoint:
+                (self.checkpoint_state,
+                 self.max_accuracy) = self.from_checkpoint()
 
     # function to drop samples with a fraction of pixels equal to the constant
     # padding value self.cval >= self.drop
-- 
GitLab