diff --git a/climax/core/utils.py b/climax/core/utils.py index 2051779fff474d9414adcd6a57902f392a803175..b619e2cb215c8134dd62b1547a0ada6d549fce4b 100644 --- a/climax/core/utils.py +++ b/climax/core/utils.py @@ -166,9 +166,8 @@ def plot_loss(state_file, figsize=(10, 10), step=5, palette='mako'): # plot training and validation loss colors = sns.color_palette(palette, n_colors=2) - for (k, v), ls, c, marker, ax in zip(rm.items(), ['-', '--'], colors, - markers, axes): - ax.plot(v, 'o', ls=ls, color=c, markevery=marker) + for (k, v), c, marker, ax in zip(rm.items(), colors, markers, axes): + ax.plot(v, 'o', color=c, markevery=marker) # x axis limits axes[0].set_xticks(np.arange(0, ntbatches * (epochs[-1] + 1), @@ -199,8 +198,8 @@ def plot_loss(state_file, figsize=(10, 10), step=5, palette='mako'): # create a patch (proxy artist) for every color ulabels = ['Training', 'Validation'] - patches = [mlines.Line2D([], [], color=c, ls=ls, label=l) for c, ls, l in - zip(colors, ['-', '--'], ulabels)] + patches = [mlines.Line2D([], [], color=c, label=l) for c, l in + zip(colors, ulabels)] # plot patches as legend ax.legend(handles=patches, loc='upper left', frameon=False, ncol=2,