diff --git a/pysegcnn/core/trainer.py b/pysegcnn/core/trainer.py index 39e677d1849f232d1e8fe6328d4c296189f372a6..88e0fc624c0999dd9a9ed2802a9e9f6dafa4b09d 100644 --- a/pysegcnn/core/trainer.py +++ b/pysegcnn/core/trainer.py @@ -2738,7 +2738,6 @@ class NetworkInference(BaseConfig): # compute k-fold average estimate of each metric across all models LOGGER.info('Calculating k-fold estimate of metrics ...') report = df.groupby(df.index, sort=False).mean() - inference['report'] = report # labels to predict labels = list(report.index.drop(['macro avg', 'weighted avg', @@ -2749,7 +2748,7 @@ class NetworkInference(BaseConfig): report_name = self.report_path.joinpath(self.report_name(kfold)) fig.savefig(report_name, dpi=300, bbox_inches='tight') - # chech whether to compute the aggregated confusion matrix + # check whether to compute the aggregated confusion matrix if self.cm: # initialize the aggregated confusion matrix cm_agg = np.zeros(shape=2 * (len(labels), )) @@ -2765,4 +2764,7 @@ class NetworkInference(BaseConfig): plot_confusion_matrix(cm_agg, labels, state_file=kfold, outpath=self.perfmc_path) + # add aggregated classification report + inference['report'] = report + return inference