From b6cc773feeb57489ee846661972c74ebb85c1dd1 Mon Sep 17 00:00:00 2001 From: "Daniel.Frisinghelli" <daniel.frisinghelli@eurac.edu> Date: Tue, 16 Feb 2021 14:58:15 +0100 Subject: [PATCH] Read the labels from the classification reports. --- pysegcnn/core/trainer.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/pysegcnn/core/trainer.py b/pysegcnn/core/trainer.py index 49349d8..261fcb7 100644 --- a/pysegcnn/core/trainer.py +++ b/pysegcnn/core/trainer.py @@ -2739,16 +2739,19 @@ class NetworkInference(BaseConfig): report = df.groupby(df.index, sort=False).mean() inference['report'] = report + # labels to predict + labels = list(report.index.drop(['macro avg', 'weighted avg', + 'accuracy'])) + # plot classification report - fig = plot_classification_report(report, self.class_names) + fig = plot_classification_report(report, labels) report_name = self.report_path.joinpath(self.report_name(kfold)) fig.savefig(report_name, dpi=300, bbox_inches='tight') # chech whether to compute the aggregated confusion matrix if self.cm: # initialize the aggregated confusion matrix - cm_agg = np.zeros(shape=2 * (len(self.src_ds.dataset.labels), ) - ) + cm_agg = np.zeros(shape=2 * (len(labels), )) # update aggregated confusion matrix for _, output in inference.items(): @@ -2758,8 +2761,7 @@ class NetworkInference(BaseConfig): inference['cm'] = cm_agg # plot aggregated confusion matrix and save to file - plot_confusion_matrix( - cm_agg, self.class_names, state_file=kfold, - outpath=self.perfmc_path) + plot_confusion_matrix(cm_agg, labels, state_file=kfold, + outpath=self.perfmc_path) return inference -- GitLab