From e6c2ec841ab4b6c550b43af2e3504beb46a3b61f Mon Sep 17 00:00:00 2001 From: "Daniel.Frisinghelli" <daniel.frisinghelli@eurac.edu> Date: Thu, 21 Oct 2021 12:48:26 +0200 Subject: [PATCH] Improved plotting for loss. --- climax/core/utils.py | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/climax/core/utils.py b/climax/core/utils.py index cc99955..cf25920 100644 --- a/climax/core/utils.py +++ b/climax/core/utils.py @@ -16,6 +16,7 @@ import numpy as np import pandas as pd import matplotlib.pyplot as plt import matplotlib.lines as mlines +import seaborn as sns # locals from pysegcnn.core.utils import search_files @@ -113,7 +114,7 @@ def split_date_range(start_date, end_date, **kwargs): return dates -def plot_loss(state_file, figsize=(10, 10), step=5): +def plot_loss(state_file, figsize=(10, 10), step=5, palette='mako'): """Plot the observed loss and accuracy of a model run. Parameters @@ -125,6 +126,8 @@ def plot_loss(state_file, figsize=(10, 10), step=5): step : `int`, optional The step to label epochs on the x-axis labels. The default is `5`, i.e. label each fifth epoch. + palette : `str`, optional + Color palette supported by seaborn. The default is `mako`. Returns ------- @@ -162,11 +165,14 @@ def plot_loss(state_file, figsize=(10, 10), step=5): axes = [ax, ax2] # plot training and validation loss - for (k, v), c, marker, ax in zip(rm.items(), ['-', '--'], markers, axes): - ax.plot(v, 'o', ls=c, color='black', markevery=marker) + colors = sns.color_palette(palette, n_colors=2) + for (k, v), ls, c, marker, ax in zip(rm.items(), ['-', '--'], colors, + markers, axes): + ax.plot(v, 'o', ls=ls, color=c, markevery=marker) # x axis limits - axes[0].set_xticks(np.arange(0, ntbatches * epochs[-1], ntbatches * step)) + axes[0].set_xticks(np.arange(0, ntbatches * (epochs[-1] + 1), + ntbatches * step)) axes[0].set_xticklabels(epochs[::step]) axes[0].set_xlabel('Epoch', fontsize=14) axes[0].set_ylabel('Loss', fontsize=14) -- GitLab