diff --git a/climax/core/utils.py b/climax/core/utils.py index cc99955f6ab51cd570d1c342aab53a9b08332217..cf259202f297d02f3f198077b026bafd0da4bfe9 100644 --- a/climax/core/utils.py +++ b/climax/core/utils.py @@ -16,6 +16,7 @@ import numpy as np import pandas as pd import matplotlib.pyplot as plt import matplotlib.lines as mlines +import seaborn as sns # locals from pysegcnn.core.utils import search_files @@ -113,7 +114,7 @@ def split_date_range(start_date, end_date, **kwargs): return dates -def plot_loss(state_file, figsize=(10, 10), step=5): +def plot_loss(state_file, figsize=(10, 10), step=5, palette='mako'): """Plot the observed loss and accuracy of a model run. Parameters @@ -125,6 +126,8 @@ def plot_loss(state_file, figsize=(10, 10), step=5): step : `int`, optional The step to label epochs on the x-axis labels. The default is `5`, i.e. label each fifth epoch. + palette : `str`, optional + Color palette supported by seaborn. The default is `mako`. Returns ------- @@ -162,11 +165,14 @@ def plot_loss(state_file, figsize=(10, 10), step=5): axes = [ax, ax2] # plot training and validation loss - for (k, v), c, marker, ax in zip(rm.items(), ['-', '--'], markers, axes): - ax.plot(v, 'o', ls=c, color='black', markevery=marker) + colors = sns.color_palette(palette, n_colors=2) + for (k, v), ls, c, marker, ax in zip(rm.items(), ['-', '--'], colors, + markers, axes): + ax.plot(v, 'o', ls=ls, color=c, markevery=marker) # x axis limits - axes[0].set_xticks(np.arange(0, ntbatches * epochs[-1], ntbatches * step)) + axes[0].set_xticks(np.arange(0, ntbatches * (epochs[-1] + 1), + ntbatches * step)) axes[0].set_xticklabels(epochs[::step]) axes[0].set_xlabel('Epoch', fontsize=14) axes[0].set_ylabel('Loss', fontsize=14)