From b6cc773feeb57489ee846661972c74ebb85c1dd1 Mon Sep 17 00:00:00 2001
From: "Daniel.Frisinghelli" <daniel.frisinghelli@eurac.edu>
Date: Tue, 16 Feb 2021 14:58:15 +0100
Subject: [PATCH] Read the labels from the classification reports.

---
 pysegcnn/core/trainer.py | 14 ++++++++------
 1 file changed, 8 insertions(+), 6 deletions(-)

diff --git a/pysegcnn/core/trainer.py b/pysegcnn/core/trainer.py
index 49349d8..261fcb7 100644
--- a/pysegcnn/core/trainer.py
+++ b/pysegcnn/core/trainer.py
@@ -2739,16 +2739,19 @@ class NetworkInference(BaseConfig):
             report = df.groupby(df.index, sort=False).mean()
             inference['report'] = report
 
+            # labels to predict
+            labels = list(report.index.drop(['macro avg', 'weighted avg',
+                                             'accuracy']))
+
             # plot classification report
-            fig = plot_classification_report(report, self.class_names)
+            fig = plot_classification_report(report, labels)
             report_name = self.report_path.joinpath(self.report_name(kfold))
             fig.savefig(report_name, dpi=300, bbox_inches='tight')
 
             # chech whether to compute the aggregated confusion matrix
             if self.cm:
                 # initialize the aggregated confusion matrix
-                cm_agg = np.zeros(shape=2 * (len(self.src_ds.dataset.labels), )
-                                  )
+                cm_agg = np.zeros(shape=2 * (len(labels), ))
 
                 # update aggregated confusion matrix
                 for _, output in inference.items():
@@ -2758,8 +2761,7 @@ class NetworkInference(BaseConfig):
                 inference['cm'] = cm_agg
 
                 # plot aggregated confusion matrix and save to file
-                plot_confusion_matrix(
-                    cm_agg, self.class_names, state_file=kfold,
-                    outpath=self.perfmc_path)
+                plot_confusion_matrix(cm_agg, labels, state_file=kfold,
+                                      outpath=self.perfmc_path)
 
         return inference
-- 
GitLab