diff --git a/pysegcnn/core/graphics.py b/pysegcnn/core/graphics.py index 39321e54d0564cfdb8259a2bd0d3d028e6838c2f..0130baca096206b13e57756a6bcfc76c0a6303d5 100644 --- a/pysegcnn/core/graphics.py +++ b/pysegcnn/core/graphics.py @@ -345,7 +345,7 @@ def plot_sample(x, use_bands, labels, def plot_confusion_matrix(cm, labels, normalize=True, figsize=(10, 10), - cmap='Blues', state_file=None, subset=None, + cmap='viridis', state_file=None, subset=None, outpath=os.path.join(HERE, '_graphics/')): """Plot the confusion matrix ``cm``. @@ -360,7 +360,7 @@ def plot_confusion_matrix(cm, labels, normalize=True, figsize=(10, 10), figsize : `tuple` [`int`], optional The figure size in centimeters. The default is `(10, 10)`. cmap : `str`, optional - A matplotlib colormap. The default is `'Blues'`. + A matplotlib colormap. The default is `'viridis'`. state_file : `str` or `None` or :py:class:`pathlib.Path`, optional Filename to save the plot to. ``state`` should be an existing model state file ending with `'.pt'`. The default is `None`, i.e. the plot is @@ -683,7 +683,8 @@ def plot_class_distribution(ds, figsize=(16, 9), alpha=0.5): return fig -def plot_classification_report(report, labels, figsize=(10, 10), **kwargs): +def plot_classification_report(report, labels, figsize=(10, 10), + cmap='viridis', **kwargs): """Plot the :py:func:`sklearn.metrics.classification_report` as heatmap. Parameters @@ -697,6 +698,8 @@ def plot_classification_report(report, labels, figsize=(10, 10), **kwargs): Names of the classes. figsize : `tuple` [`int`], optional The figure size in centimeters. The default is `(10, 10)`. + cmap : `str`, optional + A matplotlib colormap. The default is `'viridis'`. **kwargs : Additional keyword arguments passed to :py:func:`seaborn.heatmap`. @@ -732,6 +735,14 @@ def plot_classification_report(report, labels, figsize=(10, 10), **kwargs): # set figure title ax.set_title('Accuracy: {:.2f}'.format(overall_accuracy), pad=20) + # rotate x-tick labels + for label in ax.get_xticklabels(): + label.set_rotation(90) + + # rotate y-tick labels + for label in ax.get_yticklabels(): + label.set_rotation(0) + return fig