diff --git a/climax/core/utils.py b/climax/core/utils.py index cf259202f297d02f3f198077b026bafd0da4bfe9..2051779fff474d9414adcd6a57902f392a803175 100644 --- a/climax/core/utils.py +++ b/climax/core/utils.py @@ -199,8 +199,8 @@ def plot_loss(state_file, figsize=(10, 10), step=5, palette='mako'): # create a patch (proxy artist) for every color ulabels = ['Training', 'Validation'] - patches = [mlines.Line2D([], [], color='black', ls=c, label=l) for c, l in - zip(['-', '--'], ulabels)] + patches = [mlines.Line2D([], [], color=c, ls=ls, label=l) for c, ls, l in + zip(colors, ['-', '--'], ulabels)] # plot patches as legend ax.legend(handles=patches, loc='upper left', frameon=False, ncol=2,