From 8199ac71fd5e07c55791eb90ff6e70c016b9ba21 Mon Sep 17 00:00:00 2001
From: "Daniel.Frisinghelli" <daniel.frisinghelli@eurac.edu>
Date: Wed, 27 Jan 2021 15:49:35 +0100
Subject: [PATCH] Implemented the sklearn implementation of the confusion
 matrix.

---
 pysegcnn/core/trainer.py | 10 +++-------
 1 file changed, 3 insertions(+), 7 deletions(-)

diff --git a/pysegcnn/core/trainer.py b/pysegcnn/core/trainer.py
index 97ce7c1..60f9362 100644
--- a/pysegcnn/core/trainer.py
+++ b/pysegcnn/core/trainer.py
@@ -35,6 +35,7 @@ import torch.nn as nn
 import torch.nn.functional as F
 from torch.utils.data import DataLoader
 from torch.optim import Optimizer
+from sklearn.metrics import confusion_matrix
 
 # locals
 from pysegcnn.core.dataset import SupportedDatasets
@@ -2611,14 +2612,9 @@ class NetworkInference(BaseConfig):
             # check whether to calculate confusion matrix
             if self.cm:
 
-                # initialize confusion matrix
-                conf_mat = np.zeros(shape=2 * (len(self.src_ds.labels), ))
-
                 # calculate confusion matrix
-                for ytrue, ypred in zip(output['y'].flatten(),
-                                        output['y_pred'].flatten()):
-                    # update confusion matrix entries
-                    conf_mat[ytrue.long(), ypred.long()] += 1
+                conf_mat = confusion_matrix(output['y'].numpy().flatten(),
+                                            output['y_pred'].numpy().flatten())
 
                 # add confusion matrix to model output
                 output['cm'] = conf_mat
-- 
GitLab