From 6e3a7d7cd393177fa68ec5c8c92c3624e79a107e Mon Sep 17 00:00:00 2001
From: "Daniel.Frisinghelli" <daniel.frisinghelli@eurac.edu>
Date: Fri, 5 Feb 2021 15:25:58 +0100
Subject: [PATCH] Fixed some bugs: dataset path replacement and confusion
 matrix initialization.

---
 pysegcnn/core/trainer.py | 81 ++++++++++++++++++++++++++++------------
 1 file changed, 58 insertions(+), 23 deletions(-)

diff --git a/pysegcnn/core/trainer.py b/pysegcnn/core/trainer.py
index 45d6ee2..457b524 100644
--- a/pysegcnn/core/trainer.py
+++ b/pysegcnn/core/trainer.py
@@ -52,7 +52,7 @@ from pysegcnn.core.logging import log_conf
 from pysegcnn.core.graphics import (plot_loss, plot_confusion_matrix,
                                     plot_sample)
 from pysegcnn.core.constants import map_labels
-from pysegcnn.main.train_config import HERE, DRIVE_PATH
+from pysegcnn.main.train_config import HERE
 
 # module level logger
 LOGGER = logging.getLogger(__name__)
@@ -2018,6 +2018,13 @@ class NetworkInference(BaseConfig):
         :py:class:`pysegcnn.core.trainer.SplitConfig` when evaluating on
         an explicitly defined dataset, i.e. ``implicit=False``. The default is
         `{}`.
+    drive_path : `str`
+        Path to the datasets on the current machine. Per default, the path to
+        the datasets is assumed to be the same as during training, i.e. model
+        training and evaluation is done on the same machine. Otherwise,
+        ``drive_path`` should be the path to the datasets, ending with
+        `'Datasets'`, on the machine model evaluation is performed. The default
+        is `''`.
     map_labels : `bool`
         Whether to map the model labels from the model source domain to the
         defined ``domain`` in case the domain class labels differ. The default
@@ -2036,6 +2043,9 @@ class NetworkInference(BaseConfig):
         `['nir', 'red', 'green']`.
     cm : `bool`
         Whether to compute the confusion matrix. The default is `True`.
+    overwrite : `bool`
+        Whether to overwrite existing model evaluations. The default is
+        `False`.
     figsize : `tuple`
         The figure size in centimeters. The default is `(10, 10)`.
     alpha : `int`
@@ -2069,12 +2079,14 @@ class NetworkInference(BaseConfig):
     aggregate: bool = False
     ds: dict = dataclasses.field(default_factory={})
     ds_split: dict = dataclasses.field(default_factory={})
+    drive_path: str = ''
     map_labels: bool = False
     predict_scene: bool = True
     plot_scenes: bool = False
     plot_bands: list = dataclasses.field(
         default_factory=lambda: ['red', 'green', 'blue'])
     cm: bool = True
+    overwrite: bool = False
     figsize: tuple = (10, 10)
     alpha: int = 5
 
@@ -2135,8 +2147,6 @@ class NetworkInference(BaseConfig):
             This function assumes that the datasets are stored in a directory
             named "Datasets" on each machine.
 
-        See ``DRIVE_PATH`` in :py:mod:`pysegcnn.main.eval_config`.
-
         Parameters
         ----------
         ds : :py:class:`torch.utils.data.Subset`
@@ -2146,19 +2156,12 @@ class NetworkInference(BaseConfig):
             Base path to the datasets on the current machine. ``drive_path``
             should end with `'Datasets'`.
 
-        Raises
-        ------
-        TypeError
-            Raised if ``ds`` is not an instance of
-            :py:class:`torch.utils.data.Subset` build from an instance of
-            :py:class:`pysegcnn.core.dataset.ImageDataset`.
-
         """
         # iterate over the scenes of the dataset
         for scene in ds.dataset.scenes:
             for k, v in scene.items():
-                # do only look for paths
-                if isinstance(v, str) and k != 'id':
+                # do only look for paths to the dataset
+                if k in ds.dataset.use_bands + ['gt']:
 
                     # drive path: match path before "Datasets"
                     # dpath = re.search('^(.*)(?=(/.*Datasets))', v)
@@ -2229,7 +2232,16 @@ class NetworkInference(BaseConfig):
         ds.name = '_'.join([domain, ds_set])
 
         # check the dataset path: replace by path on current machine
-        # self.replace_dataset_path(ds, DRIVE_PATH)
+        if self.drive_path:
+            # check whether the specified path exists on the current machine
+            self.drive_path = pathlib.Path(self.drive_path)
+            if not self.drive_path.exists():
+                raise FileExistsError('Dataset path {} does not exist.'
+                                      .format(self.drive_path))
+
+            # replace dataset path of target dataset by path on the current
+            # machine
+            self.replace_dataset_path(ds, self.drive_path)
 
         return ds
 
@@ -2472,7 +2484,7 @@ class NetworkInference(BaseConfig):
             values are dictionaries with keys:
                 ``'x'``
                     Model input data of the sample (:py:class:`numpy.ndarray`).
-                ``'y'
+                ``'y_true'
                     Ground truth class labels (:py:class:`numpy.ndarray`).
                 ``'y_pred'``
                     Model prediction class labels (:py:class:`numpy.ndarray`).
@@ -2527,7 +2539,7 @@ class NetworkInference(BaseConfig):
                 prdctn = self.map_to_target(prdctn)
 
             # save current batch to output dictionary
-            output[batch] = {'x': inputs, 'y': labels, 'y_pred': prdctn}
+            output[batch] = {'x': inputs, 'y_true': labels, 'y_pred': prdctn}
 
             # filename for the plot of the current batch
             batch_name = '_'.join([model.state_file.stem,
@@ -2575,7 +2587,7 @@ class NetworkInference(BaseConfig):
             values of the nested dictionaries are again dictionaries with keys:
                 ``'x'``
                     Model input data of the sample (:py:class:`numpy.ndarray`).
-                ``'y'
+                ``'y_true'
                     Ground truth class labels (:py:class:`numpy.ndarray`).
                 ``'y_pred'``
                     Model prediction class labels (:py:class:`numpy.ndarray`).
@@ -2595,12 +2607,20 @@ class NetworkInference(BaseConfig):
 
             # check whether model was already evaluated
             if self.eval_file(state).exists():
-
-                # load existing model evaluation
                 LOGGER.info('Found existing model evaluation: {}.'
                             .format(self.eval_file(state)))
-                inference[state.stem] = torch.load(self.eval_file(state))
-                continue
+
+                # load existing model evaluation
+                if not self.overwrite:
+                    LOGGER.info('Using model evaluation: {}.'
+                                .format(self.eval_file(state)))
+                    inference[state.stem] = torch.load(self.eval_file(state))
+                    continue
+                else:
+                    # overwrite existing model evaluation
+                    LOGGER.info('Overwriting model evaluation: {}.'
+                                .format(self.eval_file(state)))
+                    self.eval_file(state).unlink()
 
             # plot loss and accuracy
             plot_loss(check_filename_length(state), outpath=self.perfmc_path)
@@ -2622,11 +2642,14 @@ class NetworkInference(BaseConfig):
             # check whether to calculate confusion matrix
             if self.cm:
 
-                # TODO: merge predictions for all scenes in output
+                # merge predictions for all samples
+                y_true = np.asarray([v['y_true'].numpy().flatten() for _, v
+                                     in output.items()]).flatten()
+                y_pred = np.asarray([v['y_pred'].numpy().flatten() for _, v
+                                     in output.items()]).flatten()
 
                 # calculate confusion matrix
-                conf_mat = confusion_matrix(output['y'].numpy().flatten(),
-                                            output['y_pred'].numpy().flatten())
+                conf_mat = confusion_matrix(y_true, y_pred)
 
                 # add confusion matrix to model output
                 output['cm'] = conf_mat
@@ -2646,6 +2669,9 @@ class NetworkInference(BaseConfig):
         # check whether to aggregate the results of the different model runs
         if self.aggregate:
 
+            # base name for all models
+            base_name = str(self.state_files[0]).name
+
             # chech whether to compute the aggregated confusion matrix
             if self.cm:
                 # initialize the aggregated confusion matrix
@@ -2659,4 +2685,13 @@ class NetworkInference(BaseConfig):
                 # save aggregated confusion matrix to dictionary
                 inference['cm_agg'] = cm_agg
 
+                # create file name for aggregated confusion matrix
+                fold_number = re.search('f[0-9]', base_name)[0]
+                cm_name = base_name.replace(fold_number, 'kfold')
+
+                # plot aggregated confusion matrix and save to file
+                plot_confusion_matrix(
+                    cm_agg, self.source_labels, state_file=cm_name,
+                    outpath=self.perfmc_path)
+
         return inference
-- 
GitLab