diff --git a/climax/core/utils.py b/climax/core/utils.py
index 7c0ff898a3c9e6a5906a6172da37b576adf14fc7..3ec1118f2e6b6ce2dbb18c56405451ac11d8ffa4 100644
--- a/climax/core/utils.py
+++ b/climax/core/utils.py
@@ -10,12 +10,16 @@ import datetime
 from dateutil.relativedelta import relativedelta
 
 # externals
+import cdo
+import torch
 import numpy as np
 import pandas as pd
-import cdo
+import matplotlib.pyplot as plt
+import matplotlib.lines as mlines
 
 # locals
 from pysegcnn.core.utils import search_files
+from pysegcnn.core.graphics import running_mean, ceil_decimal, floor_decimal
 from climax.core.constants import CORDEX_PARAMETERS, CDO_RESAMPLING_MODES
 
 # module level logger
@@ -107,3 +111,92 @@ def split_date_range(start_date, end_date, **kwargs):
         sdate = min(sdate + relativedelta(**kwargs), edate)
 
     return dates
+
+
+def plot_loss(state_file, figsize=(10, 10), step=5):
+    """Plot the observed loss and accuracy of a model run.
+
+    Parameters
+    ----------
+    state_file : `str` or :py:class:`pathlib.Path`
+        The model state file. Model state files are stored in
+        `pysegcnn/main/_models`.
+    figsize : `tuple` [`int`], optional
+        The figure size in centimeters. The default is `(10, 10)`.
+    step : `int`, optional
+        The step to label epochs on the x-axis labels. The default is `5`, i.e.
+        label each fifth epoch.
+
+    Returns
+    -------
+    fig : :py:class:`matplotlib.figure.Figure`
+        An instance of :py:class:`matplotlib.figure.Figure`.
+
+    """
+    # load the model state
+    model_state = torch.load(state_file, map_location=torch.device('cpu'))
+
+    # get all non-zero elements, i.e. get number of epochs trained before
+    # early stop
+    loss = {k: v for k, v in model_state['state'].items() if 'loss' in k}
+
+    # compute running mean with a window equal to the number of batches in
+    # an epoch
+    rm = {k: running_mean(v.flatten('F'), v.shape[0]) for k, v in loss.items()}
+
+    # sort the keys of the dictionary alphabetically
+    rm = {k: rm[k] for k in sorted(rm)}
+
+    # number of epochs trained
+    epochs = np.arange(0, loss['train_loss'].shape[1] + 1)
+
+    # compute number of mini-batches in training and validation set
+    ntbatches = loss['train_loss'].shape[0]
+    nvbatches = loss['valid_loss'].shape[0]
+
+    # the mean loss/accuraries at each epoch
+    markers = [ntbatches, nvbatches]
+
+    # instanciate figure
+    fig, ax = plt.subplots(1, 1, figsize=figsize)
+    ax2 = ax.twiny()
+    axes = [ax, ax2]
+
+    # plot training and validation loss
+    for (k, v), c, marker, ax in zip(rm.items(), ['-', '--'], markers, axes):
+        ax.plot(v, 'o', ls=c, color='black', markevery=marker)
+
+    # x axis limits
+    axes[0].set_xticks(np.arange(0, ntbatches * epochs[-1], ntbatches * step))
+    axes[0].set_xticklabels(epochs[::step], fontsize=14)
+    axes[0].set_xlabel('Epoch', fontsize=14)
+    axes[0].set_ylabel('Loss', fontsize=14)
+    axes[1].set(xticks=[], xticklabels=[])
+
+    # y-axis limits
+    max_loss = max(rm['train_loss'].max(), rm['valid_loss'].max())
+    min_loss = min(rm['train_loss'].min(), rm['valid_loss'].min())
+    yl_max, yl_min = (ceil_decimal(max_loss, decimal=1),
+                      floor_decimal(min_loss, decimal=1))
+    axes[0].set_ylim(yl_min, yl_max)
+
+    # compute early stopping point
+    if loss['valid_loss'].any():
+        esepoch = np.argmin(loss['valid_loss'].mean(axis=0))
+        esacc = np.min(loss['valid_loss'].mean(axis=0))
+        axes[1].vlines(esepoch * nvbatches, ymin=axes[0].get_ylim()[0],
+                       ymax=axes[0].get_ylim()[1], ls='--', color='grey')
+        axes[1].text(esepoch * nvbatches - 1, ax.get_ylim()[0] + 0.005,
+                     'epoch = {}, loss = {:.2f}'.format(esepoch, esacc),
+                     ha='right', color='grey', fontsize=14)
+
+    # create a patch (proxy artist) for every color
+    ulabels = ['Training', 'Validation']
+    patches = [mlines.Line2D([], [], color='black', ls=c, label=l) for c, l in
+               zip(['-', '--'], ulabels)]
+
+    # plot patches as legend
+    ax.legend(handles=patches, loc='upper left', frameon=False, ncol=2,
+              fontsize=14)
+
+    return fig
\ No newline at end of file