From 420561b7f7f33f69bcdd297ae1b2c930e6d3409d Mon Sep 17 00:00:00 2001
From: "Daniel.Frisinghelli" <daniel.frisinghelli@eurac.edu>
Date: Fri, 26 Feb 2021 09:08:33 +0100
Subject: [PATCH] Do not store scene inputs and use more descriptive figure
 title.

---
 pysegcnn/core/trainer.py | 16 ++++++++++------
 1 file changed, 10 insertions(+), 6 deletions(-)

diff --git a/pysegcnn/core/trainer.py b/pysegcnn/core/trainer.py
index e663a33..2640ca9 100644
--- a/pysegcnn/core/trainer.py
+++ b/pysegcnn/core/trainer.py
@@ -2504,8 +2504,6 @@ class NetworkInference(BaseConfig):
             the samples (``self.predict_scene=False``) or the name of the
             scenes of the target dataset (``self.predict_scene=True``). The
             values are dictionaries with keys:
-                ``'x'``
-                    Model input data of the sample (:py:class:`numpy.ndarray`).
                 ``'y_true'
                     Ground truth class labels (:py:class:`numpy.ndarray`).
                 ``'y_pred'``
@@ -2586,19 +2584,25 @@ class NetworkInference(BaseConfig):
 
                     # save current scene to output dictionary
                     output[batch] = {k: v for k, v in zip(
-                        INFERENCE_NAMES, [inputs, labels, prdctn])}
+                        INFERENCE_NAMES[1:], [labels, prdctn])}
 
                     # re-initialize scene dictionary
                     scenes = {k: [] for k in INFERENCE_NAMES}
 
                     # plot current scene
                     if self.plot:
+
+                        # title for prediction
+                        title = ''.join([(v[0] + str(k)) for k, v in
+                                         self.src_ds.sensor.band_dict().items()
+                                         if v in self.bands])
+
                         # plot inputs, ground truth and model predictions
                         fig = plot_sample(inputs.clip(0, 1),
                                           self.bands,
                                           self.use_labels,
                                           y=labels,
-                                          y_pred={'Prediction': prdctn},
+                                          y_pred={title: prdctn},
                                           accuracy=True,
                                           **self.plot_kwargs)
 
@@ -2618,8 +2622,8 @@ class NetworkInference(BaseConfig):
 
             else:
                 # save current batch to output dictionary
-                output[batch] = {k: v for k, v in zip(INFERENCE_NAMES,
-                                                      [inputs, labels, prdctn])
+                output[batch] = {k: v for k, v in zip(INFERENCE_NAMES[1:],
+                                                      [labels, prdctn])
                                  }
 
                 # calculate the accuracy of the prediction on the current batch
-- 
GitLab