diff --git a/pysegcnn/core/trainer.py b/pysegcnn/core/trainer.py index a5d3506e3dc4a953daf7877aa272f09f5fe50ae1..dac20d26821b75e26e2de50955042ad1843b763d 100644 --- a/pysegcnn/core/trainer.py +++ b/pysegcnn/core/trainer.py @@ -2735,11 +2735,11 @@ class NetworkInference(BaseConfig): # move predictions and labels to GPU if available y_true = torch.Tensor(y_true).to(self.device) y_pred = torch.Tensor(y_pred).to(self.device) + labels_gpu = torch.Tensor( + np.asarray(list(self.use_labels.keys()))).to(self.device) # compute confusion matrix - conf_mat = confusion_matrix( - y_true, y_pred, - labels=np.asarray(list(self.use_labels.keys()))) + conf_mat = confusion_matrix(y_true, y_pred, labels=labels_gpu) # add confusion matrix to model output output['cm'] = conf_mat.cpu().numpy()