From d101b65d92ae41d1f7328373d5c78ebbd69071b6 Mon Sep 17 00:00:00 2001
From: "Daniel.Frisinghelli" <daniel.frisinghelli@eurac.edu>
Date: Fri, 5 Feb 2021 17:06:49 +0100
Subject: [PATCH] Debugged classification report logging.

---
 pysegcnn/core/trainer.py | 13 ++++++++-----
 pysegcnn/core/utils.py   |  2 +-
 2 files changed, 9 insertions(+), 6 deletions(-)

diff --git a/pysegcnn/core/trainer.py b/pysegcnn/core/trainer.py
index 8116c8e..d2bad54 100644
--- a/pysegcnn/core/trainer.py
+++ b/pysegcnn/core/trainer.py
@@ -2661,9 +2661,9 @@ class NetworkInference(BaseConfig):
             output['y_pred'] = y_pred
 
             # calculate classification report from sklearn
-            LOGGER.info('Classification report')
-            LOGGER.info(classification_report(y_true, y_pred, target_names=[
-                        v['label'] for _, v in self.source_labels.items()]))
+            LOGGER.info('Calculating classification report: {}'
+                        .format(self.report_path.joinpath(
+                            self.report_name(state))))
 
             # export report to Latex table
             report = classification_report(y_true, y_pred, target_names=[
@@ -2710,8 +2710,11 @@ class NetworkInference(BaseConfig):
             LOGGER.info('Aggregating statistics of models:')
             LOGGER.info(('\n ' + (len(__name__) + 1) * ' ').join(
                 ['{}'.format(mstate.name) for mstate in self.state_files]))
-            LOGGER.info(classification_report(y_true, y_pred, target_names=[
-                        v['label'] for _, v in self.source_labels.items()]))
+
+            # calculate classification report from sklearn
+            LOGGER.info('Calculating classification report: {}'
+                        .format(self.report_path.joinpath(self.report_name(
+                            base_name.replace(fold_number, 'kfold')))))
 
             # export aggregated report to Latex table
             report = classification_report(y_true, y_pred, target_names=[
diff --git a/pysegcnn/core/utils.py b/pysegcnn/core/utils.py
index ebf9859..caa3959 100644
--- a/pysegcnn/core/utils.py
+++ b/pysegcnn/core/utils.py
@@ -2649,7 +2649,7 @@ def report2latex(classification_report, filename=None):
 
     """
     # convert to pandas DataFrame and export to latex
-    df = pd.DataFrame(classification_report)
+    df = pd.DataFrame.from_dict(classification_report)
 
     # check if output filename exists
     if filename is not None:
-- 
GitLab