diff --git a/pysegcnn/core/trainer.py b/pysegcnn/core/trainer.py index 49349d8608be7f73e5fe6841a0c4b0cb15bb5b5a..261fcb7ded55d0fb3e2762bddba7d2c13eb767c2 100644 --- a/pysegcnn/core/trainer.py +++ b/pysegcnn/core/trainer.py @@ -2739,16 +2739,19 @@ class NetworkInference(BaseConfig): report = df.groupby(df.index, sort=False).mean() inference['report'] = report + # labels to predict + labels = list(report.index.drop(['macro avg', 'weighted avg', + 'accuracy'])) + # plot classification report - fig = plot_classification_report(report, self.class_names) + fig = plot_classification_report(report, labels) report_name = self.report_path.joinpath(self.report_name(kfold)) fig.savefig(report_name, dpi=300, bbox_inches='tight') # chech whether to compute the aggregated confusion matrix if self.cm: # initialize the aggregated confusion matrix - cm_agg = np.zeros(shape=2 * (len(self.src_ds.dataset.labels), ) - ) + cm_agg = np.zeros(shape=2 * (len(labels), )) # update aggregated confusion matrix for _, output in inference.items(): @@ -2758,8 +2761,7 @@ class NetworkInference(BaseConfig): inference['cm'] = cm_agg # plot aggregated confusion matrix and save to file - plot_confusion_matrix( - cm_agg, self.class_names, state_file=kfold, - outpath=self.perfmc_path) + plot_confusion_matrix(cm_agg, labels, state_file=kfold, + outpath=self.perfmc_path) return inference