diff --git a/pysegcnn/core/graphics.py b/pysegcnn/core/graphics.py index b8f2e8a7ab313c540f21578ba100ac2549055733..675de171f8572c3e485faa625691f9f4535c6b16 100644 --- a/pysegcnn/core/graphics.py +++ b/pysegcnn/core/graphics.py @@ -22,11 +22,13 @@ import logging # externals import numpy as np +import pandas as pd import torch import matplotlib import matplotlib.pyplot as plt import matplotlib.patches as mpatches import matplotlib.lines as mlines +import seaborn as sns from matplotlib.colors import ListedColormap, BoundaryNorm from matplotlib.animation import ArtistAnimation from matplotlib import cm as colormap @@ -193,7 +195,7 @@ def plot_sample(x, use_bands, labels, accuracy : `bool`, optional Whether to calculate the accuracy of the predictions ``y_pred`` with respect to the ground truth ``y``. The default is `False` - figsize : `tuple`, optional + figsize : `tuple` [`int`], optional The figure size in centimeters. The default is `(16, 9)`. bands : `list` [`str`], optional The bands to build the FCC. The default is `['red', 'green', 'blue']`. @@ -351,16 +353,11 @@ def plot_confusion_matrix(cm, labels, normalize=True, figsize=(10, 10), ---------- cm : :py:class:`numpy.ndarray` The confusion matrix. - labels : `dict` [`int`, `dict`] - The label dictionary. The keys are the values of the class labels - in the ground truth ``y``. Each nested `dict` should have keys: - ``'color'`` - A named color (`str`). - ``'label'`` - The name of the class label (`str`). + labels : `list` [`str`] + Names of the classes. normalize : `bool`, optional Whether to normalize the confusion matrix. The default is `True`. - figsize : `tuple`, optional + figsize : `tuple` [`int`], optional The figure size in centimeters. The default is `(10, 10)`. cmap : `str`, optional A matplotlib colormap. The default is `'Blues'`. @@ -384,7 +381,6 @@ def plot_confusion_matrix(cm, labels, normalize=True, figsize=(10, 10), """ # number of classes - labels = [label['label'] for label in labels.values()] nclasses = len(labels) # string format to plot values of confusion matrix @@ -461,7 +457,7 @@ def plot_loss(state_file, figsize=(10, 10), step=5, state_file : `str` or :py:class:`pathlib.Path` The model state file. Model state files are stored in `pysegcnn/main/_models`. - figsize : `tuple`, optional + figsize : `tuple` [`int`], optional The figure size in centimeters. The default is `(10, 10)`. step : `int`, optional The step to label epochs on the x-axis labels. The default is `5`, i.e. @@ -581,8 +577,8 @@ def plot_class_distribution(ds, figsize=(16, 9), alpha=0.5): Returns ------- - cls_df : :py:class:`pandas.DataFrame` - The class distribution DataFrame. + fig : :py:class:`matplotlib.figure.Figure` + An instance of :py:class:`matplotlib.figure.Figure`. """ # compute class distribution @@ -683,6 +679,66 @@ def plot_class_distribution(ds, figsize=(16, 9), alpha=0.5): return fig +def plot_classification_report(report, labels, figsize=(10, 10), **kwargs): + """Plot the :py:func:`sklearn.metrics.classification_report` as heatmap. + + Parameters + ---------- + report : `dict` + The dictionary returned by setting ``output_dict=True`` in + :py:func:`sklearn.metrics.classification_report`. + labels : `list` [`str`] + Names of the classes. + figsize : `tuple` [`int`], optional + The figure size in centimeters. The default is `(10, 10)`. + **kwargs : + Additional keyword arguments passed to :py:func:`seaborn.heatmap`. + + Returns + ------- + fig : :py:class:`matplotlib.figure.Figure` + An instance of :py:class:`matplotlib.figure.Figure`. + + """ + # overall accuracy + overall_accuracy = report['accuracy'] + + # convert classification report to pandas DataFrame + report_df = pd.DataFrame(report) + + # create a DataFrame only consisting of the class-wise statistics + class_statistics = report_df[labels].transpose() + + # create a DataFrame only consisting of the average metrics + avg_metrics = report_df.drop(columns=labels + ['accuracy']).transpose() + avg_metrics.support = 1 + + # convert support values to relative values + class_statistics.support = (class_statistics.support / + class_statistics.support.sum()) + + # merge dataframes + metrics = class_statistics.append(avg_metrics) + + # create a figure + fig, ax = plt.subplots(1, 1, figsize=figsize) + + # plot class wise statistics as heatmap + sns.heatmap(metrics, vmin=0, vmax=1, annot=True, fmt='.2f', + ax=ax, xticklabels=[c.capitalize() for c in metrics.columns], + yticklabels=[r.capitalize() for r in metrics.index], **kwargs) + + # add a white line separating class-wise and average statistics + ax.plot(np.arange(0, len(metrics.columns) + 1), + np.tile(len(labels), len(metrics.columns) + 1), + color='white', lw=3) + + # set figure title + ax.set_title('Overall accuracy: {:.2f}'.format(overall_accuracy), pad=20) + + return fig + + class Animate(object): """Easily create animations with :py:mod:`matplotlib`. diff --git a/pysegcnn/core/trainer.py b/pysegcnn/core/trainer.py index a2e6170a7f11fadfc99f0fd314412505eeca8de6..a109d7a5999c86494f9e75cf407d032e0eadc7ab 100644 --- a/pysegcnn/core/trainer.py +++ b/pysegcnn/core/trainer.py @@ -42,7 +42,7 @@ from pysegcnn.core.dataset import SupportedDatasets from pysegcnn.core.transforms import Augment from pysegcnn.core.utils import (item_in_enum, accuracy_function, reconstruct_scene, check_filename_length, - array_replace, report2latex) + array_replace) from pysegcnn.core.split import SupportedSplits from pysegcnn.core.models import (SupportedModels, SupportedOptimizers, Network) @@ -50,7 +50,7 @@ from pysegcnn.core.uda import SupportedUdaMethods, CoralLoss, UDA_POSITIONS from pysegcnn.core.layers import Conv2dSame from pysegcnn.core.logging import log_conf from pysegcnn.core.graphics import (plot_loss, plot_confusion_matrix, - plot_sample) + plot_sample, plot_classification_report) from pysegcnn.core.constants import map_labels from pysegcnn.main.train_config import HERE @@ -2061,7 +2061,7 @@ class NetworkInference(BaseConfig): Path to store plots of model predictions for entire scenes. perfmc_path : :py:class:`pathlib.Path` Path to store plots of model performance, e.g. confusion matrix. - models_path : :py:class:`pathlib.Path` + report_path : :py:class:`pathlib.Path` Path to store the :py:func:`sklearn.metrics.classification_reports`. models_path : :py:class:`pathlib.Path` Path to search for model state files ``state_files``. @@ -2365,6 +2365,18 @@ class NetworkInference(BaseConfig): return (self.target_labels if self.apply_label_map else self.source_labels) + @property + def class_names(self): + """Class label names to be predicted. + + Returns + ------- + labels : `list` [`str`] + Names of the classes. + + """ + return [label['label'] for label in self.use_labels.values()] + @property def bands(self): """Spectral bands the model was trained with. @@ -2570,7 +2582,7 @@ class NetworkInference(BaseConfig): # plot inputs, ground truth and model predictions _ = plot_sample(inputs.clip(0, 1), self.bands, - self.source_labels, + self.use_labels, y=labels, y_pred={model.__class__.__name__: prdctn}, accuracy=True, @@ -2585,7 +2597,7 @@ class NetworkInference(BaseConfig): '_eval.pt')) def report_name(self, state_file): - return str(state_file.name).replace(state_file.suffix, '_cr.tex') + return str(state_file.name).replace(state_file.suffix, '_cr.png') def evaluate(self): """Evaluate the models on a defined dataset. @@ -2608,6 +2620,9 @@ class NetworkInference(BaseConfig): ``'cm'`` The confusion matrix of the model, which is only present if ``self.cm=True`` (:py:class:`numpy.ndarray`). + ``report`` + The classification report dictionary as returned by + :py:func:`sklearn.metrics.classification_report`. """ # iterate over the models to evaluate @@ -2627,76 +2642,81 @@ class NetworkInference(BaseConfig): # load existing model evaluation if not self.overwrite: inference[state.stem] = torch.load(self.eval_file(state)) - continue else: # overwrite existing model evaluation LOGGER.info('Overwriting model evaluation: {}.' .format(self.eval_file(state))) self.eval_file(state).unlink() - # plot loss and accuracy - plot_loss(check_filename_length(state), outpath=self.perfmc_path) - - # load the target dataset to evaluate the model on - self.trg_ds = self.load_dataset( - state, implicit=self.implicit, test=self.test, - domain=self.domain) - - # load the source dataset the model was trained on - self.src_ds = self.load_dataset(state, test=None) - - # load the pretrained model - model, _ = Network.load_pretrained_model(state) - - # evaluate the model on the target dataset - output = self.predict(model) - - # merge predictions of the different samples - y_true = np.asarray([v['y_true'].flatten() for _, v - in output.items()]).flatten() - y_pred = np.asarray([v['y_pred'].flatten() for _, v - in output.items()]).flatten() - - # predictions and ground truth of the entire target dataset - output['y_true'] = y_true - output['y_pred'] = y_pred - - # classification report labels - cr_labels = [v['label'] for _, v in self.source_labels.items()] - - # calculate classification report from sklearn - report_name = self.report_path.joinpath(self.report_name(state)) - LOGGER.info('Calculating classification report: {}' - .format(report_name)) - - # export report to Latex table - report = classification_report( - y_true, y_pred, target_names=cr_labels, output_dict=True) - report2latex(report, filename=report_name) - - # check whether to calculate confusion matrix - if self.cm: + else: - # calculate confusion matrix - LOGGER.info('Computing confusion matrix ...') - conf_mat = confusion_matrix(y_true, y_pred) + # plot loss and accuracy + plot_loss(check_filename_length(state), + outpath=self.perfmc_path) - # add confusion matrix to model output - output['cm'] = conf_mat + # load the target dataset to evaluate the model on + self.trg_ds = self.load_dataset( + state, implicit=self.implicit, test=self.test, + domain=self.domain) - # plot confusion matrix - plot_confusion_matrix( - conf_mat, self.source_labels, state_file=state, - subset=self.trg_ds.name, - outpath=self.perfmc_path) + # load the source dataset the model was trained on + self.src_ds = self.load_dataset(state, test=None) - # save model predictions to file - LOGGER.info('Saving model evaluation: {}' - .format(self.eval_file(state))) - torch.save(output, self.eval_file(state)) + # load the pretrained model + model, _ = Network.load_pretrained_model(state) + + # evaluate the model on the target dataset + output = self.predict(model) + + # merge predictions of the different samples + y_true = np.asarray([v['y_true'].flatten() for _, v + in output.items()]).flatten() + y_pred = np.asarray([v['y_pred'].flatten() for _, v + in output.items()]).flatten() + + # predictions and ground truth of the entire target dataset + output['y_true'] = y_true + output['y_pred'] = y_pred + + # calculate classification report from sklearn + report_name = self.report_path.joinpath( + self.report_name(state)) + LOGGER.info('Calculating classification report: {}' + .format(report_name)) + report = classification_report( + y_true, y_pred, target_names=self.class_names, + output_dict=True) + + # store report in output dictionary + output['report'] = report + + # plot classification report + fig = plot_classification_report(report, self.class_names) + fig.savefig(report_name, dpi=300, bbox_inches='tight') + + # check whether to calculate confusion matrix + if self.cm: + + # calculate confusion matrix + LOGGER.info('Computing confusion matrix ...') + conf_mat = confusion_matrix(y_true, y_pred) + + # add confusion matrix to model output + output['cm'] = conf_mat + + # plot confusion matrix + plot_confusion_matrix( + conf_mat, self.class_names, state_file=state, + subset=self.trg_ds.name, + outpath=self.perfmc_path) + + # save model predictions to file + LOGGER.info('Saving model evaluation: {}' + .format(self.eval_file(state))) + torch.save(output, self.eval_file(state)) - # save model predictions to list - inference[state.stem] = output + # save model predictions to list + inference[state.stem] = output # check whether to aggregate the results of the different model runs if self.aggregate: @@ -2722,11 +2742,16 @@ class NetworkInference(BaseConfig): report_name = self.report_path.joinpath(self.report_name(kfold)) LOGGER.info('Calculating classification report: {}' .format(report_name)) + report = classification_report(y_true, y_pred, + target_names=self.class_names, + output_dict=True) + + # save aggregated classification report + inference['report'] = report - # export aggregated report to Latex table - report = classification_report( - y_true, y_pred, target_names=cr_labels, output_dict=True) - report2latex(report, filename=report_name) + # plot classification report + fig = plot_classification_report(report, self.class_names) + fig.savefig(report_name, dpi=300, bbox_inches='tight') # chech whether to compute the aggregated confusion matrix if self.cm: @@ -2739,11 +2764,11 @@ class NetworkInference(BaseConfig): cm_agg += output['cm'] # save aggregated confusion matrix to dictionary - inference['cm_agg'] = cm_agg + inference['cm'] = cm_agg # plot aggregated confusion matrix and save to file plot_confusion_matrix( - cm_agg, self.source_labels, state_file=kfold, + cm_agg, self.class_names, state_file=kfold, outpath=self.perfmc_path) return inference