From da1ba8b9a59564ac206a25cde82891d5b1937eaa Mon Sep 17 00:00:00 2001 From: "Daniel.Frisinghelli" <daniel.frisinghelli@eurac.edu> Date: Tue, 16 Feb 2021 16:31:11 +0100 Subject: [PATCH] Changed default colormap. --- pysegcnn/core/graphics.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/pysegcnn/core/graphics.py b/pysegcnn/core/graphics.py index 0130bac..d8b4066 100644 --- a/pysegcnn/core/graphics.py +++ b/pysegcnn/core/graphics.py @@ -345,7 +345,7 @@ def plot_sample(x, use_bands, labels, def plot_confusion_matrix(cm, labels, normalize=True, figsize=(10, 10), - cmap='viridis', state_file=None, subset=None, + cmap='YlGnBu', state_file=None, subset=None, outpath=os.path.join(HERE, '_graphics/')): """Plot the confusion matrix ``cm``. @@ -360,7 +360,7 @@ def plot_confusion_matrix(cm, labels, normalize=True, figsize=(10, 10), figsize : `tuple` [`int`], optional The figure size in centimeters. The default is `(10, 10)`. cmap : `str`, optional - A matplotlib colormap. The default is `'viridis'`. + A matplotlib colormap. The default is `'YlGnBu'`. state_file : `str` or `None` or :py:class:`pathlib.Path`, optional Filename to save the plot to. ``state`` should be an existing model state file ending with `'.pt'`. The default is `None`, i.e. the plot is @@ -684,7 +684,7 @@ def plot_class_distribution(ds, figsize=(16, 9), alpha=0.5): def plot_classification_report(report, labels, figsize=(10, 10), - cmap='viridis', **kwargs): + cmap='YlGnBu', **kwargs): """Plot the :py:func:`sklearn.metrics.classification_report` as heatmap. Parameters @@ -699,7 +699,7 @@ def plot_classification_report(report, labels, figsize=(10, 10), figsize : `tuple` [`int`], optional The figure size in centimeters. The default is `(10, 10)`. cmap : `str`, optional - A matplotlib colormap. The default is `'viridis'`. + A matplotlib colormap. The default is `'YlGnBu'`. **kwargs : Additional keyword arguments passed to :py:func:`seaborn.heatmap`. @@ -725,7 +725,8 @@ def plot_classification_report(report, labels, figsize=(10, 10), # plot class wise statistics as heatmap sns.heatmap(metrics, vmin=0, vmax=1, annot=True, fmt='.2f', ax=ax, xticklabels=[c.capitalize() for c in metrics.columns], - yticklabels=[r.capitalize() for r in metrics.index], **kwargs) + yticklabels=[r.capitalize() for r in metrics.index], cmap=cmap, + **kwargs) # add a white line separating class-wise and average statistics ax.plot(np.arange(0, len(metrics.columns) + 1), -- GitLab