Skip to content
Snippets Groups Projects
Commit da1ba8b9 authored by Frisinghelli Daniel's avatar Frisinghelli Daniel
Browse files

Changed default colormap.

parent 69dc2936
No related branches found
No related tags found
No related merge requests found
......@@ -345,7 +345,7 @@ def plot_sample(x, use_bands, labels,
def plot_confusion_matrix(cm, labels, normalize=True, figsize=(10, 10),
cmap='viridis', state_file=None, subset=None,
cmap='YlGnBu', state_file=None, subset=None,
outpath=os.path.join(HERE, '_graphics/')):
"""Plot the confusion matrix ``cm``.
......@@ -360,7 +360,7 @@ def plot_confusion_matrix(cm, labels, normalize=True, figsize=(10, 10),
figsize : `tuple` [`int`], optional
The figure size in centimeters. The default is `(10, 10)`.
cmap : `str`, optional
A matplotlib colormap. The default is `'viridis'`.
A matplotlib colormap. The default is `'YlGnBu'`.
state_file : `str` or `None` or :py:class:`pathlib.Path`, optional
Filename to save the plot to. ``state`` should be an existing model
state file ending with `'.pt'`. The default is `None`, i.e. the plot is
......@@ -684,7 +684,7 @@ def plot_class_distribution(ds, figsize=(16, 9), alpha=0.5):
def plot_classification_report(report, labels, figsize=(10, 10),
cmap='viridis', **kwargs):
cmap='YlGnBu', **kwargs):
"""Plot the :py:func:`sklearn.metrics.classification_report` as heatmap.
Parameters
......@@ -699,7 +699,7 @@ def plot_classification_report(report, labels, figsize=(10, 10),
figsize : `tuple` [`int`], optional
The figure size in centimeters. The default is `(10, 10)`.
cmap : `str`, optional
A matplotlib colormap. The default is `'viridis'`.
A matplotlib colormap. The default is `'YlGnBu'`.
**kwargs :
Additional keyword arguments passed to :py:func:`seaborn.heatmap`.
......@@ -725,7 +725,8 @@ def plot_classification_report(report, labels, figsize=(10, 10),
# plot class wise statistics as heatmap
sns.heatmap(metrics, vmin=0, vmax=1, annot=True, fmt='.2f',
ax=ax, xticklabels=[c.capitalize() for c in metrics.columns],
yticklabels=[r.capitalize() for r in metrics.index], **kwargs)
yticklabels=[r.capitalize() for r in metrics.index], cmap=cmap,
**kwargs)
# add a white line separating class-wise and average statistics
ax.plot(np.arange(0, len(metrics.columns) + 1),
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment