From 8199ac71fd5e07c55791eb90ff6e70c016b9ba21 Mon Sep 17 00:00:00 2001 From: "Daniel.Frisinghelli" <daniel.frisinghelli@eurac.edu> Date: Wed, 27 Jan 2021 15:49:35 +0100 Subject: [PATCH] Implemented the sklearn implementation of the confusion matrix. --- pysegcnn/core/trainer.py | 10 +++------- 1 file changed, 3 insertions(+), 7 deletions(-) diff --git a/pysegcnn/core/trainer.py b/pysegcnn/core/trainer.py index 97ce7c1..60f9362 100644 --- a/pysegcnn/core/trainer.py +++ b/pysegcnn/core/trainer.py @@ -35,6 +35,7 @@ import torch.nn as nn import torch.nn.functional as F from torch.utils.data import DataLoader from torch.optim import Optimizer +from sklearn.metrics import confusion_matrix # locals from pysegcnn.core.dataset import SupportedDatasets @@ -2611,14 +2612,9 @@ class NetworkInference(BaseConfig): # check whether to calculate confusion matrix if self.cm: - # initialize confusion matrix - conf_mat = np.zeros(shape=2 * (len(self.src_ds.labels), )) - # calculate confusion matrix - for ytrue, ypred in zip(output['y'].flatten(), - output['y_pred'].flatten()): - # update confusion matrix entries - conf_mat[ytrue.long(), ypred.long()] += 1 + conf_mat = confusion_matrix(output['y'].numpy().flatten(), + output['y_pred'].numpy().flatten()) # add confusion matrix to model output output['cm'] = conf_mat -- GitLab