From e6c2ec841ab4b6c550b43af2e3504beb46a3b61f Mon Sep 17 00:00:00 2001
From: "Daniel.Frisinghelli" <daniel.frisinghelli@eurac.edu>
Date: Thu, 21 Oct 2021 12:48:26 +0200
Subject: [PATCH] Improved plotting for loss.

---
 climax/core/utils.py | 14 ++++++++++----
 1 file changed, 10 insertions(+), 4 deletions(-)

diff --git a/climax/core/utils.py b/climax/core/utils.py
index cc99955..cf25920 100644
--- a/climax/core/utils.py
+++ b/climax/core/utils.py
@@ -16,6 +16,7 @@ import numpy as np
 import pandas as pd
 import matplotlib.pyplot as plt
 import matplotlib.lines as mlines
+import seaborn as sns
 
 # locals
 from pysegcnn.core.utils import search_files
@@ -113,7 +114,7 @@ def split_date_range(start_date, end_date, **kwargs):
     return dates
 
 
-def plot_loss(state_file, figsize=(10, 10), step=5):
+def plot_loss(state_file, figsize=(10, 10), step=5, palette='mako'):
     """Plot the observed loss and accuracy of a model run.
 
     Parameters
@@ -125,6 +126,8 @@ def plot_loss(state_file, figsize=(10, 10), step=5):
     step : `int`, optional
         The step to label epochs on the x-axis labels. The default is `5`, i.e.
         label each fifth epoch.
+    palette : `str`, optional
+        Color palette supported by seaborn. The default is `mako`.
 
     Returns
     -------
@@ -162,11 +165,14 @@ def plot_loss(state_file, figsize=(10, 10), step=5):
     axes = [ax, ax2]
 
     # plot training and validation loss
-    for (k, v), c, marker, ax in zip(rm.items(), ['-', '--'], markers, axes):
-        ax.plot(v, 'o', ls=c, color='black', markevery=marker)
+    colors = sns.color_palette(palette, n_colors=2)
+    for (k, v), ls, c, marker, ax in zip(rm.items(), ['-', '--'], colors,
+                                         markers, axes):
+        ax.plot(v, 'o', ls=ls, color=c, markevery=marker)
 
     # x axis limits
-    axes[0].set_xticks(np.arange(0, ntbatches * epochs[-1], ntbatches * step))
+    axes[0].set_xticks(np.arange(0, ntbatches * (epochs[-1] + 1),
+                                 ntbatches * step))
     axes[0].set_xticklabels(epochs[::step])
     axes[0].set_xlabel('Epoch', fontsize=14)
     axes[0].set_ylabel('Loss', fontsize=14)
-- 
GitLab