diff --git a/pysegcnn/core/trainer.py b/pysegcnn/core/trainer.py index d0c2a57182349868a690b755aee1dca94ecdbb3c..792ebe20cbb5bf39d92349d05cfd72c70d32bb69 100644 --- a/pysegcnn/core/trainer.py +++ b/pysegcnn/core/trainer.py @@ -2733,7 +2733,9 @@ class NetworkInference(BaseConfig): # calculate confusion matrix LOGGER.info('Computing confusion matrix ...') - conf_mat = confusion_matrix(y_true, y_pred) + conf_mat = confusion_matrix( + y_true, y_pred, + labels=np.asarray(list(self.use_labels.keys()))) # add confusion matrix to model output output['cm'] = conf_mat @@ -2779,11 +2781,15 @@ class NetworkInference(BaseConfig): # labels to predict: drop averages try: + # check if for some scenes micro average is calculated rather + # than accuracy due to missing class labels labels = list(report.index.drop(['macro avg', 'weighted avg', - 'accuracy'])) + 'micro_avg', 'accuracy'])) except KeyError: + # micro average is not calculated for any scene: each scene + # contains all the classes labels = list(report.index.drop(['macro avg', 'weighted avg', - 'micro avg'])) + 'accuracy'])) # plot classification report fig = plot_classification_report(report, labels)