diff --git a/pysegcnn/core/trainer.py b/pysegcnn/core/trainer.py index bc0f11c367660ec73191059492ab5f5d4a0e2d12..894ec9f0190f164c2a932b6c4786caa030aaf917 100644 --- a/pysegcnn/core/trainer.py +++ b/pysegcnn/core/trainer.py @@ -59,7 +59,7 @@ from pysegcnn.main.train_config import HERE LOGGER = logging.getLogger(__name__) # global variable: variable names of the model inference output -INFERENCE_NAMES = ['x', 'y', 'y_pred'] +INFERENCE_NAMES = ['x', 'y_true', 'y_pred'] @dataclasses.dataclass @@ -2696,9 +2696,9 @@ class NetworkInference(BaseConfig): output = self.predict(model) # merge predictions of the different samples - y_true = np.asarray([v['y_true'].flatten() for _, v + y_true = np.asarray([v[INFERENCE_NAMES[1]].flatten() for _, v in output.items()]).flatten() - y_pred = np.asarray([v['y_pred'].flatten() for _, v + y_pred = np.asarray([v[INFERENCE_NAMES[2]].flatten() for _, v in output.items()]).flatten() # calculate classification report from sklearn