From e63a81b59f088392b2d9f793c7891c08176ca379 Mon Sep 17 00:00:00 2001
From: "Daniel.Frisinghelli" <daniel.frisinghelli@eurac.edu>
Date: Thu, 18 Feb 2021 15:26:19 +0100
Subject: [PATCH] Changed model evaluation strategy: process batch-wise rather
 than scene-wise.

---
 pysegcnn/core/trainer.py | 152 +++++++++++++++++++++++----------------
 1 file changed, 90 insertions(+), 62 deletions(-)

diff --git a/pysegcnn/core/trainer.py b/pysegcnn/core/trainer.py
index b64aa74..7ebcd1e 100644
--- a/pysegcnn/core/trainer.py
+++ b/pysegcnn/core/trainer.py
@@ -58,6 +58,9 @@ from pysegcnn.main.train_config import HERE
 # module level logger
 LOGGER = logging.getLogger(__name__)
 
+# global variable: variable names of the model inference output
+INFERENCE_NAMES = ['x', 'y', 'y_pred']
+
 
 @dataclasses.dataclass
 class BaseConfig:
@@ -2397,8 +2400,7 @@ class NetworkInference(BaseConfig):
         Returns
         -------
         plot : `bool`
-            Save plots for each sample or for each scene of the target dataset,
-            depending on ``self.predict_scene``.
+            Save plots for each scene of the target dataset.
 
         """
         return self.plot_scenes if self.predict_scene else False
@@ -2424,13 +2426,10 @@ class NetworkInference(BaseConfig):
         Returns
         -------
         batch_size : `int`
-            The batch size of the dataloader used for model inference. Depends
-            on whether to predict each sample of the target dataset
-            individually or whether to reconstruct each scene in the target
-            dataset.
+            The batch size of the dataloader used for model inference.
 
         """
-        return self.trg_ds.dataset.tiles if self.predict_scene else 1
+        return 1
 
     @property
     def _original_source_labels(self):
@@ -2514,8 +2513,11 @@ class NetworkInference(BaseConfig):
         model.to(self.device)
         LOGGER.info('Device: {}'.format(self.device))
 
-        # iterate over the samples of the target dataset
+        # initialize dictionaries to store outputs in
         output = {}
+        scenes = {k: [] for k in INFERENCE_NAMES}
+
+        # iterate over the samples of the target dataset
         for batch, (inputs, labels) in enumerate(self.dataloader):
 
             # send inputs and labels to device
@@ -2527,68 +2529,94 @@ class NetworkInference(BaseConfig):
                 prdctn = F.softmax(
                     model(inputs), dim=1).argmax(dim=1).squeeze()
 
+            # check whether the source and target domain labels differ
+            if self.apply_label_map:
+                prdctn = self.map_to_target(prdctn)
+
             # progress string to log
-            progress = 'Sample: {:d}/{:d}'.format(batch + 1,
-                                                  len(self.dataloader))
+            progress = 'Sample: {:d}/{:d}'.format(
+                batch + 1, len(self.dataloader))
 
             # check if tensor is on gpu and convert to numpy array
             inputs = inputs.cpu().numpy()
             labels = labels.cpu().numpy()
             prdctn = prdctn.cpu().numpy()
 
-            # check whether to reconstruct the scene
-            if self.dataloader.batch_size > 1:
-
-                # tiles of the current scene
-                current_tiles = self.trg_ds.indices[
-                    np.arange(batch * self.dataloader.batch_size,
-                              (batch + 1) * self.dataloader.batch_size)]
+            # check whether to reconstruct the scenes of a dataset
+            if self.predict_scene:
+
+                # append model predictions of current batch to scene dictionary
+                for k, v in zip(INFERENCE_NAMES, [inputs, labels, prdctn]):
+                    scenes[k].append(v)
+
+                # check if an entire scene is processed
+                if batch % self.trg_ds.dataset.tiles == 0 and batch != 0:
+
+                    # convert scene dictionary to numpy arrays
+                    inputs, labels, prdctn = [np.asarray(v) for _, v in
+                                              scenes.items()]
+
+                    # tiles of the current scene
+                    current_tiles = self.trg_ds.indices[
+                        np.arange(batch - self.trg_ds.dataset.tiles, batch)]
+
+                    # name of the current scene
+                    batch = np.unique([self.trg_ds.dataset.scenes[sid]['id']
+                                       for sid in current_tiles]).item()
+
+                    # modify the progress string
+                    progress = progress.replace('Sample', 'Scene')
+                    progress += ' Id: {}'.format(batch)
+
+                    # reconstruct the entire scene
+                    inputs = reconstruct_scene(inputs)
+                    labels = reconstruct_scene(labels)
+                    prdctn = reconstruct_scene(prdctn)
+
+                    # calculate the accuracy of the prediction on the current
+                    # scene
+                    progress += ', Accuracy: {:.2f}'.format(
+                        accuracy_function(prdctn, labels))
+                    LOGGER.info(progress)
+
+                    # save current scene to output dictionary
+                    output[batch] = {k: v for k, v in zip(
+                        INFERENCE_NAMES, [inputs, labels, prdctn])}
+
+                    # re-initialize scene dictionary
+                    scenes = {k: [] for k in INFERENCE_NAMES}
+
+                    # plot current scene
+                    if self.plot:
+                        # plot inputs, ground truth and model predictions
+                        fig = plot_sample(inputs.clip(0, 1),
+                                          self.bands,
+                                          self.use_labels,
+                                          y=labels,
+                                          y_pred={'Prediction': prdctn},
+                                          accuracy=True,
+                                          **self.plot_kwargs)
+
+                        # filename for the plot of the current batch
+                        batch_name = '_'.join(
+                            [model.state_file.stem,
+                             '{}_{}.pdf'.format(self.trg_ds.name, batch)])
+
+                        # save figure
+                        fig.savefig(check_filename_length(
+                            self.scenes_path.joinpath(batch_name)),
+                            bbox_inches='tight')
 
-                # name of the current scene
-                batch = np.unique([self.trg_ds.dataset.scenes[sid]['id'] for
-                                   sid in current_tiles]).item()
-
-                # modify the progress string
-                progress = progress.replace('Sample', 'Scene')
-                progress += ' Id: {}'.format(batch)
-
-                # reconstruct the entire scene
-                inputs = reconstruct_scene(inputs)
-                labels = reconstruct_scene(labels)
-                prdctn = reconstruct_scene(prdctn)
-
-            # check whether the source and target domain labels differ
-            if self.apply_label_map:
-                prdctn = self.map_to_target(prdctn)
-
-            # save current batch to output dictionary
-            output[batch] = {'x': inputs, 'y_true': labels, 'y_pred': prdctn}
-
-            # filename for the plot of the current batch
-            batch_name = '_'.join([model.state_file.stem,
-                                  '{}_{}.pdf'.format(self.trg_ds.name, batch)])
-
-            # calculate the accuracy of the prediction
-            progress += ', Accuracy: {:.2f}'.format(
-                accuracy_function(prdctn, labels))
-            LOGGER.info(progress)
-
-            # plot current scene
-            if self.plot:
-
-                # plot inputs, ground truth and model predictions
-                fig = plot_sample(inputs.clip(0, 1),
-                                  self.bands,
-                                  self.use_labels,
-                                  y=labels,
-                                  y_pred={'Prediction': prdctn},
-                                  accuracy=True,
-                                  **self.plot_kwargs)
-
-                # save figure
-                fig.savefig(check_filename_length(
-                    self.scenes_path.joinpath(batch_name)),
-                    bbox_inches='tight')
+            else:
+                # save current batch to output dictionary
+                output[batch] = {k: v for k, v in zip(INFERENCE_NAMES,
+                                                      [inputs, labels, prdctn])
+                                 }
+
+                # calculate the accuracy of the prediction on the current batch
+                progress += ', Accuracy: {:.2f}'.format(
+                    accuracy_function(prdctn, labels))
+                LOGGER.info(progress)
 
         return output
 
-- 
GitLab