Skip to content
Snippets Groups Projects
Commit 8199ac71 authored by Frisinghelli Daniel's avatar Frisinghelli Daniel
Browse files

Implemented the sklearn implementation of the confusion matrix.

parent 4a56331e
No related branches found
No related tags found
No related merge requests found
......@@ -35,6 +35,7 @@ import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torch.optim import Optimizer
from sklearn.metrics import confusion_matrix
# locals
from pysegcnn.core.dataset import SupportedDatasets
......@@ -2611,14 +2612,9 @@ class NetworkInference(BaseConfig):
# check whether to calculate confusion matrix
if self.cm:
# initialize confusion matrix
conf_mat = np.zeros(shape=2 * (len(self.src_ds.labels), ))
# calculate confusion matrix
for ytrue, ypred in zip(output['y'].flatten(),
output['y_pred'].flatten()):
# update confusion matrix entries
conf_mat[ytrue.long(), ypred.long()] += 1
conf_mat = confusion_matrix(output['y'].numpy().flatten(),
output['y_pred'].numpy().flatten())
# add confusion matrix to model output
output['cm'] = conf_mat
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment