Skip to content
Snippets Groups Projects
Commit 5d11d263 authored by Frisinghelli Daniel's avatar Frisinghelli Daniel
Browse files

Plot model state loss.

parent 2ad5d6f7
No related branches found
No related tags found
No related merge requests found
...@@ -10,12 +10,16 @@ import datetime ...@@ -10,12 +10,16 @@ import datetime
from dateutil.relativedelta import relativedelta from dateutil.relativedelta import relativedelta
# externals # externals
import cdo
import torch
import numpy as np import numpy as np
import pandas as pd import pandas as pd
import cdo import matplotlib.pyplot as plt
import matplotlib.lines as mlines
# locals # locals
from pysegcnn.core.utils import search_files from pysegcnn.core.utils import search_files
from pysegcnn.core.graphics import running_mean, ceil_decimal, floor_decimal
from climax.core.constants import CORDEX_PARAMETERS, CDO_RESAMPLING_MODES from climax.core.constants import CORDEX_PARAMETERS, CDO_RESAMPLING_MODES
# module level logger # module level logger
...@@ -107,3 +111,92 @@ def split_date_range(start_date, end_date, **kwargs): ...@@ -107,3 +111,92 @@ def split_date_range(start_date, end_date, **kwargs):
sdate = min(sdate + relativedelta(**kwargs), edate) sdate = min(sdate + relativedelta(**kwargs), edate)
return dates return dates
def plot_loss(state_file, figsize=(10, 10), step=5):
"""Plot the observed loss and accuracy of a model run.
Parameters
----------
state_file : `str` or :py:class:`pathlib.Path`
The model state file. Model state files are stored in
`pysegcnn/main/_models`.
figsize : `tuple` [`int`], optional
The figure size in centimeters. The default is `(10, 10)`.
step : `int`, optional
The step to label epochs on the x-axis labels. The default is `5`, i.e.
label each fifth epoch.
Returns
-------
fig : :py:class:`matplotlib.figure.Figure`
An instance of :py:class:`matplotlib.figure.Figure`.
"""
# load the model state
model_state = torch.load(state_file, map_location=torch.device('cpu'))
# get all non-zero elements, i.e. get number of epochs trained before
# early stop
loss = {k: v for k, v in model_state['state'].items() if 'loss' in k}
# compute running mean with a window equal to the number of batches in
# an epoch
rm = {k: running_mean(v.flatten('F'), v.shape[0]) for k, v in loss.items()}
# sort the keys of the dictionary alphabetically
rm = {k: rm[k] for k in sorted(rm)}
# number of epochs trained
epochs = np.arange(0, loss['train_loss'].shape[1] + 1)
# compute number of mini-batches in training and validation set
ntbatches = loss['train_loss'].shape[0]
nvbatches = loss['valid_loss'].shape[0]
# the mean loss/accuraries at each epoch
markers = [ntbatches, nvbatches]
# instanciate figure
fig, ax = plt.subplots(1, 1, figsize=figsize)
ax2 = ax.twiny()
axes = [ax, ax2]
# plot training and validation loss
for (k, v), c, marker, ax in zip(rm.items(), ['-', '--'], markers, axes):
ax.plot(v, 'o', ls=c, color='black', markevery=marker)
# x axis limits
axes[0].set_xticks(np.arange(0, ntbatches * epochs[-1], ntbatches * step))
axes[0].set_xticklabels(epochs[::step], fontsize=14)
axes[0].set_xlabel('Epoch', fontsize=14)
axes[0].set_ylabel('Loss', fontsize=14)
axes[1].set(xticks=[], xticklabels=[])
# y-axis limits
max_loss = max(rm['train_loss'].max(), rm['valid_loss'].max())
min_loss = min(rm['train_loss'].min(), rm['valid_loss'].min())
yl_max, yl_min = (ceil_decimal(max_loss, decimal=1),
floor_decimal(min_loss, decimal=1))
axes[0].set_ylim(yl_min, yl_max)
# compute early stopping point
if loss['valid_loss'].any():
esepoch = np.argmin(loss['valid_loss'].mean(axis=0))
esacc = np.min(loss['valid_loss'].mean(axis=0))
axes[1].vlines(esepoch * nvbatches, ymin=axes[0].get_ylim()[0],
ymax=axes[0].get_ylim()[1], ls='--', color='grey')
axes[1].text(esepoch * nvbatches - 1, ax.get_ylim()[0] + 0.005,
'epoch = {}, loss = {:.2f}'.format(esepoch, esacc),
ha='right', color='grey', fontsize=14)
# create a patch (proxy artist) for every color
ulabels = ['Training', 'Validation']
patches = [mlines.Line2D([], [], color='black', ls=c, label=l) for c, l in
zip(['-', '--'], ulabels)]
# plot patches as legend
ax.legend(handles=patches, loc='upper left', frameon=False, ncol=2,
fontsize=14)
return fig
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment