From aefec6d367345b8cea3b71e5bdd67cd07d2476c3 Mon Sep 17 00:00:00 2001
From: "Daniel.Frisinghelli" <daniel.frisinghelli@eurac.edu>
Date: Mon, 25 Jan 2021 15:22:52 +0100
Subject: [PATCH] Adjusted NetworkInference to accumulate statistics accross
 different model runs.

---
 pysegcnn/core/trainer.py | 170 ++++++++++++++++++++++-----------------
 1 file changed, 94 insertions(+), 76 deletions(-)

diff --git a/pysegcnn/core/trainer.py b/pysegcnn/core/trainer.py
index 123be81..97ce7c1 100644
--- a/pysegcnn/core/trainer.py
+++ b/pysegcnn/core/trainer.py
@@ -51,7 +51,7 @@ from pysegcnn.core.logging import log_conf
 from pysegcnn.core.graphics import (plot_loss, plot_confusion_matrix,
                                     plot_sample)
 from pysegcnn.core.constants import map_labels
-from pysegcnn.main.config import HERE, DRIVE_PATH
+from pysegcnn.main.train_config import HERE, DRIVE_PATH
 
 
 # module level logger
@@ -270,11 +270,11 @@ class SplitConfig(BaseConfig):
 
         Returns
         -------
-        train_ds : :py:class:`pysegcnn.core.split.CustomSubset`.
+        train_ds : :py:class:`torch.utils.data.Subset`.
             The training set.
-        valid_ds : :py:class:`pysegcnn.core.split.CustomSubset`.
+        valid_ds : :py:class:`torch.utils.data.Subset`.
             The validation set.
-        test_ds : :py:class:`pysegcnn.core.split.CustomSubset`.
+        test_ds : :py:class:`torch.utils.data.Subset`.
             The test set.
 
         """
@@ -478,11 +478,10 @@ class ModelConfig(BaseConfig):
 
         """
         # write an initialization string to the log file
-        LogConfig.init_log('{}: Initializing model run. ')
+        LogConfig.init_log('Initializing model: {} '.format(state_file.name))
 
         # set the random seed for reproducibility
         torch.manual_seed(self.torch_seed)
-        LOGGER.info('Initializing model: {}'.format(state_file.name))
 
         # initialize checkpoint state, i.e. no model checkpoint
         checkpoint_state = {}
@@ -756,7 +755,7 @@ class StateConfig(BaseConfig):
         self.ds_state_ext = 't{}_{}.pt'
 
     def init_state(self, src_dc, src_sc, mc, trg_dc=None, trg_sc=None, tc=None,
-                   fold=0):
+                   fold=None):
         """Generate the model state filename.
 
         Parameters
@@ -767,12 +766,13 @@ class StateConfig(BaseConfig):
             The source domain dataset split configuration.
         mc : :py:class:`pysegcnn.core.trainer.ModelConfig`
             The model configuration.
-        trg_dc : :py:class:`pysegcnn.core.trainer.DatasetConfig`
-            The target domain dataset configuration.
-        trg_sc : :py:class:`pysegcnn.core.trainer.SplitConfig`
-            The target domain dataset split configuration.
-        tc : :py:class:`pysegcnn.core.trainer.TransferLearningConfig`
-            The transfer learning configuration.
+        trg_dc : :py:class:`pysegcnn.core.trainer.DatasetConfig`, optional
+            The target domain dataset configuration. The default is `None`.
+        trg_sc : :py:class:`pysegcnn.core.trainer.SplitConfig`, optional
+            The target domain dataset split configuration. The default is
+            `None`.
+        tc : :py:class:`pysegcnn.core.trainer.TransferLearningConfig`, optional
+            The transfer learning configuration. The default is `None`.
 
         Returns
         -------
@@ -798,33 +798,34 @@ class StateConfig(BaseConfig):
             if trg_dc is None or trg_sc is None:
                 raise ValueError('Target domain configurations required.')
 
-            # target domain dataset state filename
-            trg_ds_state, _ = self.format_ds_state(
-                trg_dc, trg_sc, fold)
-
-            # check whether a pretrained model is used to fine-tune to the
-            # target domain
-            if tc.supervised:
-                # state file for models fine-tuned to target domain
-                # DatasetConfig_PretrainedModel.pt
-
-                # TODO: Is this correct? Trainer is initialized with source
-                # dataloaders
-                state = '_'.join([tc.pretrained_model,
-                                  'sda_{}'.format(trg_ds_state)])
-            else:
-                # state file for models trained via unsupervised domain
-                # adaptation
-                state = '_'.join([state.replace(
-                    src_ds_ext, 'uda_{}'.format(tc.uda_pos)),
-                    trg_ds_state, src_ds_ext])
-
-                # check whether unsupervised domain adaptation is initialized
-                # from a pretrained model state
-                if tc.uda_from_pretrained:
-                    state = '_'.join(state.replace('.pt', ''),
-                                     'prt_{}'.format(
-                                         tc.pretrained_model))
+            # check whether to apply transfer learning
+            if tc.transfer:
+
+                # target domain dataset state filename
+                trg_ds_state, _ = self.format_ds_state(
+                    trg_dc, trg_sc, fold)
+
+                # check whether a pretrained model is used to fine-tune to the
+                # target domain
+                if tc.supervised:
+                    # state file for models fine-tuned to target domain
+                    # PretrainedModel_DatasetConfig.pt
+                    state = '_'.join([
+                        tc.pretrained_model, 'sda_{}'.format(src_ds_state),
+                        src_ds_ext])
+                else:
+                    # state file for models trained via unsupervised domain
+                    # adaptation
+                    state = '_'.join([
+                        state.replace(src_ds_ext, 'uda_{}'.format(tc.uda_pos)),
+                        trg_ds_state, src_ds_ext])
+
+                    # check whether unsupervised domain adaptation is
+                    # initialized from a pretrained model state
+                    if tc.uda_from_pretrained:
+                        state = '_'.join(
+                            [state.replace('.pt', ''),
+                             'prt_{}'.format(tc.pretrained_model)])
 
         # path to model state
         state = mc.state_path.joinpath(state)
@@ -954,7 +955,7 @@ class LogConfig(BaseConfig):
 
         """
         LOGGER.info(80 * '-')
-        LOGGER.info(init_str.format(LogConfig.now()))
+        LOGGER.info('{}: '.format(LogConfig.now()) + init_str)
         LOGGER.info(80 * '-')
 
 
@@ -984,15 +985,15 @@ class ClassificationNetworkTrainer(BaseConfig):
     src_train_dl : :py:class:`torch.utils.data.DataLoader`
         The source domain training :py:class:`torch.utils.data.DataLoader`
         instance build from an instance of
-        :py:class:`pysegcnn.core.split.CustomSubset`.
+        :py:class:`torch.utils.data.Subset`.
     src_valid_dl : :py:class:`torch.utils.data.DataLoader`
         The source domain validation :py:class:`torch.utils.data.DataLoader`
         instance build from an instance of
-        :py:class:`pysegcnn.core.split.CustomSubset`.
+        :py:class:`torch.utils.data.Subset`.
     src_test_dl : :py:class:`torch.utils.data.DataLoader`
         The source domain test :py:class:`torch.utils.data.DataLoader`
         instance build from an instance of
-        :py:class:`pysegcnn.core.split.CustomSubset`.
+        :py:class:`torch.utils.data.Subset`.
     epochs : `int`
         The maximum number of epochs to train. The default is `1`.
     nthreads : `int`
@@ -1419,17 +1420,17 @@ class DomainAdaptationTrainer(ClassificationNetworkTrainer):
     trg_train_dl : `None` or :py:class:`torch.utils.data.DataLoader`
         The target domain training :py:class:`torch.utils.data.DataLoader`
         instance build from an instance of
-        :py:class:`pysegcnn.core.split.CustomSubset`. The default is an empty
+        :py:class:`torch.utils.data.Subset`. The default is an empty
         :py:class:`torch.utils.data.DataLoader`.
     trg_valid_dl : `None` or  :py:class:`torch.utils.data.DataLoader`
         The target domain validation :py:class:`torch.utils.data.DataLoader`
         instance build from an instance of
-        :py:class:`pysegcnn.core.split.CustomSubset`. The default is an empty
+        :py:class:`torch.utils.data.Subset`. The default is an empty
         :py:class:`torch.utils.data.DataLoader`.
     trg_test_dl : :py:class:`torch.utils.data.DataLoader`
         The target domain test :py:class:`torch.utils.data.DataLoader`
         instance build from an instance of
-        :py:class:`pysegcnn.core.split.CustomSubset`. The default is an empty
+        :py:class:`torch.utils.data.Subset`. The default is an empty
         :py:class:`torch.utils.data.DataLoader`.
     uda_loss_function : :py:class:`torch.nn.Module`
         The domain adaptation loss function. An instance of
@@ -2003,6 +2004,10 @@ class NetworkInference(BaseConfig):
         Whether to evaluate the model on the training (``test=None``), the
         validation (``test=False``) or the test set (``test=True``). The
         default is `False`.
+    aggregate : `bool`
+        Whether to aggregate the statistics of the different models in
+        ``state_files``. Useful to aggregate the results of mutliple model
+        runs in cross validation. The default is `False`.
     ds : `dict`
         The dataset configuration dictionary passed to
         :py:class:`pysegcnn.core.trainer.DatasetConfig` when evaluating on
@@ -2036,10 +2041,6 @@ class NetworkInference(BaseConfig):
     alpha : `int`
         The level of the percentiles for contrast stretching of the false color
         compsite. The default is `0`, i.e. no stretching.
-    animate : `bool`
-        Whether to create an animation of (input, ground truth, prediction) for
-        the scenes in the train/validation/test dataset. Only works if
-        ``predict_scene=True`` and ``plot_scene=True``.
     device : `str`
         The device to evaluate the model on, i.e. `cpu` or `cuda`.
     base_path : :py:class:`pathlib.Path`
@@ -2054,9 +2055,9 @@ class NetworkInference(BaseConfig):
         Path to search for model state files ``state_files``.
     plot_kwargs : `dict`
         Keyword arguments for :py:func:`pysegcnn.core.graphics.plot_sample`
-    trg_ds : :py:class:`pysegcnn.core.split.CustomSubset`
+    trg_ds : :py:class:`torch.utils.data.Subset`
         The dataset to evaluate ``model`` on.
-    src_ds : :py:class:`pysegcnn.core.split.CustomSubset`
+    src_ds : :py:class:`torch.utils.data.Subset`
         The model source domain training dataset.
 
     """
@@ -2065,6 +2066,7 @@ class NetworkInference(BaseConfig):
     implicit: bool = True
     domain: str = 'src'
     test: object = False
+    aggregate: bool = False
     ds: dict = dataclasses.field(default_factory={})
     ds_split: dict = dataclasses.field(default_factory={})
     map_labels: bool = False
@@ -2133,11 +2135,11 @@ class NetworkInference(BaseConfig):
             This function assumes that the datasets are stored in a directory
             named "Datasets" on each machine.
 
-        See ``DRIVE_PATH`` in :py:mod:`pysegcnn.main.config`.
+        See ``DRIVE_PATH`` in :py:mod:`pysegcnn.main.eval_config`.
 
         Parameters
         ----------
-        ds : :py:class:`pysegcnn.core.split.CustomSubset`
+        ds : :py:class:`torch.utils.data.Subset`
             A subset of an instance of
             :py:class:`pysegcnn.core.dataset.ImageDataset`.
         drive_path : `str`
@@ -2148,8 +2150,7 @@ class NetworkInference(BaseConfig):
         ------
         TypeError
             Raised if ``ds`` is not an instance of
-            :py:class:`pysegcnn.core.split.CustomSubset` and if ``ds`` is not
-            a subset of an instance of
+            :py:class:`torch.utils.data.Subset` build from an instance of
             :py:class:`pysegcnn.core.dataset.ImageDataset`.
 
         """
@@ -2180,23 +2181,23 @@ class NetworkInference(BaseConfig):
 
         Returns
         -------
-        ds : :py:class:`pysegcnn.core.split.CustomSubset`
+        ds : :py:class:`torch.utils.data.Subset`
             The dataset to evaluate the model on.
 
         """
         # load model state
         model_state = Network.load(state)
 
+        # check whether to evaluate the model on the training, validation
+        # or test set
+        if test is None:
+            ds_set = 'train'
+        else:
+            ds_set = 'test' if test else 'valid'
+
         # check whether to evaluate on the datasets defined at training time
         if implicit:
 
-            # check whether to evaluate the model on the training, validation
-            # or test set
-            if test is None:
-                ds_set = 'train'
-            else:
-                ds_set = 'test' if test else 'valid'
-
             # the dataset to evaluate the model on
             ds = model_state[domain + '_{}_dl'.format(ds_set)].dataset
             if ds is None:
@@ -2214,18 +2215,18 @@ class NetworkInference(BaseConfig):
 
             # split configuration
             sc = SplitConfig(**self.ds_split)
-            train_ds, valid_ds, test_ds = sc.train_val_test_split(ds)
+            folds = sc.train_val_test_split(ds)[0]
 
             # check whether to evaluate the model on the training, validation
             # or test set
-            if test is None:
-                ds = train_ds
-            else:
-                ds = test_ds if test else valid_ds
+            ds = folds[ds_set]
 
             # log dataset representation
             LOGGER.info('Evaluating on {} set of explicitly defined dataset: '
-                        '\n {}'.format(ds.name, repr(ds.dataset)))
+                        '\n {}'.format(ds_set, repr(ds.dataset)))
+
+        # name the current dataset
+        ds.name = '_'.join([domain, ds_set])
 
         # check the dataset path: replace by path on current machine
         self.replace_dataset_path(ds, DRIVE_PATH)
@@ -2464,11 +2465,11 @@ class NetworkInference(BaseConfig):
             the samples (``self.predict_scene=False``) or the name of the
             scenes of the target dataset (``self.predict_scene=True``). The
             values are dictionaries with keys:
-                ``'input'``
+                ``'x'``
                     Model input data of the sample (:py:class:`numpy.ndarray`).
-                ``'labels'
+                ``'y'
                     Ground truth class labels (:py:class:`numpy.ndarray`).
-                ``'prediction'``
+                ``'y_pred'``
                     Model prediction class labels (:py:class:`numpy.ndarray`).
 
         """
@@ -2498,7 +2499,8 @@ class NetworkInference(BaseConfig):
             if self.dataloader.batch_size > 1:
 
                 # id of the current scene
-                batch = self.trg_ds.ids[batch]
+                current_scene = np.int(batch * self.dataloader.batch_size)
+                batch = self.trg_ds.dataset.scenes['id'][current_scene]
 
                 # modify the progress string
                 progress = progress.replace('Sample', 'Scene')
@@ -2542,6 +2544,7 @@ class NetworkInference(BaseConfig):
                                 state=batch_name,
                                 plot_path=self.scenes_path,
                                 **self.kwargs)
+
         return output
 
     def eval_file(self, state_file):
@@ -2577,7 +2580,7 @@ class NetworkInference(BaseConfig):
             # initialize logging
             log = LogConfig(state)
             dictConfig(log_conf(log.log_file))
-            log.init_log('{}: ' + 'Evaluating model: {}.'.format(state.name))
+            log.init_log('Evaluating model: {}.'.format(state.name))
 
             # check whether model was already evaluated
             if self.eval_file(state).exists():
@@ -2632,4 +2635,19 @@ class NetworkInference(BaseConfig):
             # save model predictions to list
             inference[state.stem] = output
 
+        # check whether to aggregate the results of the different model runs
+        if self.aggregate:
+
+            # chech whether to compute the aggregated confusion matrix
+            if self.cm:
+                # initialize the aggregated confusion matrix
+                cm_agg = np.zeros(shape=2 * (len(self.src_ds.labels), ))
+
+                # update aggregated confusion matrix
+                for _, output in inference.items():
+                    cm_agg += output['cm']
+
+                # save aggregated confusion matrix to dictionary
+                inference['cm_agg'] = cm_agg
+
         return inference
-- 
GitLab