diff --git a/pysegcnn/core/trainer.py b/pysegcnn/core/trainer.py index 97ce7c19daa0b1561c0857528c7b2c29553590dd..60f9362765ad3aacdc6d3dd4ccd1723ea263985f 100644 --- a/pysegcnn/core/trainer.py +++ b/pysegcnn/core/trainer.py @@ -35,6 +35,7 @@ import torch.nn as nn import torch.nn.functional as F from torch.utils.data import DataLoader from torch.optim import Optimizer +from sklearn.metrics import confusion_matrix # locals from pysegcnn.core.dataset import SupportedDatasets @@ -2611,14 +2612,9 @@ class NetworkInference(BaseConfig): # check whether to calculate confusion matrix if self.cm: - # initialize confusion matrix - conf_mat = np.zeros(shape=2 * (len(self.src_ds.labels), )) - # calculate confusion matrix - for ytrue, ypred in zip(output['y'].flatten(), - output['y_pred'].flatten()): - # update confusion matrix entries - conf_mat[ytrue.long(), ypred.long()] += 1 + conf_mat = confusion_matrix(output['y'].numpy().flatten(), + output['y_pred'].numpy().flatten()) # add confusion matrix to model output output['cm'] = conf_mat