From 643b49562e6afef1bac0bf65c3db27b21a8d72e2 Mon Sep 17 00:00:00 2001
From: "Daniel.Frisinghelli" <daniel.frisinghelli@eurac.edu>
Date: Thu, 4 Feb 2021 18:08:02 +0100
Subject: [PATCH] Consistent use of source and target datasets.

---
 pysegcnn/core/trainer.py | 17 ++++++++++-------
 1 file changed, 10 insertions(+), 7 deletions(-)

diff --git a/pysegcnn/core/trainer.py b/pysegcnn/core/trainer.py
index 59037cc..cc805cc 100644
--- a/pysegcnn/core/trainer.py
+++ b/pysegcnn/core/trainer.py
@@ -2244,7 +2244,7 @@ class NetworkInference(BaseConfig):
             The class labels of the source domain.
 
         """
-        return self.src_ds.labels
+        return self.src_ds.dataset.labels
 
     @property
     def target_labels(self):
@@ -2256,7 +2256,7 @@ class NetworkInference(BaseConfig):
             The class labels of the target domain.
 
         """
-        return self.trg_ds.labels
+        return self.trg_ds.dataset.labels
 
     @property
     def source_label_map(self):
@@ -2307,7 +2307,7 @@ class NetworkInference(BaseConfig):
         """
         # check whether the source domain labels are the same as the target
         # domain labels
-        return map_labels(self.src_ds.get_labels(),
+        return map_labels(self.src_ds.dataset.get_labels(),
                           self.trg_ds.dataset.get_labels())
 
     @property
@@ -2359,7 +2359,7 @@ class NetworkInference(BaseConfig):
             A list of the named spectral bands used to train the model.
 
         """
-        return self.src_ds.use_bands
+        return self.src_ds.dataset.use_bands
 
     @property
     def plot(self):
@@ -2417,7 +2417,7 @@ class NetworkInference(BaseConfig):
             The original class labels of the source domain.
 
         """
-        return self.src_ds._labels
+        return self.src_ds.dataset._labels
 
     @property
     def _original_target_labels(self):
@@ -2607,7 +2607,7 @@ class NetworkInference(BaseConfig):
                 domain=self.domain)
 
             # load the source dataset the model was trained on
-            self.src_ds = self.load_dataset(state, test=None).dataset
+            self.src_ds = self.load_dataset(state, test=None)
 
             # load the pretrained model
             model, _ = Network.load_pretrained_model(state)
@@ -2618,6 +2618,8 @@ class NetworkInference(BaseConfig):
             # check whether to calculate confusion matrix
             if self.cm:
 
+                # TODO: merge predictions for all scenes in output
+
                 # calculate confusion matrix
                 conf_mat = confusion_matrix(output['y'].numpy().flatten(),
                                             output['y_pred'].numpy().flatten())
@@ -2643,7 +2645,8 @@ class NetworkInference(BaseConfig):
             # chech whether to compute the aggregated confusion matrix
             if self.cm:
                 # initialize the aggregated confusion matrix
-                cm_agg = np.zeros(shape=2 * (len(self.src_ds.labels), ))
+                cm_agg = np.zeros(shape=2 * (len(self.src_ds.dataset.labels), )
+                                  )
 
                 # update aggregated confusion matrix
                 for _, output in inference.items():
-- 
GitLab