Skip to content
Snippets Groups Projects
Commit 598460ec authored by Frisinghelli Daniel's avatar Frisinghelli Daniel
Browse files

Specify labels also for confusion matrix.

parent 6ba7096e
No related branches found
No related tags found
No related merge requests found
...@@ -2733,7 +2733,9 @@ class NetworkInference(BaseConfig): ...@@ -2733,7 +2733,9 @@ class NetworkInference(BaseConfig):
# calculate confusion matrix # calculate confusion matrix
LOGGER.info('Computing confusion matrix ...') LOGGER.info('Computing confusion matrix ...')
conf_mat = confusion_matrix(y_true, y_pred) conf_mat = confusion_matrix(
y_true, y_pred,
labels=np.asarray(list(self.use_labels.keys())))
# add confusion matrix to model output # add confusion matrix to model output
output['cm'] = conf_mat output['cm'] = conf_mat
...@@ -2779,11 +2781,15 @@ class NetworkInference(BaseConfig): ...@@ -2779,11 +2781,15 @@ class NetworkInference(BaseConfig):
# labels to predict: drop averages # labels to predict: drop averages
try: try:
# check if for some scenes micro average is calculated rather
# than accuracy due to missing class labels
labels = list(report.index.drop(['macro avg', 'weighted avg', labels = list(report.index.drop(['macro avg', 'weighted avg',
'accuracy'])) 'micro_avg', 'accuracy']))
except KeyError: except KeyError:
# micro average is not calculated for any scene: each scene
# contains all the classes
labels = list(report.index.drop(['macro avg', 'weighted avg', labels = list(report.index.drop(['macro avg', 'weighted avg',
'micro avg'])) 'accuracy']))
# plot classification report # plot classification report
fig = plot_classification_report(report, labels) fig = plot_classification_report(report, labels)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment