Skip to content
Snippets Groups Projects
Commit e6c2ec84 authored by Frisinghelli Daniel's avatar Frisinghelli Daniel
Browse files

Improved plotting for loss.

parent fae2983d
No related branches found
No related tags found
No related merge requests found
......@@ -16,6 +16,7 @@ import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib.lines as mlines
import seaborn as sns
# locals
from pysegcnn.core.utils import search_files
......@@ -113,7 +114,7 @@ def split_date_range(start_date, end_date, **kwargs):
return dates
def plot_loss(state_file, figsize=(10, 10), step=5):
def plot_loss(state_file, figsize=(10, 10), step=5, palette='mako'):
"""Plot the observed loss and accuracy of a model run.
Parameters
......@@ -125,6 +126,8 @@ def plot_loss(state_file, figsize=(10, 10), step=5):
step : `int`, optional
The step to label epochs on the x-axis labels. The default is `5`, i.e.
label each fifth epoch.
palette : `str`, optional
Color palette supported by seaborn. The default is `mako`.
Returns
-------
......@@ -162,11 +165,14 @@ def plot_loss(state_file, figsize=(10, 10), step=5):
axes = [ax, ax2]
# plot training and validation loss
for (k, v), c, marker, ax in zip(rm.items(), ['-', '--'], markers, axes):
ax.plot(v, 'o', ls=c, color='black', markevery=marker)
colors = sns.color_palette(palette, n_colors=2)
for (k, v), ls, c, marker, ax in zip(rm.items(), ['-', '--'], colors,
markers, axes):
ax.plot(v, 'o', ls=ls, color=c, markevery=marker)
# x axis limits
axes[0].set_xticks(np.arange(0, ntbatches * epochs[-1], ntbatches * step))
axes[0].set_xticks(np.arange(0, ntbatches * (epochs[-1] + 1),
ntbatches * step))
axes[0].set_xticklabels(epochs[::step])
axes[0].set_xlabel('Epoch', fontsize=14)
axes[0].set_ylabel('Loss', fontsize=14)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment