diff --git a/pysegcnn/core/trainer.py b/pysegcnn/core/trainer.py index 261fcb7ded55d0fb3e2762bddba7d2c13eb767c2..44b68929a82f8f56cac90bab07d91da2df3cd217 100644 --- a/pysegcnn/core/trainer.py +++ b/pysegcnn/core/trainer.py @@ -2736,6 +2736,7 @@ class NetworkInference(BaseConfig): df = pd.concat([df, output['report']], axis=0) # compute k-fold average estimate of each metric across all models + LOGGER.info('Calculating k-fold estimate of metrics ...') report = df.groupby(df.index, sort=False).mean() inference['report'] = report @@ -2754,8 +2755,8 @@ class NetworkInference(BaseConfig): cm_agg = np.zeros(shape=2 * (len(labels), )) # update aggregated confusion matrix - for _, output in inference.items(): - cm_agg += output['cm'] + for _, metrics in inference.items(): + cm_agg += metrics['cm'] # save aggregated confusion matrix to dictionary inference['cm'] = cm_agg