diff --git a/climax/core/utils.py b/climax/core/utils.py
index cf259202f297d02f3f198077b026bafd0da4bfe9..2051779fff474d9414adcd6a57902f392a803175 100644
--- a/climax/core/utils.py
+++ b/climax/core/utils.py
@@ -199,8 +199,8 @@ def plot_loss(state_file, figsize=(10, 10), step=5, palette='mako'):
 
     # create a patch (proxy artist) for every color
     ulabels = ['Training', 'Validation']
-    patches = [mlines.Line2D([], [], color='black', ls=c, label=l) for c, l in
-               zip(['-', '--'], ulabels)]
+    patches = [mlines.Line2D([], [], color=c, ls=ls, label=l) for c, ls, l in
+               zip(colors, ['-', '--'], ulabels)]
 
     # plot patches as legend
     ax.legend(handles=patches, loc='upper left', frameon=False, ncol=2,