diff --git a/pysegcnn/core/trainer.py b/pysegcnn/core/trainer.py
index 59037cca3a898f68770165c89e2bb0035a87f261..cc805cc42aae2bc66807c44a923363450036b8e0 100644
--- a/pysegcnn/core/trainer.py
+++ b/pysegcnn/core/trainer.py
@@ -2244,7 +2244,7 @@ class NetworkInference(BaseConfig):
             The class labels of the source domain.
 
         """
-        return self.src_ds.labels
+        return self.src_ds.dataset.labels
 
     @property
     def target_labels(self):
@@ -2256,7 +2256,7 @@ class NetworkInference(BaseConfig):
             The class labels of the target domain.
 
         """
-        return self.trg_ds.labels
+        return self.trg_ds.dataset.labels
 
     @property
     def source_label_map(self):
@@ -2307,7 +2307,7 @@ class NetworkInference(BaseConfig):
         """
         # check whether the source domain labels are the same as the target
         # domain labels
-        return map_labels(self.src_ds.get_labels(),
+        return map_labels(self.src_ds.dataset.get_labels(),
                           self.trg_ds.dataset.get_labels())
 
     @property
@@ -2359,7 +2359,7 @@ class NetworkInference(BaseConfig):
             A list of the named spectral bands used to train the model.
 
         """
-        return self.src_ds.use_bands
+        return self.src_ds.dataset.use_bands
 
     @property
     def plot(self):
@@ -2417,7 +2417,7 @@ class NetworkInference(BaseConfig):
             The original class labels of the source domain.
 
         """
-        return self.src_ds._labels
+        return self.src_ds.dataset._labels
 
     @property
     def _original_target_labels(self):
@@ -2607,7 +2607,7 @@ class NetworkInference(BaseConfig):
                 domain=self.domain)
 
             # load the source dataset the model was trained on
-            self.src_ds = self.load_dataset(state, test=None).dataset
+            self.src_ds = self.load_dataset(state, test=None)
 
             # load the pretrained model
             model, _ = Network.load_pretrained_model(state)
@@ -2618,6 +2618,8 @@ class NetworkInference(BaseConfig):
             # check whether to calculate confusion matrix
             if self.cm:
 
+                # TODO: merge predictions for all scenes in output
+
                 # calculate confusion matrix
                 conf_mat = confusion_matrix(output['y'].numpy().flatten(),
                                             output['y_pred'].numpy().flatten())
@@ -2643,7 +2645,8 @@ class NetworkInference(BaseConfig):
             # chech whether to compute the aggregated confusion matrix
             if self.cm:
                 # initialize the aggregated confusion matrix
-                cm_agg = np.zeros(shape=2 * (len(self.src_ds.labels), ))
+                cm_agg = np.zeros(shape=2 * (len(self.src_ds.dataset.labels), )
+                                  )
 
                 # update aggregated confusion matrix
                 for _, output in inference.items():