diff --git a/pysegcnn/core/trainer.py b/pysegcnn/core/trainer.py index 6f3405901de818de98b37c95fde306bfe1147829..c8020f7730c87bbece9e11f218f4025b560408dc 100644 --- a/pysegcnn/core/trainer.py +++ b/pysegcnn/core/trainer.py @@ -2717,7 +2717,8 @@ class NetworkInference(BaseConfig): .format(report_name)) report = classification_report( y_true, y_pred, target_names=self.class_names, - output_dict=True, zero_division=1) + output_dict=True, zero_division=1, + labels=np.asarray(list(self.use_labels.keys()))) # store report in output dictionary output['report'] = report2df(report, self.class_names)