From 643b49562e6afef1bac0bf65c3db27b21a8d72e2 Mon Sep 17 00:00:00 2001 From: "Daniel.Frisinghelli" <daniel.frisinghelli@eurac.edu> Date: Thu, 4 Feb 2021 18:08:02 +0100 Subject: [PATCH] Consistent use of source and target datasets. --- pysegcnn/core/trainer.py | 17 ++++++++++------- 1 file changed, 10 insertions(+), 7 deletions(-) diff --git a/pysegcnn/core/trainer.py b/pysegcnn/core/trainer.py index 59037cc..cc805cc 100644 --- a/pysegcnn/core/trainer.py +++ b/pysegcnn/core/trainer.py @@ -2244,7 +2244,7 @@ class NetworkInference(BaseConfig): The class labels of the source domain. """ - return self.src_ds.labels + return self.src_ds.dataset.labels @property def target_labels(self): @@ -2256,7 +2256,7 @@ class NetworkInference(BaseConfig): The class labels of the target domain. """ - return self.trg_ds.labels + return self.trg_ds.dataset.labels @property def source_label_map(self): @@ -2307,7 +2307,7 @@ class NetworkInference(BaseConfig): """ # check whether the source domain labels are the same as the target # domain labels - return map_labels(self.src_ds.get_labels(), + return map_labels(self.src_ds.dataset.get_labels(), self.trg_ds.dataset.get_labels()) @property @@ -2359,7 +2359,7 @@ class NetworkInference(BaseConfig): A list of the named spectral bands used to train the model. """ - return self.src_ds.use_bands + return self.src_ds.dataset.use_bands @property def plot(self): @@ -2417,7 +2417,7 @@ class NetworkInference(BaseConfig): The original class labels of the source domain. """ - return self.src_ds._labels + return self.src_ds.dataset._labels @property def _original_target_labels(self): @@ -2607,7 +2607,7 @@ class NetworkInference(BaseConfig): domain=self.domain) # load the source dataset the model was trained on - self.src_ds = self.load_dataset(state, test=None).dataset + self.src_ds = self.load_dataset(state, test=None) # load the pretrained model model, _ = Network.load_pretrained_model(state) @@ -2618,6 +2618,8 @@ class NetworkInference(BaseConfig): # check whether to calculate confusion matrix if self.cm: + # TODO: merge predictions for all scenes in output + # calculate confusion matrix conf_mat = confusion_matrix(output['y'].numpy().flatten(), output['y_pred'].numpy().flatten()) @@ -2643,7 +2645,8 @@ class NetworkInference(BaseConfig): # chech whether to compute the aggregated confusion matrix if self.cm: # initialize the aggregated confusion matrix - cm_agg = np.zeros(shape=2 * (len(self.src_ds.labels), )) + cm_agg = np.zeros(shape=2 * (len(self.src_ds.dataset.labels), ) + ) # update aggregated confusion matrix for _, output in inference.items(): -- GitLab