From cea0dd0174b41f4f0358d2134e64cb7bca554435 Mon Sep 17 00:00:00 2001
From: "Daniel.Frisinghelli" <daniel.frisinghelli@eurac.edu>
Date: Wed, 10 Feb 2021 14:44:02 +0100
Subject: [PATCH] Implemented evaluation of classification report.

---
 pysegcnn/core/graphics.py |  82 ++++++++++++++++---
 pysegcnn/core/trainer.py  | 165 ++++++++++++++++++++++----------------
 2 files changed, 164 insertions(+), 83 deletions(-)

diff --git a/pysegcnn/core/graphics.py b/pysegcnn/core/graphics.py
index b8f2e8a..675de17 100644
--- a/pysegcnn/core/graphics.py
+++ b/pysegcnn/core/graphics.py
@@ -22,11 +22,13 @@ import logging
 
 # externals
 import numpy as np
+import pandas as pd
 import torch
 import matplotlib
 import matplotlib.pyplot as plt
 import matplotlib.patches as mpatches
 import matplotlib.lines as mlines
+import seaborn as sns
 from matplotlib.colors import ListedColormap, BoundaryNorm
 from matplotlib.animation import ArtistAnimation
 from matplotlib import cm as colormap
@@ -193,7 +195,7 @@ def plot_sample(x, use_bands, labels,
     accuracy : `bool`, optional
         Whether to calculate the accuracy of the predictions ``y_pred`` with
         respect to the ground truth ``y``. The default is `False`
-    figsize : `tuple`, optional
+    figsize : `tuple` [`int`], optional
         The figure size in centimeters. The default is `(16, 9)`.
     bands : `list` [`str`], optional
         The bands to build the FCC. The default is `['red', 'green', 'blue']`.
@@ -351,16 +353,11 @@ def plot_confusion_matrix(cm, labels, normalize=True, figsize=(10, 10),
     ----------
     cm : :py:class:`numpy.ndarray`
         The confusion matrix.
-    labels : `dict` [`int`, `dict`]
-        The label dictionary. The keys are the values of the class labels
-        in the ground truth ``y``. Each nested `dict` should have keys:
-            ``'color'``
-                A named color (`str`).
-            ``'label'``
-                The name of the class label (`str`).
+    labels : `list` [`str`]
+        Names of the classes.
     normalize : `bool`, optional
         Whether to normalize the confusion matrix. The default is `True`.
-    figsize : `tuple`, optional
+    figsize : `tuple` [`int`], optional
         The figure size in centimeters. The default is `(10, 10)`.
     cmap : `str`, optional
         A matplotlib colormap. The default is `'Blues'`.
@@ -384,7 +381,6 @@ def plot_confusion_matrix(cm, labels, normalize=True, figsize=(10, 10),
 
     """
     # number of classes
-    labels = [label['label'] for label in labels.values()]
     nclasses = len(labels)
 
     # string format to plot values of confusion matrix
@@ -461,7 +457,7 @@ def plot_loss(state_file, figsize=(10, 10), step=5,
     state_file : `str` or :py:class:`pathlib.Path`
         The model state file. Model state files are stored in
         `pysegcnn/main/_models`.
-    figsize : `tuple`, optional
+    figsize : `tuple` [`int`], optional
         The figure size in centimeters. The default is `(10, 10)`.
     step : `int`, optional
         The step to label epochs on the x-axis labels. The default is `5`, i.e.
@@ -581,8 +577,8 @@ def plot_class_distribution(ds, figsize=(16, 9), alpha=0.5):
 
     Returns
     -------
-    cls_df : :py:class:`pandas.DataFrame`
-        The class distribution DataFrame.
+    fig : :py:class:`matplotlib.figure.Figure`
+        An instance of :py:class:`matplotlib.figure.Figure`.
 
     """
     # compute class distribution
@@ -683,6 +679,66 @@ def plot_class_distribution(ds, figsize=(16, 9), alpha=0.5):
     return fig
 
 
+def plot_classification_report(report, labels, figsize=(10, 10), **kwargs):
+    """Plot the :py:func:`sklearn.metrics.classification_report` as heatmap.
+
+    Parameters
+    ----------
+    report : `dict`
+        The dictionary returned by setting ``output_dict=True`` in
+        :py:func:`sklearn.metrics.classification_report`.
+    labels : `list` [`str`]
+        Names of the classes.
+    figsize : `tuple` [`int`], optional
+        The figure size in centimeters. The default is `(10, 10)`.
+    **kwargs :
+        Additional keyword arguments passed to :py:func:`seaborn.heatmap`.
+
+    Returns
+    -------
+     fig : :py:class:`matplotlib.figure.Figure`
+        An instance of :py:class:`matplotlib.figure.Figure`.
+
+    """
+    # overall accuracy
+    overall_accuracy = report['accuracy']
+
+    # convert classification report to pandas DataFrame
+    report_df = pd.DataFrame(report)
+
+    # create a DataFrame only consisting of the class-wise statistics
+    class_statistics = report_df[labels].transpose()
+
+    # create a DataFrame only consisting of the average metrics
+    avg_metrics = report_df.drop(columns=labels + ['accuracy']).transpose()
+    avg_metrics.support = 1
+
+    # convert support values to relative values
+    class_statistics.support = (class_statistics.support /
+                                class_statistics.support.sum())
+
+    # merge dataframes
+    metrics = class_statistics.append(avg_metrics)
+
+    # create a figure
+    fig, ax = plt.subplots(1, 1, figsize=figsize)
+
+    # plot class wise statistics as heatmap
+    sns.heatmap(metrics, vmin=0, vmax=1, annot=True, fmt='.2f',
+                ax=ax, xticklabels=[c.capitalize() for c in metrics.columns],
+                yticklabels=[r.capitalize() for r in metrics.index], **kwargs)
+
+    # add a white line separating class-wise and average statistics
+    ax.plot(np.arange(0, len(metrics.columns) + 1),
+            np.tile(len(labels), len(metrics.columns) + 1),
+            color='white', lw=3)
+
+    # set figure title
+    ax.set_title('Overall accuracy: {:.2f}'.format(overall_accuracy), pad=20)
+
+    return fig
+
+
 class Animate(object):
     """Easily create animations with :py:mod:`matplotlib`.
 
diff --git a/pysegcnn/core/trainer.py b/pysegcnn/core/trainer.py
index a2e6170..a109d7a 100644
--- a/pysegcnn/core/trainer.py
+++ b/pysegcnn/core/trainer.py
@@ -42,7 +42,7 @@ from pysegcnn.core.dataset import SupportedDatasets
 from pysegcnn.core.transforms import Augment
 from pysegcnn.core.utils import (item_in_enum, accuracy_function,
                                  reconstruct_scene, check_filename_length,
-                                 array_replace, report2latex)
+                                 array_replace)
 from pysegcnn.core.split import SupportedSplits
 from pysegcnn.core.models import (SupportedModels, SupportedOptimizers,
                                   Network)
@@ -50,7 +50,7 @@ from pysegcnn.core.uda import SupportedUdaMethods, CoralLoss, UDA_POSITIONS
 from pysegcnn.core.layers import Conv2dSame
 from pysegcnn.core.logging import log_conf
 from pysegcnn.core.graphics import (plot_loss, plot_confusion_matrix,
-                                    plot_sample)
+                                    plot_sample, plot_classification_report)
 from pysegcnn.core.constants import map_labels
 from pysegcnn.main.train_config import HERE
 
@@ -2061,7 +2061,7 @@ class NetworkInference(BaseConfig):
         Path to store plots of model predictions for entire scenes.
     perfmc_path : :py:class:`pathlib.Path`
         Path to store plots of model performance, e.g. confusion matrix.
-    models_path : :py:class:`pathlib.Path`
+    report_path : :py:class:`pathlib.Path`
         Path to store the :py:func:`sklearn.metrics.classification_reports`.
     models_path : :py:class:`pathlib.Path`
         Path to search for model state files ``state_files``.
@@ -2365,6 +2365,18 @@ class NetworkInference(BaseConfig):
         return (self.target_labels if self.apply_label_map else
                 self.source_labels)
 
+    @property
+    def class_names(self):
+        """Class label names to be predicted.
+
+        Returns
+        -------
+        labels : `list` [`str`]
+            Names of the classes.
+
+        """
+        return [label['label'] for label in self.use_labels.values()]
+
     @property
     def bands(self):
         """Spectral bands the model was trained with.
@@ -2570,7 +2582,7 @@ class NetworkInference(BaseConfig):
                 # plot inputs, ground truth and model predictions
                 _ = plot_sample(inputs.clip(0, 1),
                                 self.bands,
-                                self.source_labels,
+                                self.use_labels,
                                 y=labels,
                                 y_pred={model.__class__.__name__: prdctn},
                                 accuracy=True,
@@ -2585,7 +2597,7 @@ class NetworkInference(BaseConfig):
                                                     '_eval.pt'))
 
     def report_name(self, state_file):
-        return str(state_file.name).replace(state_file.suffix, '_cr.tex')
+        return str(state_file.name).replace(state_file.suffix, '_cr.png')
 
     def evaluate(self):
         """Evaluate the models on a defined dataset.
@@ -2608,6 +2620,9 @@ class NetworkInference(BaseConfig):
                 ``'cm'``
                     The confusion matrix of the model, which is only present if
                     ``self.cm=True`` (:py:class:`numpy.ndarray`).
+                ``report``
+                    The classification report dictionary as returned by
+                    :py:func:`sklearn.metrics.classification_report`.
 
         """
         # iterate over the models to evaluate
@@ -2627,76 +2642,81 @@ class NetworkInference(BaseConfig):
                 # load existing model evaluation
                 if not self.overwrite:
                     inference[state.stem] = torch.load(self.eval_file(state))
-                    continue
                 else:
                     # overwrite existing model evaluation
                     LOGGER.info('Overwriting model evaluation: {}.'
                                 .format(self.eval_file(state)))
                     self.eval_file(state).unlink()
 
-            # plot loss and accuracy
-            plot_loss(check_filename_length(state), outpath=self.perfmc_path)
-
-            # load the target dataset to evaluate the model on
-            self.trg_ds = self.load_dataset(
-                state, implicit=self.implicit, test=self.test,
-                domain=self.domain)
-
-            # load the source dataset the model was trained on
-            self.src_ds = self.load_dataset(state, test=None)
-
-            # load the pretrained model
-            model, _ = Network.load_pretrained_model(state)
-
-            # evaluate the model on the target dataset
-            output = self.predict(model)
-
-            # merge predictions of the different samples
-            y_true = np.asarray([v['y_true'].flatten() for _, v
-                                 in output.items()]).flatten()
-            y_pred = np.asarray([v['y_pred'].flatten() for _, v
-                                 in output.items()]).flatten()
-
-            # predictions and ground truth of the entire target dataset
-            output['y_true'] = y_true
-            output['y_pred'] = y_pred
-
-            # classification report labels
-            cr_labels = [v['label'] for _, v in self.source_labels.items()]
-
-            # calculate classification report from sklearn
-            report_name = self.report_path.joinpath(self.report_name(state))
-            LOGGER.info('Calculating classification report: {}'
-                        .format(report_name))
-
-            # export report to Latex table
-            report = classification_report(
-                y_true, y_pred, target_names=cr_labels, output_dict=True)
-            report2latex(report, filename=report_name)
-
-            # check whether to calculate confusion matrix
-            if self.cm:
+            else:
 
-                # calculate confusion matrix
-                LOGGER.info('Computing confusion matrix ...')
-                conf_mat = confusion_matrix(y_true, y_pred)
+                # plot loss and accuracy
+                plot_loss(check_filename_length(state),
+                          outpath=self.perfmc_path)
 
-                # add confusion matrix to model output
-                output['cm'] = conf_mat
+                # load the target dataset to evaluate the model on
+                self.trg_ds = self.load_dataset(
+                    state, implicit=self.implicit, test=self.test,
+                    domain=self.domain)
 
-                # plot confusion matrix
-                plot_confusion_matrix(
-                    conf_mat, self.source_labels, state_file=state,
-                    subset=self.trg_ds.name,
-                    outpath=self.perfmc_path)
+                # load the source dataset the model was trained on
+                self.src_ds = self.load_dataset(state, test=None)
 
-            # save model predictions to file
-            LOGGER.info('Saving model evaluation: {}'
-                        .format(self.eval_file(state)))
-            torch.save(output, self.eval_file(state))
+                # load the pretrained model
+                model, _ = Network.load_pretrained_model(state)
+
+                # evaluate the model on the target dataset
+                output = self.predict(model)
+
+                # merge predictions of the different samples
+                y_true = np.asarray([v['y_true'].flatten() for _, v
+                                     in output.items()]).flatten()
+                y_pred = np.asarray([v['y_pred'].flatten() for _, v
+                                     in output.items()]).flatten()
+
+                # predictions and ground truth of the entire target dataset
+                output['y_true'] = y_true
+                output['y_pred'] = y_pred
+
+                # calculate classification report from sklearn
+                report_name = self.report_path.joinpath(
+                    self.report_name(state))
+                LOGGER.info('Calculating classification report: {}'
+                            .format(report_name))
+                report = classification_report(
+                    y_true, y_pred, target_names=self.class_names,
+                    output_dict=True)
+
+                # store report in output dictionary
+                output['report'] = report
+
+                # plot classification report
+                fig = plot_classification_report(report, self.class_names)
+                fig.savefig(report_name, dpi=300, bbox_inches='tight')
+
+                # check whether to calculate confusion matrix
+                if self.cm:
+
+                    # calculate confusion matrix
+                    LOGGER.info('Computing confusion matrix ...')
+                    conf_mat = confusion_matrix(y_true, y_pred)
+
+                    # add confusion matrix to model output
+                    output['cm'] = conf_mat
+
+                    # plot confusion matrix
+                    plot_confusion_matrix(
+                        conf_mat, self.class_names, state_file=state,
+                        subset=self.trg_ds.name,
+                        outpath=self.perfmc_path)
+
+                # save model predictions to file
+                LOGGER.info('Saving model evaluation: {}'
+                            .format(self.eval_file(state)))
+                torch.save(output, self.eval_file(state))
 
-            # save model predictions to list
-            inference[state.stem] = output
+                # save model predictions to list
+                inference[state.stem] = output
 
         # check whether to aggregate the results of the different model runs
         if self.aggregate:
@@ -2722,11 +2742,16 @@ class NetworkInference(BaseConfig):
             report_name = self.report_path.joinpath(self.report_name(kfold))
             LOGGER.info('Calculating classification report: {}'
                         .format(report_name))
+            report = classification_report(y_true, y_pred,
+                                           target_names=self.class_names,
+                                           output_dict=True)
+
+            # save aggregated classification report
+            inference['report'] = report
 
-            # export aggregated report to Latex table
-            report = classification_report(
-                y_true, y_pred, target_names=cr_labels, output_dict=True)
-            report2latex(report, filename=report_name)
+            # plot classification report
+            fig = plot_classification_report(report, self.class_names)
+            fig.savefig(report_name, dpi=300, bbox_inches='tight')
 
             # chech whether to compute the aggregated confusion matrix
             if self.cm:
@@ -2739,11 +2764,11 @@ class NetworkInference(BaseConfig):
                     cm_agg += output['cm']
 
                 # save aggregated confusion matrix to dictionary
-                inference['cm_agg'] = cm_agg
+                inference['cm'] = cm_agg
 
                 # plot aggregated confusion matrix and save to file
                 plot_confusion_matrix(
-                    cm_agg, self.source_labels, state_file=kfold,
+                    cm_agg, self.class_names, state_file=kfold,
                     outpath=self.perfmc_path)
 
         return inference
-- 
GitLab