diff --git a/climax/core/utils.py b/climax/core/utils.py index 7c0ff898a3c9e6a5906a6172da37b576adf14fc7..3ec1118f2e6b6ce2dbb18c56405451ac11d8ffa4 100644 --- a/climax/core/utils.py +++ b/climax/core/utils.py @@ -10,12 +10,16 @@ import datetime from dateutil.relativedelta import relativedelta # externals +import cdo +import torch import numpy as np import pandas as pd -import cdo +import matplotlib.pyplot as plt +import matplotlib.lines as mlines # locals from pysegcnn.core.utils import search_files +from pysegcnn.core.graphics import running_mean, ceil_decimal, floor_decimal from climax.core.constants import CORDEX_PARAMETERS, CDO_RESAMPLING_MODES # module level logger @@ -107,3 +111,92 @@ def split_date_range(start_date, end_date, **kwargs): sdate = min(sdate + relativedelta(**kwargs), edate) return dates + + +def plot_loss(state_file, figsize=(10, 10), step=5): + """Plot the observed loss and accuracy of a model run. + + Parameters + ---------- + state_file : `str` or :py:class:`pathlib.Path` + The model state file. Model state files are stored in + `pysegcnn/main/_models`. + figsize : `tuple` [`int`], optional + The figure size in centimeters. The default is `(10, 10)`. + step : `int`, optional + The step to label epochs on the x-axis labels. The default is `5`, i.e. + label each fifth epoch. + + Returns + ------- + fig : :py:class:`matplotlib.figure.Figure` + An instance of :py:class:`matplotlib.figure.Figure`. + + """ + # load the model state + model_state = torch.load(state_file, map_location=torch.device('cpu')) + + # get all non-zero elements, i.e. get number of epochs trained before + # early stop + loss = {k: v for k, v in model_state['state'].items() if 'loss' in k} + + # compute running mean with a window equal to the number of batches in + # an epoch + rm = {k: running_mean(v.flatten('F'), v.shape[0]) for k, v in loss.items()} + + # sort the keys of the dictionary alphabetically + rm = {k: rm[k] for k in sorted(rm)} + + # number of epochs trained + epochs = np.arange(0, loss['train_loss'].shape[1] + 1) + + # compute number of mini-batches in training and validation set + ntbatches = loss['train_loss'].shape[0] + nvbatches = loss['valid_loss'].shape[0] + + # the mean loss/accuraries at each epoch + markers = [ntbatches, nvbatches] + + # instanciate figure + fig, ax = plt.subplots(1, 1, figsize=figsize) + ax2 = ax.twiny() + axes = [ax, ax2] + + # plot training and validation loss + for (k, v), c, marker, ax in zip(rm.items(), ['-', '--'], markers, axes): + ax.plot(v, 'o', ls=c, color='black', markevery=marker) + + # x axis limits + axes[0].set_xticks(np.arange(0, ntbatches * epochs[-1], ntbatches * step)) + axes[0].set_xticklabels(epochs[::step], fontsize=14) + axes[0].set_xlabel('Epoch', fontsize=14) + axes[0].set_ylabel('Loss', fontsize=14) + axes[1].set(xticks=[], xticklabels=[]) + + # y-axis limits + max_loss = max(rm['train_loss'].max(), rm['valid_loss'].max()) + min_loss = min(rm['train_loss'].min(), rm['valid_loss'].min()) + yl_max, yl_min = (ceil_decimal(max_loss, decimal=1), + floor_decimal(min_loss, decimal=1)) + axes[0].set_ylim(yl_min, yl_max) + + # compute early stopping point + if loss['valid_loss'].any(): + esepoch = np.argmin(loss['valid_loss'].mean(axis=0)) + esacc = np.min(loss['valid_loss'].mean(axis=0)) + axes[1].vlines(esepoch * nvbatches, ymin=axes[0].get_ylim()[0], + ymax=axes[0].get_ylim()[1], ls='--', color='grey') + axes[1].text(esepoch * nvbatches - 1, ax.get_ylim()[0] + 0.005, + 'epoch = {}, loss = {:.2f}'.format(esepoch, esacc), + ha='right', color='grey', fontsize=14) + + # create a patch (proxy artist) for every color + ulabels = ['Training', 'Validation'] + patches = [mlines.Line2D([], [], color='black', ls=c, label=l) for c, l in + zip(['-', '--'], ulabels)] + + # plot patches as legend + ax.legend(handles=patches, loc='upper left', frameon=False, ncol=2, + fontsize=14) + + return fig \ No newline at end of file