From 98e67409f2bec1430a7abdadf72f472ae31038f9 Mon Sep 17 00:00:00 2001
From: "Daniel.Frisinghelli" <daniel.frisinghelli@eurac.edu>
Date: Thu, 21 Oct 2021 12:51:40 +0200
Subject: [PATCH] Improved plotting for loss.

---
 climax/core/utils.py | 9 ++++-----
 1 file changed, 4 insertions(+), 5 deletions(-)

diff --git a/climax/core/utils.py b/climax/core/utils.py
index 2051779..b619e2c 100644
--- a/climax/core/utils.py
+++ b/climax/core/utils.py
@@ -166,9 +166,8 @@ def plot_loss(state_file, figsize=(10, 10), step=5, palette='mako'):
 
     # plot training and validation loss
     colors = sns.color_palette(palette, n_colors=2)
-    for (k, v), ls, c, marker, ax in zip(rm.items(), ['-', '--'], colors,
-                                         markers, axes):
-        ax.plot(v, 'o', ls=ls, color=c, markevery=marker)
+    for (k, v), c, marker, ax in zip(rm.items(), colors, markers, axes):
+        ax.plot(v, 'o', color=c, markevery=marker)
 
     # x axis limits
     axes[0].set_xticks(np.arange(0, ntbatches * (epochs[-1] + 1),
@@ -199,8 +198,8 @@ def plot_loss(state_file, figsize=(10, 10), step=5, palette='mako'):
 
     # create a patch (proxy artist) for every color
     ulabels = ['Training', 'Validation']
-    patches = [mlines.Line2D([], [], color=c, ls=ls, label=l) for c, ls, l in
-               zip(colors, ['-', '--'], ulabels)]
+    patches = [mlines.Line2D([], [], color=c, label=l) for c, l in
+               zip(colors, ulabels)]
 
     # plot patches as legend
     ax.legend(handles=patches, loc='upper left', frameon=False, ncol=2,
-- 
GitLab