diff --git a/.gitignore b/.gitignore
index 51026ddbf4c7110b8d769f14cc075a126cc11318..91f1eff04fa35e0a39b8631358a5f97dc4cb493a 100644
--- a/.gitignore
+++ b/.gitignore
@@ -8,5 +8,5 @@ __pycache__/
 
 # jupyter
 .ipynb_checkpoints/
-Figures/
+Notebooks/Figures/
 *.nc
diff --git a/Notebooks/Figures/pr_ROC.png b/Notebooks/Figures/pr_ROC.png
index 3d794b96feeac9ff9b61ffd5d1293be0b9eab27a..05e4ac64ffaae4b95bc5cd1ffd79e3984de104b8 100644
Binary files a/Notebooks/Figures/pr_ROC.png and b/Notebooks/Figures/pr_ROC.png differ
diff --git a/Notebooks/Figures/pr_average_bias.png b/Notebooks/Figures/pr_average_bias.png
index 2c4e8edc26db176fae1b874825e38a3d154551c8..dece9cee6b82b6b1895b99f18e0d37ea91ade836 100644
Binary files a/Notebooks/Figures/pr_average_bias.png and b/Notebooks/Figures/pr_average_bias.png differ
diff --git a/Notebooks/Figures/pr_average_bias_p98.png b/Notebooks/Figures/pr_average_bias_p98.png
index 75aedf9f96ebaf90868ae438e75c32f9b40755bf..389092dafa3e94c3b69719e2c3712319d8e0cb6b 100644
Binary files a/Notebooks/Figures/pr_average_bias_p98.png and b/Notebooks/Figures/pr_average_bias_p98.png differ
diff --git a/Notebooks/Figures/pr_average_bias_seasonal.png b/Notebooks/Figures/pr_average_bias_seasonal.png
index 6f9bccbaeab009ad09e0ac9676dbdf6e75726843..c0301ffb0e292ce7cea0cbe82f192577b8b8cbdc 100644
Binary files a/Notebooks/Figures/pr_average_bias_seasonal.png and b/Notebooks/Figures/pr_average_bias_seasonal.png differ
diff --git a/Notebooks/Figures/pr_average_bias_seasonal_ex.png b/Notebooks/Figures/pr_average_bias_seasonal_ex.png
index 623758a81477f6db21c14191e137c800c7453901..33274b62674ffddac36e802982e6c6e6b33b484e 100644
Binary files a/Notebooks/Figures/pr_average_bias_seasonal_ex.png and b/Notebooks/Figures/pr_average_bias_seasonal_ex.png differ
diff --git a/Notebooks/Figures/pr_bias_wet_days.png b/Notebooks/Figures/pr_bias_wet_days.png
index d440737d9e35f3c053986e9160ed1b899ca47e19..cce0d4a8f423e431d075d074409f578889122206 100644
Binary files a/Notebooks/Figures/pr_bias_wet_days.png and b/Notebooks/Figures/pr_bias_wet_days.png differ
diff --git a/Notebooks/Figures/pr_bias_wet_days_p.png b/Notebooks/Figures/pr_bias_wet_days_p.png
index d9d92768028899bcf93a911ee797e5de2c43a7b3..49b3b8455022912535b8df697273c94ec64ebb43 100644
Binary files a/Notebooks/Figures/pr_bias_wet_days_p.png and b/Notebooks/Figures/pr_bias_wet_days_p.png differ
diff --git a/Notebooks/Figures/pr_distribution.png b/Notebooks/Figures/pr_distribution.png
index 16bebc879e7c905559c876f36c14aef4ab79b4d9..c0e82110ebf9744fc5e4b3f56010d3000039c143 100644
Binary files a/Notebooks/Figures/pr_distribution.png and b/Notebooks/Figures/pr_distribution.png differ
diff --git a/Notebooks/Figures/pr_r2.png b/Notebooks/Figures/pr_r2.png
index c53560b67b8f4b23393cb629ecc5ad700b90a29a..afe8e83daab47b571515f3a66129f67bfb41e6ef 100644
Binary files a/Notebooks/Figures/pr_r2.png and b/Notebooks/Figures/pr_r2.png differ
diff --git a/Notebooks/Figures/pr_rbias_ERA_vs_model.png b/Notebooks/Figures/pr_rbias_ERA_vs_model.png
index 1e1d862f30f996697ce015d90b749eced48ac021..846e3ee7c4c9526b5a91914b298994a46072e555 100644
Binary files a/Notebooks/Figures/pr_rbias_ERA_vs_model.png and b/Notebooks/Figures/pr_rbias_ERA_vs_model.png differ
diff --git a/Notebooks/Figures/tasmin_average_bias.png b/Notebooks/Figures/tasmin_average_bias.png
index 5992128610af6a48020688efda2db8b3de06dc1b..457180f301641a28503c3351ed83f0a1127bd556 100644
Binary files a/Notebooks/Figures/tasmin_average_bias.png and b/Notebooks/Figures/tasmin_average_bias.png differ
diff --git a/Notebooks/Figures/tasmin_average_bias_seasonal.png b/Notebooks/Figures/tasmin_average_bias_seasonal.png
index 409f8619b787f46f453e09a84793c000f308719b..6b149b124fab7ab8923f78c9b28311d933c88063 100644
Binary files a/Notebooks/Figures/tasmin_average_bias_seasonal.png and b/Notebooks/Figures/tasmin_average_bias_seasonal.png differ
diff --git a/Notebooks/Figures/tasmin_r2.png b/Notebooks/Figures/tasmin_r2.png
index 5e6d822120b7de648334bb422cbcd2cdaec64105..397cf48d0de2fcd17f9a12aa5ea54c4c2c93936f 100644
Binary files a/Notebooks/Figures/tasmin_r2.png and b/Notebooks/Figures/tasmin_r2.png differ
diff --git a/Notebooks/check_predictors.ipynb b/Notebooks/check_predictors.ipynb
index dd442c4a0a7aefb114e7c77240247f0a6c1f6cdb..1c23caa16b5cd7e359f2e0e62bca699f8498aec1 100644
--- a/Notebooks/check_predictors.ipynb
+++ b/Notebooks/check_predictors.ipynb
@@ -17,8 +17,10 @@
    "source": [
     "# builtins\n",
     "import pathlib\n",
+    "import warnings\n",
     "\n",
     "# externals\n",
+    "import torch\n",
     "import numpy as np\n",
     "import pandas as pd\n",
     "import xarray as xr\n",
@@ -63,11 +65,13 @@
    "outputs": [],
    "source": [
     "# define the predictor variables you want to use\n",
-    "ERA5_PREDICTORS = ['geopotential', 'temperature', 'mean_sea_level_pressure']  # use geopotential, temperature and pressure\n",
+    "# ERA5_PREDICTORS = ['geopotential', 'temperature', 'mean_sea_level_pressure']  # use geopotential, temperature and pressure\n",
     "\n",
     "# you can change this list as you wish, e.g.:\n",
     "# ERA5_PREDICTORS = ['geopotential', 'temperature']  # use only geopotential and temperature\n",
     "# ERA5_PREDICTORS = ERA5_VARIABLES  # use all ERA5 variables as predictors\n",
+    "ERA5_PREDICTORS = ['geopotential', 'temperature', 'surface_pressure',\n",
+    "                   'u_component_of_wind', 'v_component_of_wind', 'specific_humidity']\n",
     "\n",
     "# this checks if the variable names are correct\n",
     "assert all([p in ERA5_VARIABLES for p in ERA5_PREDICTORS]) "
@@ -81,6 +85,16 @@
     "### Use the climax package to load ERA5 predictors"
    ]
   },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "3432b46f-29a9-4333-af13-05dc0643ce63",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "CHUNKS = {'time': 16, 'x': 16, 'y': 16}"
+   ]
+  },
   {
    "cell_type": "code",
    "execution_count": null,
@@ -101,20 +115,48 @@
    "source": [
     "# create the xarray.Dataset of the specified predictor variables\n",
     "predictors = ERA5Dataset(ERA5_PATH, ERA5_PREDICTORS, plevels=PLEVELS)\n",
-    "predictors = predictors.merge(chunks=-1)"
+    "predictors = predictors.merge(chunks=CHUNKS)"
    ]
   },
   {
    "cell_type": "code",
    "execution_count": null,
-   "id": "ea5c1c43-fbbf-4db1-926b-1102e8079348",
+   "id": "f4c6a816-a099-4c53-8d9f-79c07ca6626a",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "# add the day of the year\n",
+    "predictors = predictors.assign(ERA5Dataset.encode_doys(predictors, chunks=predictors.chunks))"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "3d453b92-818f-4d6c-bc61-bdc5869b49ea",
    "metadata": {},
    "outputs": [],
    "source": [
-    "# check out the xarray.Dataset: you will see all the variables you specified\n",
     "predictors"
    ]
   },
+  {
+   "cell_type": "markdown",
+   "id": "9e804201-1fd4-44a7-b63b-6e4d55b77e65",
+   "metadata": {},
+   "source": [
+    "## Normalize predictors"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "662e2639-2553-4711-b103-e6a94d4ab07c",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "predictors = ERA5Dataset.normalize(predictors, period=CALIB_PERIOD)"
+   ]
+  },
   {
    "cell_type": "markdown",
    "id": "14668e55-3b47-44ea-b3f7-596b0e62eec6",
@@ -144,7 +186,8 @@
     "dem = search_files(DEM_PATH, '^eu_dem_v11_stt.nc$').pop()\n",
     "\n",
     "# read elevation and compute slope and aspect\n",
-    "dem = ERA5Dataset.dem_features(dem, {'y': predictors.y, 'x': predictors.x}, add_coord={'time': predictors.time})"
+    "dem = ERA5Dataset.dem_features(dem, {'y': predictors.y, 'x': predictors.x}, add_coord={'time': predictors.time})\n",
+    "dem = dem.drop_vars(['slope', 'aspect']).chunk(predictors.chunks)"
    ]
   },
   {
@@ -157,6 +200,16 @@
     "dem"
    ]
   },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "0a545ff3-ab4f-4e6f-8990-af2c4e63be30",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "dem = ERA5Dataset.normalize(dem, dim=('y', 'x'))"
+   ]
+  },
   {
    "cell_type": "markdown",
    "id": "d55ca9b1-46a9-462a-ba42-fb7a38c823fa",
@@ -175,6 +228,62 @@
     "Era5_ds = xr.merge([predictors, dem])"
    ]
   },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "01557918-05e0-4a24-991a-daf6d6bf34de",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "Era5_ds"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "4df9a0ec-8689-4efb-a753-938f70cf548a",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "# select calibration period and load data\n",
+    "%time calib = Era5_ds.sel(time=CALIB_PERIOD).compute()"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "id": "86f24b93-07df-4f8c-8c9f-e6cbb8213d0d",
+   "metadata": {},
+   "source": [
+    "## Compute standardized anomalies"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "4da51abd-8786-47da-9b14-1a6a036172d3",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "groups = predictors.groupby('time.dayofyear').groups\n",
+    "anomalies = {}\n",
+    "for doy, days in groups.items():\n",
+    "    with warnings.catch_warnings():\n",
+    "        warnings.simplefilter('ignore', category=RuntimeWarning)\n",
+    "        anomalies[doy] = (predictors.isel(time=days) - predictors.isel(time=days).mean(dim='time')) / predictors.isel(time=days).std(dim='time')\n",
+    "anomalies = xr.concat(anomalies.values(), dim='time')\n",
+    "anomalies = anomalies.sortby(anomalies.time)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "c500c934-25dd-4dcc-bac3-1aef4ac27181",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "anomalies"
+   ]
+  },
   {
    "cell_type": "markdown",
    "id": "dd584313-8913-45b5-a379-2724d42bc99f",
@@ -190,7 +299,7 @@
    "metadata": {},
    "outputs": [],
    "source": [
-    "PREDICTAND = 'pr'"
+    "PREDICTAND = 'tasmin'"
    ]
   },
   {
diff --git a/Notebooks/debug_loss.ipynb b/Notebooks/debug_loss.ipynb
new file mode 100644
index 0000000000000000000000000000000000000000..6e51b4574b55dd4afd570f343f1bd14ec13a5900
--- /dev/null
+++ b/Notebooks/debug_loss.ipynb
@@ -0,0 +1,298 @@
+{
+ "cells": [
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "5b9c57b3-4748-4016-8990-840475c44392",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "\"\"\"Dynamical climate downscaling using deep convolutional neural networks.\"\"\"\n",
+    "\n",
+    "# !/usr/bin/env python\n",
+    "# -*- coding: utf-8 -*-\n",
+    "\n",
+    "# builtins\n",
+    "import sys\n",
+    "import time\n",
+    "import logging\n",
+    "from datetime import timedelta\n",
+    "from logging.config import dictConfig\n",
+    "\n",
+    "# externals\n",
+    "import torch\n",
+    "import xarray as xr\n",
+    "from sklearn.model_selection import train_test_split\n",
+    "from torch.utils.data import DataLoader\n",
+    "\n",
+    "# locals\n",
+    "from pysegcnn.core.utils import search_files\n",
+    "from pysegcnn.core.trainer import NetworkTrainer, LogConfig\n",
+    "from pysegcnn.core.models import Network\n",
+    "from pysegcnn.core.logging import log_conf\n",
+    "from climax.core.dataset import ERA5Dataset, NetCDFDataset\n",
+    "from climax.core.loss import MSELoss, L1Loss, BernoulliWeibullLoss\n",
+    "from climax.main.config import (ERA5_PLEVELS, ERA5_PREDICTORS, PREDICTAND,\n",
+    "                                CALIB_PERIOD, DOY, SHUFFLE, BATCH_SIZE, LR,\n",
+    "                                LAMBDA, NORM, TRAIN_CONFIG, NET, LOSS, FILTERS,\n",
+    "                                OVERWRITE, DEM, DEM_FEATURES, STRATIFY,\n",
+    "                                WET_DAY_THRESHOLD, VALID_SIZE)\n",
+    "from climax.main.io import ERA5_PATH, OBS_PATH, DEM_PATH, MODEL_PATH\n",
+    "\n",
+    "# module level logger\n",
+    "LOGGER = logging.getLogger(__name__)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "40c70d2f-2e31-4ea8-adfa-5009b950e1e0",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "# initialize timing\n",
+    "start_time = time.monotonic()\n",
+    "\n",
+    "# initialize network filename\n",
+    "state_file = ERA5Dataset.state_file(\n",
+    "    NET, PREDICTAND, ERA5_PREDICTORS, ERA5_PLEVELS, dem=DEM,\n",
+    "    dem_features=DEM_FEATURES, doy=DOY, loss=LOSS)\n",
+    "\n",
+    "# path to model state\n",
+    "state_file = MODEL_PATH.joinpath(PREDICTAND, state_file)\n",
+    "\n",
+    "# initialize logging\n",
+    "log_file = MODEL_PATH.joinpath(PREDICTAND,\n",
+    "                               state_file.name.replace('.pt', '_log.txt'))\n",
+    "if log_file.exists():\n",
+    "    log_file.unlink()\n",
+    "dictConfig(log_conf(log_file))\n",
+    "\n",
+    "# initialize downscaling\n",
+    "LogConfig.init_log('Initializing downscaling for period: {}'.format(\n",
+    "    ' - '.join([str(CALIB_PERIOD[0]), str(CALIB_PERIOD[-1])])))\n",
+    "\n",
+    "# check if model exists\n",
+    "if state_file.exists() and not OVERWRITE:\n",
+    "    # load pretrained network\n",
+    "    net, _ = Network.load_pretrained_model(state_file, NET)\n",
+    "    sys.exit()"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "582a53c1-78ff-4c56-9caf-3b1dd40ce1b3",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "# initialize ERA5 predictor dataset\n",
+    "LogConfig.init_log('Initializing ERA5 predictors.')\n",
+    "Era5 = ERA5Dataset(ERA5_PATH.joinpath('ERA5'), ERA5_PREDICTORS,\n",
+    "                   plevels=ERA5_PLEVELS)\n",
+    "Era5_ds = Era5.merge(chunks=-1)\n",
+    "\n",
+    "# initialize OBS predictand dataset\n",
+    "LogConfig.init_log('Initializing observations for predictand: {}'\n",
+    "                   .format(PREDICTAND))\n",
+    "\n",
+    "# check whether to joinlty train tasmin and tasmax\n",
+    "if PREDICTAND == 'tas':\n",
+    "    # read both tasmax and tasmin\n",
+    "    tasmax = xr.open_dataset(\n",
+    "        search_files(OBS_PATH.joinpath('tasmax'), '.nc$').pop())\n",
+    "    tasmin = xr.open_dataset(\n",
+    "        search_files(OBS_PATH.joinpath('tasmin'), '.nc$').pop())\n",
+    "    Obs_ds = xr.merge([tasmax, tasmin])\n",
+    "else:\n",
+    "    # read in-situ gridded observations\n",
+    "    Obs_ds = search_files(OBS_PATH.joinpath(PREDICTAND), 'OBS_pr(.*).nc$').pop()\n",
+    "    Obs_ds = xr.open_dataset(Obs_ds)\n",
+    "\n",
+    "# whether to use digital elevation model\n",
+    "if DEM:\n",
+    "    # digital elevation model: Copernicus EU-Dem v1.1\n",
+    "    dem = search_files(DEM_PATH, '^eu_dem_v11_stt.nc$').pop()\n",
+    "\n",
+    "    # read elevation and compute slope and aspect\n",
+    "    dem = ERA5Dataset.dem_features(\n",
+    "        dem, {'y': Era5_ds.y, 'x': Era5_ds.x},\n",
+    "        add_coord={'time': Era5_ds.time})\n",
+    "\n",
+    "    # check whether to use slope and aspect\n",
+    "    if not DEM_FEATURES:\n",
+    "        dem = dem.drop_vars(['slope', 'aspect'])\n",
+    "\n",
+    "    # add dem to set of predictor variables\n",
+    "    Era5_ds = xr.merge([Era5_ds, dem])"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "54afbbab-728a-449a-9856-51e55352fafc",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "# initialize training data\n",
+    "LogConfig.init_log('Initializing training data.')\n",
+    "\n",
+    "# split calibration period into training and validation period\n",
+    "if PREDICTAND == 'pr' and STRATIFY:\n",
+    "    # stratify training and validation dataset by number of\n",
+    "    # observed wet days for precipitation\n",
+    "    wet_days = (Obs_ds.sel(time=CALIB_PERIOD).mean(dim=('y', 'x'))\n",
+    "                >= WET_DAY_THRESHOLD).to_array().values.squeeze()\n",
+    "    train, valid = train_test_split(\n",
+    "        CALIB_PERIOD, stratify=wet_days, test_size=VALID_SIZE)\n",
+    "\n",
+    "    # sort chronologically\n",
+    "    train, valid = sorted(train), sorted(valid)\n",
+    "else:\n",
+    "    train, valid = train_test_split(CALIB_PERIOD, shuffle=False,\n",
+    "                                    test_size=VALID_SIZE)\n",
+    "\n",
+    "# training and validation dataset\n",
+    "Era5_train, Obs_train = Era5_ds.sel(time=train), Obs_ds.sel(time=train)\n",
+    "Era5_valid, Obs_valid = Era5_ds.sel(time=valid), Obs_ds.sel(time=valid)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "e6b861a1-3b2a-49b8-8019-48fce4a42664",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "# create PyTorch compliant dataset and dataloader instances for model\n",
+    "# training\n",
+    "train_ds = NetCDFDataset(Era5_train, Obs_train, normalize=NORM, doy=DOY)\n",
+    "# valid_ds = NetCDFDataset(Era5_valid, Obs_valid, normalize=NORM, doy=DOY)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "79537498-bed8-4d2e-a3d4-363f943c57cf",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "train_dl = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=SHUFFLE,\n",
+    "                      drop_last=False)\n",
+    "# valid_dl = DataLoader(valid_ds, batch_size=BATCH_SIZE, shuffle=SHUFFLE,\n",
+    "#                       drop_last=False)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "a4e87b23-9a12-492c-9113-d44b404ac9d7",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "PREDICTAND = 'pr'"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "192e1128-e02d-4ca1-8df4-bc4be57d3acc",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "# initialize network and optimizer\n",
+    "LogConfig.init_log('Initializing network and optimizer.')\n",
+    "\n",
+    "# define number of output fields\n",
+    "# check whether modelling pr with probabilistic approach\n",
+    "outputs = len(Obs_ds.data_vars)\n",
+    "if PREDICTAND == 'pr':\n",
+    "    outputs = (1 if (isinstance(LOSS, MSELoss) or isinstance(LOSS, L1Loss))\n",
+    "               else 3)\n",
+    "\n",
+    "# instanciate network\n",
+    "inputs = len(Era5_ds.data_vars) + 2 if DOY else len(Era5_ds.data_vars)\n",
+    "net = NET(state_file, inputs, outputs, filters=FILTERS)\n",
+    "\n",
+    "# initialize optimizer\n",
+    "optimizer = torch.optim.Adam(net.parameters(), lr=LR,\n",
+    "                             weight_decay=LAMBDA)\n",
+    "# optimizer = torch.optim.SGD(net.parameters(), lr=LR, momentum=0.9,\n",
+    "#                             weight_decay=LAMBDA)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "be030225-7d90-484b-81c8-325578c8648b",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "LOSS = BernoulliWeibullLoss(min_amount=1, reduction='mean')"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "4c84b9ef-158f-421a-8abd-d241b5dca35b",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "# forward pass\n",
+    "torch.autograd.set_detect_anomaly(True)\n",
+    "for inputs, labels in train_dl:\n",
+    "    outputs = net(inputs)\n",
+    "    loss = LOSS(outputs, labels)\n",
+    "    print(loss)\n",
+    "    loss.backward()\n",
+    "    optimizer.step()"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "d42ffa8d-e172-4fce-950c-2df7a371bc4c",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "# initialize network trainer\n",
+    "trainer = NetworkTrainer(net, optimizer, net.state_file, train_dl,\n",
+    "                         valid_dl, loss_function=LOSS, **TRAIN_CONFIG)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "aea217fb-b9fb-4a17-8135-2fb052101f39",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "# train model\n",
+    "# check for anomalies when training\n",
+    "torch.autograd.set_detect_anomaly(True)\n",
+    "state = trainer.train()"
+   ]
+  }
+ ],
+ "metadata": {
+  "kernelspec": {
+   "display_name": "Python 3",
+   "language": "python",
+   "name": "python3"
+  },
+  "language_info": {
+   "codemirror_mode": {
+    "name": "ipython",
+    "version": 3
+   },
+   "file_extension": ".py",
+   "mimetype": "text/x-python",
+   "name": "python",
+   "nbconvert_exporter": "python",
+   "pygments_lexer": "ipython3",
+   "version": "3.7.10"
+  }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 5
+}
diff --git a/Notebooks/eval_precipitation.ipynb b/Notebooks/eval_precipitation.ipynb
index 40f103a120f92c99c415b1e1a0799083eb83e4b8..e1f85e50510435e9d9f838888664025d2dc7682e 100644
--- a/Notebooks/eval_precipitation.ipynb
+++ b/Notebooks/eval_precipitation.ipynb
@@ -29,7 +29,8 @@
     "- Specific humidity (q)\n",
     "\n",
     "**Predictors on surface**:\n",
-    "- Mean sea level pressure (msl)\n",
+    "- Surface pressure (p)\n",
+    "- Total precipitation (pr)\n",
     "\n",
     "**Auxiliary predictors**:\n",
     "- Elevation from Copernicus EU-DEM v1.1 (dem)\n",
@@ -59,13 +60,52 @@
     "PLEVELS = ['500', '850']\n",
     "# PLEVELS = []\n",
     "SPREDICTORS = 'p'\n",
+    "# SPREDICTORS = 'ppr'\n",
     "DEM = 'dem'\n",
     "DEM_FEATURES = ''\n",
-    "DOY = ''\n",
+    "DOY = 'doy'\n",
     "WET_DAY_THRESHOLD = '1'\n",
     "# LOSS = 'MSELoss'\n",
     "LOSS = 'BernoulliGammaLoss'\n",
-    "SEASON = 'season'"
+    "# LOSS = 'BernoulliWeibullLoss'\n",
+    "OPTIM = 'SGD'\n",
+    "# OPTIM = 'Adam'\n",
+    "SEASON = ''\n",
+    "DECAY = 1e-02\n",
+    "LR = 5e-03"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "9f4b5c0e-670f-40ba-9e38-00c38e7e7855",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "# construct file pattern to match\n",
+    "PATTERN = '_'.join([MODEL, PREDICTAND, PPREDICTORS, *PLEVELS])\n",
+    "PATTERN = '_'.join([PATTERN, SPREDICTORS]) if SPREDICTORS else PATTERN\n",
+    "PATTERN = '_'.join([PATTERN, DEM]) if DEM else PATTERN\n",
+    "PATTERN = '_'.join([PATTERN, DEM_FEATURES]) if DEM_FEATURES else PATTERN\n",
+    "PATTERN = '_'.join([PATTERN, DOY]) if DOY else PATTERN\n",
+    "PATTERN = '_'.join([PATTERN, '{}mm'.format(str(WET_DAY_THRESHOLD).replace('.', ''))]) if WET_DAY_THRESHOLD else PATTERN\n",
+    "PATTERN = '_'.join([PATTERN, LOSS])\n",
+    "PATTERN = '_'.join([PATTERN, OPTIM])\n",
+    "PATTERN = '_'.join([PATTERN, SEASON]) if SEASON else PATTERN\n",
+    "PATTERN = '_'.join([PATTERN, 'd{:.0e}'.format(DECAY)])\n",
+    "PATTERN = '_'.join([PATTERN, 'lr{:.0e}'.format(LR)])\n",
+    "PATTERN"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "e38fabfa-60ca-4eb3-80da-dc646404fb0f",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "# whether to search for the defined pattern\n",
+    "SEARCH = False"
    ]
   },
   {
@@ -73,7 +113,7 @@
    "id": "dd188df0-69ee-44b0-82b2-d212994dc271",
    "metadata": {},
    "source": [
-    "### Imports"
+    "## Imports"
    ]
   },
   {
@@ -89,18 +129,22 @@
     "import calendar\n",
     "\n",
     "# externals\n",
+    "import torch\n",
     "import xarray as xr\n",
     "import numpy as np\n",
     "import matplotlib.pyplot as plt\n",
+    "from matplotlib import cm\n",
+    "from matplotlib.colors import ListedColormap\n",
     "from mpl_toolkits.axes_grid1.inset_locator import inset_axes\n",
     "import scipy.stats as stats\n",
     "from IPython.display import Image\n",
-    "from sklearn.metrics import r2_score, roc_curve, auc, classification_report\n",
+    "from sklearn.metrics import r2_score, roc_curve, auc\n",
     "\n",
     "# locals\n",
-    "from climax.main.io import ERA5_PATH, OBS_PATH, TARGET_PATH, DEM_PATH\n",
-    "from pysegcnn.core.utils import search_files\n",
-    "from pysegcnn.core.graphics import plot_classification_report"
+    "from climax.main.io import ERA5_PATH, OBS_PATH, TARGET_PATH, DEM_PATH, MODEL_PATH\n",
+    "from climax.core.utils import plot_loss\n",
+    "from climax.core.dataset import ERA5Dataset\n",
+    "from pysegcnn.core.utils import search_files"
    ]
   },
   {
@@ -116,7 +160,7 @@
   },
   {
    "cell_type": "markdown",
-   "id": "a3d92474-2bc1-4035-8938-c5fbb07ae891",
+   "id": "9ca58029-ed0a-4aab-ba67-b0d5d96a12c2",
    "metadata": {},
    "source": [
     "### Model architecture"
@@ -125,61 +169,13 @@
   {
    "cell_type": "code",
    "execution_count": null,
-   "id": "c68e022b-41c2-438b-bfb0-e27eddee89bf",
+   "id": "9b083287-34d0-4e31-aeb0-4c0005cf3e83",
    "metadata": {},
    "outputs": [],
    "source": [
     "Image(\"./Figures/architecture.png\", width=900, height=400)"
    ]
   },
-  {
-   "cell_type": "markdown",
-   "id": "c8833efe-c715-4872-9aee-a0b5766f5c67",
-   "metadata": {},
-   "source": [
-    "### Loss function"
-   ]
-  },
-  {
-   "cell_type": "markdown",
-   "id": "bb801367-6872-4a0b-bff5-70cb6746e057",
-   "metadata": {},
-   "source": [
-    "For precipitation, the network is optimizing the negative log-likelihood of a Bernoulli-Gamma distribution after [Cannon (2008)](http://journals.ametsoc.org/doi/10.1175/2008JHM960.1)."
-   ]
-  },
-  {
-   "cell_type": "markdown",
-   "id": "a8775d5e-5ad4-47e4-8fef-6dc230e15dee",
-   "metadata": {},
-   "source": [
-    "Bernoulli-Gamma distribution:"
-   ]
-  },
-  {
-   "cell_type": "markdown",
-   "id": "ab10f8de-d8d2-4427-b9c8-5d68803543c3",
-   "metadata": {},
-   "source": [
-    "$$P(y \\mid, p, \\alpha, \\beta) = \\begin{cases} 1 - p, & \\text{for } y = 0\\\\ p \\cdot \\frac{y^{\\alpha -1} \\exp(-y/\\beta)}{\\beta^{\\alpha} \\tau(\\alpha)}, & \\text{for } y > 0\\end{cases}$$"
-   ]
-  },
-  {
-   "cell_type": "markdown",
-   "id": "6b6dbd06-1f0e-4c52-84f2-b7ff31c75726",
-   "metadata": {},
-   "source": [
-    "Log-likelihood function:"
-   ]
-  },
-  {
-   "cell_type": "markdown",
-   "id": "e41c7b39-f352-4a98-820f-9a7345b3283c",
-   "metadata": {},
-   "source": [
-    "$$\\mathcal{J}(p, \\alpha, \\beta \\mid y) = \\underbrace{(1 - P(y > 0)) \\log(1 - p)}_{\\text{Bernoulli}} + \\underbrace{P(y > 0) \\cdot \\left(\\log(p) + (\\alpha - 1) \\log(y) - \\frac{y}{\\beta} - \\alpha \\log(\\beta) - \\log(\\tau(\\alpha))\\right)}_{\\text{Gamma}}$$"
-   ]
-  },
   {
    "cell_type": "markdown",
    "id": "5a0c55f0-79fb-4501-b3cf-b5414399a3d9",
@@ -191,32 +187,33 @@
   {
    "cell_type": "code",
    "execution_count": null,
-   "id": "efa76f8e-c089-47ff-a001-d4c2a11c4d6d",
+   "id": "ecaba394-f802-481f-8274-c44e2f5fdf1a",
    "metadata": {},
    "outputs": [],
    "source": [
-    "# construct file pattern to match\n",
-    "PATTERN = '_'.join([MODEL, PREDICTAND, PPREDICTORS, *PLEVELS])\n",
-    "PATTERN = '_'.join([PATTERN, SPREDICTORS]) if SPREDICTORS else PATTERN\n",
-    "PATTERN = '_'.join([PATTERN, DEM]) if DEM else PATTERN\n",
-    "PATTERN = '_'.join([PATTERN, DEM_FEATURES]) if DEM_FEATURES else PATTERN\n",
-    "PATTERN = '_'.join([PATTERN, DOY]) if DOY else PATTERN\n",
-    "PATTERN = '_'.join([PATTERN, '{}mm'.format(str(WET_DAY_THRESHOLD).replace('.', ''))]) if WET_DAY_THRESHOLD else PATTERN\n",
-    "PATTERN = '_'.join([PATTERN, LOSS])\n",
-    "PATTERN = '_'.join([PATTERN, SEASON]) if SEASON else PATTERN\n",
-    "PATTERN"
+    "# digital elevation model\n",
+    "dem = xr.open_dataset(search_files(DEM_PATH, 'eu_dem_v11_stt.nc').pop())\n",
+    "dem = dem.Band1.to_dataset().rename({'Band1': 'elevation'})"
    ]
   },
   {
    "cell_type": "code",
    "execution_count": null,
-   "id": "ecaba394-f802-481f-8274-c44e2f5fdf1a",
+   "id": "859ec647-6010-4526-bd96-edb22e336c65",
    "metadata": {},
    "outputs": [],
    "source": [
-    "# digital elevation model\n",
-    "dem = xr.open_dataset(search_files(DEM_PATH, 'eu_dem_v11_stt.nc').pop())\n",
-    "dem = dem.Band1.to_dataset().rename({'Band1': 'elevation'})"
+    "# model predictions\n",
+    "if SEARCH:\n",
+    "    y_pred = xr.open_dataset(search_files(TARGET_PATH.joinpath(PREDICTAND), '.'.join([PATTERN, 'nc$'])).pop())\n",
+    "else:\n",
+    "    filename = 'USegNet_tasmin_ztuvq_500_850_p_dem_doy_L1Loss_Adam_d1e-03_lr3e-04.nc'\n",
+    "    file = search_files(TARGET_PATH, filename).pop()\n",
+    "    y_pred = xr.open_dataset(file)\n",
+    "try:\n",
+    "    y_pred = y_pred.rename({'pr': 'precipitation'})\n",
+    "except ValueError:\n",
+    "    pass"
    ]
   },
   {
@@ -228,13 +225,8 @@
    },
    "outputs": [],
    "source": [
-    "# model predictions and observations NetCDF \n",
-    "y_pred = xr.open_dataset(search_files(TARGET_PATH.joinpath(PREDICTAND), '.'.join([PATTERN, 'nc$'])).pop())\n",
-    "y_true = xr.open_dataset(search_files(OBS_PATH.joinpath(PREDICTAND), 'OBS_pr(.*).nc$').pop())\n",
-    "try:\n",
-    "    y_pred = y_pred.rename({'pr': 'precipitation'})\n",
-    "except ValueError:\n",
-    "    pass"
+    "# observations\n",
+    "y_true = xr.open_dataset(search_files(OBS_PATH.joinpath(PREDICTAND), 'OBS_pr(.*).nc$').pop())"
    ]
   },
   {
@@ -268,7 +260,7 @@
    "metadata": {},
    "outputs": [],
    "source": [
-    "# align datasets\n",
+    "# aligndata_varsasets\n",
     "if len(y_pred.data_vars) > 1:\n",
     "    y_true, y_refe, y_pred_pr, y_pred_prob = xr.align(y_true.precipitation, y_refe.precipitation, y_pred.precipitation, y_pred.prob, join='override')\n",
     "else:\n",
@@ -304,67 +296,112 @@
   },
   {
    "cell_type": "markdown",
-   "id": "b269a131-cf5b-4c6c-9f8e-a5408250aa83",
+   "id": "6232ec70-f595-401e-bddd-015eeb606ab0",
    "metadata": {},
    "source": [
-    "## Model validation: precipitation amount"
+    "## Model convergence"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "0e3ee22c-7ec1-44c2-9152-c10726621848",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "# plot model state\n",
+    "model_state = MODEL_PATH.joinpath(PREDICTAND, '.'.join([PATTERN, 'pt']))\n",
+    "try:\n",
+    "    fig = plot_loss(model_state, step=5)\n",
+    "except FileNotFoundError:\n",
+    "    pass"
    ]
   },
   {
    "cell_type": "markdown",
-   "id": "0fa1fe82-0d6e-4676-b5b8-9eac2fd28ffb",
+   "id": "b269a131-cf5b-4c6c-9f8e-a5408250aa83",
    "metadata": {},
    "source": [
-    "### Coefficient of determination: monthly mean"
+    "## Model validation: precipitation amount"
    ]
   },
   {
    "cell_type": "code",
    "execution_count": null,
-   "id": "4d11cee6-1ebd-4424-8ec6-92e5a196bac4",
+   "id": "b73658af-d994-40a9-a41c-5b434e45504a",
    "metadata": {},
    "outputs": [],
    "source": [
-    "# calculate monthly mean precipitation (mm / month)\n",
-    "y_pred_values = y_pred_pr.resample(time='1M').sum(skipna=False).groupby('time.month').mean(dim='time').values\n",
-    "y_true_values = y_true.resample(time='1M').sum(skipna=False).groupby('time.month').mean(dim='time').values"
+    "# calculate monthly precipitation (mm / month)\n",
+    "y_pred_m = y_pred_pr.resample(time='1M').sum(skipna=False)\n",
+    "y_true_m = y_true.resample(time='1M').sum(skipna=False)\n",
+    "y_refe_m = y_refe.resample(time='1M').sum(skipna=False)"
    ]
   },
   {
    "cell_type": "code",
    "execution_count": null,
-   "id": "6e55cdd5-8adb-40c5-9742-46f4fc3d4be9",
+   "id": "93bd8fd9-77c7-40ed-bd31-0d9a3ce86a12",
    "metadata": {},
    "outputs": [],
    "source": [
-    "# apply mask of valid pixels\n",
-    "mask = (~np.isnan(y_pred_values) & ~np.isnan(y_true_values))\n",
-    "y_pred_values = y_pred_values[mask]\n",
-    "y_true_values = y_true_values[mask]"
+    "# calculate mean annual cycle\n",
+    "y_pred_ac = y_pred_m.groupby('time.month').mean(dim=('y', 'x', 'time'), skipna=True).values.squeeze()\n",
+    "y_true_ac = y_true_m.groupby('time.month').mean(dim=('y', 'x', 'time'), skipna=True).values.squeeze()"
    ]
   },
   {
    "cell_type": "code",
    "execution_count": null,
-   "id": "13a9ff21-34ea-4db1-9c0a-bcb7ea1001f7",
+   "id": "8ed529fb-ec3f-4c40-958a-4a7eaec8b25c",
    "metadata": {},
    "outputs": [],
    "source": [
-    "# calculate coefficient of determination\n",
-    "r2 = r2_score(y_true_values, y_pred_values)\n",
-    "r2"
+    "# compute daily anomalies\n",
+    "y_pred_anom = ERA5Dataset.anomalies(y_pred_pr, timescale='time.month')\n",
+    "y_true_anom = ERA5Dataset.anomalies(y_true, timescale='time.month')"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "id": "0fa1fe82-0d6e-4676-b5b8-9eac2fd28ffb",
+   "metadata": {},
+   "source": [
+    "### Coefficient of determination"
    ]
   },
   {
    "cell_type": "code",
    "execution_count": null,
-   "id": "a1b33431-9f2b-4e42-bbb0-f4fa258edd98",
+   "id": "ea73bb99-886c-40e9-913e-7ec25418013e",
    "metadata": {},
    "outputs": [],
    "source": [
-    "# group timeseries by month and calculate mean over time and space\n",
-    "y_pred_ac = y_pred_pr.resample(time='1M').sum(skipna=False).groupby('time.month').mean(dim=('y', 'x', 'time'), skipna=True).values.squeeze()\n",
-    "y_true_ac = y_true.resample(time='1M').sum(skipna=False).groupby('time.month').mean(dim=('y', 'x', 'time'), skipna=True).values.squeeze()"
+    "# get predicted and observed daily anomalies\n",
+    "y_pred_av = y_pred_anom.values.flatten()\n",
+    "y_true_av = y_true_anom.values.flatten()\n",
+    "\n",
+    "# apply mask of valid pixels\n",
+    "mask = (~np.isnan(y_pred_av) & ~np.isnan(y_true_av))\n",
+    "y_pred_av = y_pred_av[mask]\n",
+    "y_true_av = y_true_av[mask]\n",
+    "\n",
+    "# get predicted and observed monthly means\n",
+    "y_pred_mv = y_pred_m.values.flatten()\n",
+    "y_true_mv = y_true_m.values.flatten()\n",
+    "\n",
+    "# apply mask of valid pixels\n",
+    "mask = (~np.isnan(y_pred_mv) & ~np.isnan(y_true_mv))\n",
+    "y_pred_mv = y_pred_mv[mask]\n",
+    "y_true_mv = y_true_mv[mask]\n",
+    "\n",
+    "# calculate coefficient of determination on monthly means\n",
+    "r2_mm = r2_score(y_true_mv, y_pred_mv)\n",
+    "print('R2 on monthly means: {:.2f}'.format(r2_mm))\n",
+    "\n",
+    "# calculate coefficient of determination on daily anomalies\n",
+    "r2_anom = r2_score(y_true_av, y_pred_av)\n",
+    "print('R2 on daily anomalies: {:.2f}'.format(r2_anom))"
    ]
   },
   {
@@ -382,14 +419,15 @@
     "# ax.plot(y_true_values[subset], y_pred_values[subset], 'o', alpha=.5, markeredgecolor='grey', markerfacecolor='none', markersize=3);\n",
     "\n",
     "# plot entire dataset\n",
-    "ax.plot(y_true_values, y_pred_values, 'o', alpha=.5, markeredgecolor='grey', markerfacecolor='none', markersize=3);\n",
+    "ax.plot(y_true_mv, y_pred_mv, 'o', alpha=.5, markeredgecolor='grey', markerfacecolor='none', markersize=3);\n",
     "\n",
     "# plot 1:1 mapping line\n",
-    "interval = np.arange(0, 300, 50)\n",
+    "interval = np.arange(0, 750, 50)\n",
     "ax.plot(interval, interval, color='k', lw=2, ls='--')\n",
     "\n",
-    "# add coefficient of determination: calculated on entire dataset!\n",
-    "ax.text(interval[-1] - 2, interval[0] + 2, s='Coefficient of determination R$^2$ = {:.2f}'.format(r2), ha='right', fontsize=18)\n",
+    "# add coefficients of determination\n",
+    "ax.text(interval[-1] - 0.5, interval[0] + 2, s='R$^2$ (monthly means)= {:.2f}'.format(r2_mm), ha='right', fontsize=18)\n",
+    "ax.text(interval[-1] - 0.5, interval[0] + 27.5, s='R$^2$ (daily anomalies) = {:.2f}'.format(r2_anom), ha='right', fontsize=18)\n",
     "\n",
     "# format axes\n",
     "ax.set_ylim(interval[0], interval[-1])\n",
@@ -403,7 +441,7 @@
     "ax.set_title('Monthly mean {} (mm / month)'.format(NAMES[PREDICTAND]), fontsize=20, pad=10);\n",
     "\n",
     "# add axis for annual cycle\n",
-    "axins = inset_axes(ax, width=\"30%\", height=\"40%\", loc=2, borderpad=0.25)\n",
+    "axins = inset_axes(ax, width=\"30%\", height=\"40%\", loc=2, borderpad=0.5)\n",
     "axins.plot(y_pred_ac, ls='--', color='k', label='Predicted')\n",
     "axins.plot(y_true_ac, ls='-', color='k', label='Observed')\n",
     "axins.legend(frameon=False, fontsize=12, loc='lower center');\n",
@@ -414,7 +452,91 @@
     "axins.set_xticklabels([calendar.month_name[i + 1] for i in np.arange(0, 12)], rotation=90, fontsize=12)\n",
     "\n",
     "# save figure\n",
-    "fig.savefig('../Notebooks/Figures/{}_r2.png'.format(PREDICTAND), dpi=300, bbox_inches='tight')"
+    "fig.savefig('../Notebooks/Figures/{}_r2_{}.png'.format(PREDICTAND, LOSS), dpi=300, bbox_inches='tight')"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "id": "0f0b28e8-742d-4342-a01b-04fc5f736026",
+   "metadata": {},
+   "source": [
+    "### Coefficient of determination: Spatially"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "db5a5180-e1fb-4be6-8970-532064fc9298",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "# iterate over the grid points\n",
+    "r2 = np.ones((2, len(y_pred_m.x), len(y_pred_m.y)), dtype=np.float32) * np.nan\n",
+    "for i, _ in enumerate(y_pred_m.x):\n",
+    "    for j, _ in enumerate(y_pred_m.y):\n",
+    "        # get observed and predicted monthly precipitation for current grid point\n",
+    "        point_true = y_true_m.isel(x=i, y=j)\n",
+    "        point_pred = y_pred_m.isel(x=i, y=j)\n",
+    "        \n",
+    "        # remove missing values\n",
+    "        mask = ((~np.isnan(point_true)) & (~np.isnan(point_pred)))\n",
+    "        point_true = point_true[mask].values\n",
+    "        point_pred = point_pred[mask].values\n",
+    "        if point_true.size < 1:\n",
+    "            continue\n",
+    "        \n",
+    "        # get anomalies for current grid point\n",
+    "        point_anom_true = y_true_anom.isel(x=i, y=j)\n",
+    "        point_anom_pred = y_pred_anom.isel(x=i, y=j)\n",
+    "        \n",
+    "        # remove missing values\n",
+    "        mask_anom = ((~np.isnan(point_anom_true)) & (~np.isnan(point_anom_pred)))\n",
+    "        point_anom_true = point_anom_true[mask_anom].values\n",
+    "        point_anom_pred = point_anom_pred[mask_anom].values\n",
+    "\n",
+    "        # compute coefficient of determination\n",
+    "        r2[0, j, i] = r2_score(point_true, point_pred)\n",
+    "        r2[1, j, i] = r2_score(point_anom_true, point_anom_pred)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "c6a23548-77b7-4acb-b9ea-e2404c1c2256",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "# define color map: red to green\n",
+    "grn = cm.get_cmap('Greens', 128)\n",
+    "red = cm.get_cmap('Reds_r', 128)\n",
+    "red2green = ListedColormap(np.vstack((red(np.linspace(0, 1, 128)),\n",
+    "                                      grn(np.linspace(0, 1, 128)))))\n",
+    "\n",
+    "# plot coefficients of determination\n",
+    "vmin, vmax = -1, 1\n",
+    "fig, ax = plt.subplots(1, 2, figsize=(10, 10))\n",
+    "\n",
+    "# monthly means\n",
+    "im0 = ax[0].imshow(r2[0, :], origin='lower', cmap=red2green, vmin=vmin, vmax=vmax)\n",
+    "ax[0].text(x=r2.shape[1] - 2, y=2, s='Average R$^2$: {:.2f}'.format(np.nanmean(r2[0, :])), fontsize=14, ha='right');\n",
+    "ax[0].set_axis_off()\n",
+    "ax[0].set_title('Monthly mean {} (mm / month)'.format(NAMES[PREDICTAND]), fontsize=14, pad=10);\n",
+    "\n",
+    "# daily anomalies\n",
+    "im1 = ax[1].imshow(r2[1, :], origin='lower', cmap=red2green, vmin=vmin, vmax=vmax)\n",
+    "ax[1].text(x=r2.shape[1] - 2, y=2, s='Average R$^2$: {:.2f}'.format(np.nanmean(r2[1, :])), fontsize=14, ha='right');\n",
+    "ax[1].set_axis_off()\n",
+    "ax[1].set_title('Daily {} anomaly (mm / day)'.format(NAMES[PREDICTAND]), fontsize=14, pad=10);\n",
+    "\n",
+    "# add colorbar \n",
+    "cbar_ax_bias = fig.add_axes([ax[1].get_position().x1 + 0.05, ax[1].get_position().y0,\n",
+    "                             0.03, ax[1].get_position().y1 - ax[1].get_position().y0])\n",
+    "cbar_bias = fig.colorbar(im0, cax=cbar_ax_bias)\n",
+    "cbar_bias.set_label(label='Coefficient of determination R$^2$', fontsize=14)\n",
+    "cbar_bias.ax.tick_params(labelsize=14, pad=10)\n",
+    "\n",
+    "# save figure\n",
+    "fig.savefig('../Notebooks/Figures/{}_r2_spatial_{}.png'.format(PREDICTAND, LOSS), dpi=300, bbox_inches='tight')"
    ]
   },
   {
@@ -440,7 +562,7 @@
    "metadata": {},
    "outputs": [],
    "source": [
-    "# yearly average bias over reference period\n",
+    "# average bias of daily precipitation over reference period\n",
     "y_pred_yearly_avg = y_pred_pr.groupby('time.year').mean(dim='time')\n",
     "y_refe_yearly_avg = y_refe.groupby('time.year').mean(dim='time')\n",
     "y_true_yearly_avg = y_true.groupby('time.year').mean(dim='time')\n",
@@ -472,8 +594,8 @@
    "outputs": [],
    "source": [
     "# root mean squared error over reference period\n",
-    "rmse_avg = ((y_pred_yearly_avg - y_true_yearly_avg) ** 2).mean()\n",
-    "rmse_avg_ref = ((y_refe_yearly_avg - y_true_yearly_avg) **2).mean()\n",
+    "rmse_avg = np.sqrt(((y_pred_yearly_avg - y_true_yearly_avg) ** 2).mean())\n",
+    "rmse_avg_ref = np.sqrt(((y_refe_yearly_avg - y_true_yearly_avg) **2).mean())\n",
     "print('(Model) Yearly average RMSE: {:.2f} mm / day'.format(rmse_avg.item()))\n",
     "print('(ERA-5) Yearly average RMSE: {:.2f} mm / day'.format(rmse_avg_ref.item()))"
    ]
@@ -490,7 +612,10 @@
     "    y_p = y_pred_yearly_avg.sel(year=year).values        \n",
     "    y_t = y_true_yearly_avg.sel(year=year).values\n",
     "    r, _ = stats.pearsonr(y_p[~np.isnan(y_p)], y_t[~np.isnan(y_t)])\n",
-    "    print('({:0d}) Pearson correlation: {:.2f}'.format(year.item(), np.asarray(r).mean()))"
+    "    print('({:0d}) Pearson correlation: {:.2f}'.format(year.item(), np.asarray(r).mean()))\n",
+    "r, _ = stats.pearsonr(y_pred_yearly_avg.values[~np.isnan(y_pred_yearly_avg.values)],\n",
+    "                      y_true_yearly_avg.values[~np.isnan(y_true_yearly_avg.values)])\n",
+    "print('Total: {:.2f}'.format(r))"
    ]
   },
   {
@@ -500,18 +625,18 @@
    "metadata": {},
    "outputs": [],
    "source": [
-    "# plot yearly average MAE of reference vs. prediction\n",
-    "vmin, vmax = 0, 5\n",
+    "# plot yearly average bias of reference vs. prediction\n",
+    "vmin, vmax = -40, 40\n",
     "fig, axes = plt.subplots(1, 3, figsize=(24, 8), sharex=True, sharey=True)\n",
     "\n",
     "# plot bias of ERA-5 reference\n",
     "reference = bias_yearly_avg_ref.mean(dim='year')\n",
-    "im1 = axes[0].imshow(reference.values, origin='lower', cmap='RdBu_r', vmin=-40, vmax=40)\n",
+    "im1 = axes[0].imshow(reference.values, origin='lower', cmap='RdBu_r', vmin=vmin, vmax=vmax)\n",
     "axes[0].text(x=reference.shape[0] - 2, y=2, s='Average: {:.1f}%'.format(reference.mean().item()), fontsize=14, ha='right')\n",
     "\n",
     "# plot MAE of model\n",
     "prediction = bias_yearly_avg.mean(dim='year')\n",
-    "im2 = axes[1].imshow(prediction.values, origin='lower', cmap='RdBu_r', vmin=-40, vmax=40)\n",
+    "im2 = axes[1].imshow(prediction.values, origin='lower', cmap='RdBu_r', vmin=vmin, vmax=vmax)\n",
     "axes[1].text(x=reference.shape[0] - 2, y=2, s='Average: {:.1f}%'.format(prediction.mean().item()), fontsize=14, ha='right')\n",
     "\n",
     "# plot topography\n",
@@ -558,7 +683,7 @@
     "#axes[1].text(x=ds.shape[0] - 2, y=12, s='RMSE = {:.1f}'.format(rmse_avg[NAMES[PREDICTAND]].item()) + 'mm$^2$ day$^{-2}$', fontsize=14, ha='right')\n",
     "\n",
     "# save figure\n",
-    "fig.savefig('../Notebooks/Figures/{}_rbias_ERA_vs_model.png'.format(PREDICTAND), dpi=300, bbox_inches='tight')"
+    "fig.savefig('../Notebooks/Figures/{}_rbias_ERA_{}.png'.format(PREDICTAND, LOSS), dpi=300, bbox_inches='tight')"
    ]
   },
   {
@@ -620,7 +745,7 @@
     "axes[1].text(x=ds.shape[0] - 2, y=12, s='RMSE = {:.1f}'.format(rmse_avg.mean().item()) + 'mm day$^{-1}$', fontsize=14, ha='right')\n",
     "\n",
     "# save figure\n",
-    "fig.savefig('../Notebooks/Figures/{}_average_bias.png'.format(PREDICTAND), dpi=300, bbox_inches='tight')"
+    "fig.savefig('../Notebooks/Figures/{}_bias_{}.png'.format(PREDICTAND, LOSS), dpi=300, bbox_inches='tight')"
    ]
   },
   {
@@ -736,7 +861,7 @@
     "cbar.ax.tick_params(labelsize=14)\n",
     "\n",
     "# save figure\n",
-    "fig.savefig('../Notebooks/Figures/{}_average_bias_seasonal.png'.format(PREDICTAND), dpi=300, bbox_inches='tight')"
+    "fig.savefig('../Notebooks/Figures/{}_bias_seasonal_{}.png'.format(PREDICTAND, LOSS), dpi=300, bbox_inches='tight')"
    ]
   },
   {
@@ -849,8 +974,8 @@
    "outputs": [],
    "source": [
     "# root mean squared error in extreme quantile\n",
-    "rmse_ex = ((y_pred_ex - y_true_ex) ** 2).mean()\n",
-    "rmse_ex_ref = ((y_refe_ex - y_true_ex) ** 2).mean()"
+    "rmse_ex = np.sqrt(((y_pred_ex - y_true_ex) ** 2).mean())\n",
+    "rmse_ex_ref = np.sqrt(((y_refe_ex - y_true_ex) ** 2).mean())"
    ]
   },
   {
@@ -884,9 +1009,9 @@
    "source": [
     "# plot extremes of observation, prediction, and bias\n",
     "vmin, vmax = 10, 40\n",
-    "fig, axes = plt.subplots(1, 3, figsize=(24, 6), sharex=True, sharey=True)\n",
-    "axes = axes.reshape(1, -1)\n",
-    "for ds, ax in zip([y_true_ex, y_pred_ex, bias_ex], axes[i, ...]):\n",
+    "fig, axes = plt.subplots(1, 3, figsize=(24, 8), sharex=True, sharey=True)\n",
+    "axes = axes.flatten()\n",
+    "for ds, ax in zip([y_true_ex, y_pred_ex, bias_ex], axes):\n",
     "    if ds is bias_ex:\n",
     "        ds = ds.mean(dim='year')\n",
     "        im2 = ax.imshow(ds.values, origin='lower', cmap='RdBu_r', vmin=-40, vmax=40)\n",
@@ -895,9 +1020,9 @@
     "        im1 = ax.imshow(ds.mean(dim='year').values, origin='lower', cmap='BuPu', vmin=vmin, vmax=vmax)\n",
     "        \n",
     "# set titles\n",
-    "axes[0, 0].set_title('Observed', fontsize=16, pad=10);\n",
-    "axes[0, 1].set_title('Predicted', fontsize=16, pad=10);\n",
-    "axes[0, 2].set_title('Bias', fontsize=16, pad=10);\n",
+    "axes[0].set_title('Observed', fontsize=16, pad=10);\n",
+    "axes[1].set_title('Predicted', fontsize=16, pad=10);\n",
+    "axes[2].set_title('Bias', fontsize=16, pad=10);\n",
     "\n",
     "# adjust axes\n",
     "for ax in axes.flat:\n",
@@ -934,7 +1059,7 @@
     "axes[1].text(x=ds.shape[0] - 2, y=12, s='RMSE = {:.1f}'.format(rmse_ex.item())  + 'mm$^2$ day$^{-2}$', fontsize=14, ha='right')\n",
     "\n",
     "# save figure\n",
-    "fig.savefig('../Notebooks/Figures/{}_average_bias_p{:.0f}.png'.format(PREDICTAND, quantile * 100), dpi=300, bbox_inches='tight')"
+    "fig.savefig('../Notebooks/Figures/{}_bias_p{:.0f}_{}.png'.format(PREDICTAND, quantile * 100, LOSS), dpi=300, bbox_inches='tight')"
    ]
   },
   {
@@ -1044,7 +1169,7 @@
     "cbar.ax.tick_params(labelsize=14)\n",
     "\n",
     "# save figure\n",
-    "fig.savefig('../Notebooks/Figures/{}_average_bias_seasonal_ex.png'.format(PREDICTAND), dpi=300, bbox_inches='tight')"
+    "fig.savefig('../Notebooks/Figures/{}_bias_p{:.0f}_seasonal_{}.png'.format(PREDICTAND, quantile * 100, LOSS), dpi=300, bbox_inches='tight')"
    ]
   },
   {
@@ -1137,12 +1262,12 @@
    "outputs": [],
    "source": [
     "# plot average of observation, prediction, and bias\n",
-    "fig, axes = plt.subplots(2, 3, figsize=(24, 12), sharex=True, sharey=True)\n",
+    "fig, axes = plt.subplots(2, 3, figsize=(24, 16), sharex=True, sharey=True)\n",
     "axes = axes.flatten()\n",
     "\n",
     "# plot annual average bias of extreme\n",
     "ds = bias_wet\n",
-    "im = axes[0].imshow(ds.values, origin='lower', cmap='RdBu_r', vmin=-40, vmax=40)\n",
+    "im = axes[0].imshow(ds.values, origin='lower', cmap='RdBu_r', vmin=-100, vmax=100)\n",
     "axes[0].set_title('Annual', fontsize=16);\n",
     "axes[0].text(x=ds.shape[0] - 2, y=2, s='Average: {:.1f}%'.format(ds.mean().item()), fontsize=14, ha='right')\n",
     "\n",
@@ -1178,7 +1303,7 @@
     "cbar_predictand.ax.tick_params(labelsize=14)\n",
     "\n",
     "# save figure\n",
-    "fig.savefig('../Notebooks/Figures/{}_bias_wet_days.png'.format(PREDICTAND), dpi=300, bbox_inches='tight')"
+    "fig.savefig('../Notebooks/Figures/{}_bias_wd_{}.png'.format(PREDICTAND, LOSS), dpi=300, bbox_inches='tight')"
    ]
   },
   {
@@ -1220,7 +1345,7 @@
    "outputs": [],
    "source": [
     "# plot average of observation, prediction, and bias\n",
-    "fig, axes = plt.subplots(1, 3, figsize=(24, 6), sharex=True, sharey=True)\n",
+    "fig, axes = plt.subplots(1, 3, figsize=(24, 8), sharex=True, sharey=True)\n",
     "for ds, ax in zip([dii_true, dii_pred, bias_dii], axes):\n",
     "    if ds is bias_dii:\n",
     "        im2 = ax.imshow(ds.values, origin='lower', cmap='RdBu_r', vmin=-40, vmax=40)\n",
@@ -1264,7 +1389,7 @@
     "cbar_predictand.ax.tick_params(labelsize=14)\n",
     "\n",
     "# save figure\n",
-    "fig.savefig('../Notebooks/Figures/{}_bias_wet_days_p.png'.format(PREDICTAND), dpi=300, bbox_inches='tight')"
+    "fig.savefig('../Notebooks/Figures/{}_bias_wdp_{}.png'.format(PREDICTAND, LOSS), dpi=300, bbox_inches='tight')"
    ]
   },
   {
@@ -1291,7 +1416,7 @@
    "outputs": [],
    "source": [
     "# true and predicted probability of precipitation\n",
-    "p_true = (y_true > float(WET_DAY_THRESHOLD)).values.flatten()\n",
+    "p_true = (y_true >= float(WET_DAY_THRESHOLD)).values.flatten()\n",
     "p_pred = y_pred_prob.values.flatten()"
    ]
   },
@@ -1360,13 +1485,13 @@
     "ax.legend(frameon=False, loc='lower right', fontsize=14);\n",
     "\n",
     "# save figure\n",
-    "fig.savefig('../Notebooks/Figures/{}_ROC.png'.format(PREDICTAND), dpi=300, bbox_inches='tight')"
+    "fig.savefig('../Notebooks/Figures/{}_ROC_{}.png'.format(PREDICTAND, LOSS), dpi=300, bbox_inches='tight')"
    ]
   }
  ],
  "metadata": {
   "kernelspec": {
-   "display_name": "Python 3",
+   "display_name": "Python 3 (ipykernel)",
    "language": "python",
    "name": "python3"
   },
diff --git a/Notebooks/eval_sensitivity.ipynb b/Notebooks/eval_sensitivity.ipynb
new file mode 100644
index 0000000000000000000000000000000000000000..6ab07d1003938f6c1ea152669d57de2387d28fc2
--- /dev/null
+++ b/Notebooks/eval_sensitivity.ipynb
@@ -0,0 +1,558 @@
+{
+ "cells": [
+  {
+   "cell_type": "markdown",
+   "id": "f2bfabd5-c8d6-46aa-b340-c169cc03bee7",
+   "metadata": {},
+   "source": [
+    "# Precipitation: Sensitivity analysis of hyperparameters"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "id": "9f024696-3cb3-4637-ab30-c57df733aeed",
+   "metadata": {},
+   "source": [
+    "We used **1981-1991 as training** period and **1991-2010 as reference** period. The results shown in this notebook are based on the model predictions on the reference period."
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "id": "17ca3513-5757-470c-9d3c-7494b5bade56",
+   "metadata": {},
+   "source": [
+    "**Predictors on pressure levels (500, 850)**:\n",
+    "- Geopotential (z)\n",
+    "- Temperature (t)\n",
+    "- Zonal wind (u)\n",
+    "- Meridional wind (v)\n",
+    "- Specific humidity (q)\n",
+    "\n",
+    "**Predictors on surface**:\n",
+    "- Surface pressure (p)\n",
+    "- Totat precipitation (pr)\n",
+    "\n",
+    "**Auxiliary predictors**:\n",
+    "- Elevation from Copernicus EU-DEM v1.1 (dem)\n",
+    "- Day of the year (doy)"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "id": "a7e732f3-a219-4317-aa42-638cae86b4af",
+   "metadata": {},
+   "source": [
+    "Define the predictand and the model to evaluate:"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "d75e3627-13c4-46b8-b749-d582e4b802ac",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "# define the model parameters\n",
+    "PREDICTAND = 'pr'  # precipitation, tasmin, tasmax\n",
+    "MODEL = 'USegNet'\n",
+    "# PPREDICTORS = 'ztuvq'\n",
+    "PPREDICTORS = ''\n",
+    "PLEVELS = ['500', '850']\n",
+    "# SPREDICTORS = 'p'\n",
+    "SPREDICTORS = 'pr'\n",
+    "# DEM = 'dem'\n",
+    "DEM = ''\n",
+    "DEM_FEATURES = ''\n",
+    "DOY = ''\n",
+    "# DOY = 'doy'\n",
+    "OPTIMS = ['Adam', 'SGD', 'SGDCLR']"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "29602f30-6b7f-4be8-aefa-eff01997bd1e",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "# parameters for sensivity analysis\n",
+    "WET_DAY_THRESHOLD = [0, 0.5, 1, 2, 3, 5]\n",
+    "LAMBDA = [0, 1e-6, 1e-5, 1e-4, 1e-3, 1e-2, 1]\n",
+    "LOSS = ['BernoulliGammaLoss', 'L1Loss' ,'MSELoss'] if PREDICTAND == 'pr' else ['L1Loss', 'MSELoss']"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "5351191a-7cd6-48e6-a6c2-c7cd695db277",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "# maximum learning rates for optimizers and loss functions\n",
+    "LR = {'pr': {'L1Loss': {'SGD': 0.001, 'Adam': 0.001},\n",
+    "             'MSELoss': {'SGD': 0.0004, 'Adam': 0.0004},\n",
+    "             'BernoulliGammaLoss': {'SGD': 0.001, 'Adam': 0.0005}},\n",
+    "      'tasmin': {'L1Loss': {'SGD': 0.004, 'Adam': 0.001},\n",
+    "                 'MSELoss': {'SGD': 0.002, 'Adam': 0.001}},\n",
+    "      'tasmax': {'L1Loss': {'SGD': 0.001, 'Adam': 0.001},\n",
+    "                 'MSELoss': {'SGD': 0.004, 'Adam': 0.001}}\n",
+    "     }"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "bda0b985-f5f2-40b3-8b36-94bba633ea51",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "# constant wet day threshold not evaluating wet days\n",
+    "CWDT = 1"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "5a47524c-4d70-4556-88a5-79d25038c9c0",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "# parameter to investigate\n",
+    "PARAMETER = {'regularization': LAMBDA}\n",
+    "# PARAMETER = {'Wet day threshold': WET_DAY_THRESHOLD}"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "b7a61547-9887-4068-8fbf-9a799922ff8d",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "# extreme quantile of interest\n",
+    "quantile = 0.98"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "id": "cda0cf8b-adce-4513-af69-31df6eb5542f",
+   "metadata": {},
+   "source": [
+    "### Imports"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "01ed9fbc-49c2-4028-a3f3-692a014dcf89",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "# builtins\n",
+    "import datetime\n",
+    "import warnings\n",
+    "import calendar\n",
+    "from itertools import product\n",
+    "\n",
+    "# externals\n",
+    "import xarray as xr\n",
+    "import numpy as np\n",
+    "import pandas as pd\n",
+    "import matplotlib.pyplot as plt\n",
+    "import matplotlib.ticker as mticker\n",
+    "import seaborn as sns\n",
+    "from mpl_toolkits.axes_grid1.inset_locator import inset_axes\n",
+    "from sklearn.metrics import r2_score, roc_curve, auc, classification_report\n",
+    "\n",
+    "# locals\n",
+    "from climax.core.dataset import ERA5Dataset\n",
+    "from climax.main.config import VALID_PERIOD\n",
+    "from climax.main.io import ERA5_PATH, OBS_PATH, TARGET_PATH, DEM_PATH"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "6b6f5703-d247-40f5-a859-2d362592b1d8",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "# path of predictions for hyperparameter sensitivity analysis\n",
+    "TARGET_PATH = TARGET_PATH.joinpath('sensitivity')"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "073dfc6a-8b33-4aca-ab71-c06ef92bbb80",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "# mapping from predictands to variable names\n",
+    "NAMES = {'tasmin': 'minimum temperature', 'tasmax': 'maximum temperature', 'pr': 'precipitation'}"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "id": "9410ed09-50a9-4d19-9e90-8551ab701509",
+   "metadata": {},
+   "source": [
+    "### Load datasets"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "01509828-e94c-4ad9-b2a0-67d8f77b4f84",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "# load observations\n",
+    "y_true = xr.open_dataset(OBS_PATH.joinpath(PREDICTAND, 'OBS_{}_1980_2018.nc'.format(PREDICTAND)), chunks={'time': 365})\n",
+    "y_true = y_true.sel(time=VALID_PERIOD)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "4ca10fb8-bc89-4dea-9346-339e14db9c98",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "# mask of missing values\n",
+    "missing = np.isnan(y_true[NAMES[PREDICTAND]]) if PREDICTAND == 'pr' else np.isnan(y_true[PREDICTAND])"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "030c221d-e0ad-4e67-b742-db34de549983",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "# construct file pattern to match\n",
+    "PATTERN = '_'.join([MODEL, PREDICTAND])\n",
+    "PATTERN = '_'.join([PATTERN, PPREDICTORS]) if PPREDICTORS else PATTERN\n",
+    "PATTERN = '_'.join([PATTERN, *PLEVELS]) if PPREDICTORS else PATTERN\n",
+    "PATTERN = '_'.join([PATTERN, SPREDICTORS]) if SPREDICTORS else PATTERN\n",
+    "PATTERN = '_'.join([PATTERN, DEM]) if DEM else PATTERN\n",
+    "PATTERN = '_'.join([PATTERN, DEM_FEATURES]) if DEM_FEATURES else PATTERN\n",
+    "PATTERN = '_'.join([PATTERN, DOY]) if DOY else PATTERN"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "7106fed2-9f63-453b-a9af-b5d490efbbb4",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "# file pattern for different parametersa\n",
+    "if list(*PARAMETER.values()) == WET_DAY_THRESHOLD:\n",
+    "    PATTERNS = ['_'.join([PATTERN, '{}mm_{}'.format(str(t).replace('.', ''), LOSS), OPTIM]) for t in WET_DAY_THRESHOLD]\n",
+    "    \n",
+    "if list(*PARAMETER.values()) == LAMBDA:\n",
+    "    p = '_'.join([PATTERN, 'loss'])\n",
+    "    p = '_'.join([p, 'optim'])\n",
+    "    PATTERNS = ['_'.join([p, 'd{:.0e}'.format(decay)]) for decay in LAMBDA]\n",
+    "\n",
+    "# iterate over loss functions and parameter space\n",
+    "STATE_FILES = {l: {o: {p: '_'.join([v.replace('loss', '{}mm_{}'.format(CWDT, l) if l in ['BernoulliGammaLoss', 'BernoulliWeibullLoss'] else l).replace(\n",
+    "    'optim', o if o != 'SGDCLR' else 'SGD'), 'lr{:.0e}'.format(LR[PREDICTAND][l][o if o != 'SGDCLR' else 'SGD'] / 4), 'CyclicLR' if o == 'SGDCLR' else ''])\n",
+    "                   for p, v in zip(*PARAMETER.values(), PATTERNS)} for o in OPTIMS} for l in LOSS}\n",
+    "STATE_FILES"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "58177982-7f86-4c9e-9980-563cf5d48852",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "# align datasets and mask missing values\n",
+    "y_pred = {}\n",
+    "y_prob = {}\n",
+    "for l in LOSS:\n",
+    "    y_pred[l], y_prob[l] = {}, {}\n",
+    "    for o in OPTIMS:\n",
+    "        y_pred[l][o], y_prob[l][o] = {}, {}\n",
+    "        for k, v in STATE_FILES[l][o].items():\n",
+    "            ds = xr.open_dataset(TARGET_PATH.joinpath(PREDICTAND, '.'.join([v.rstrip('_'), 'nc'])), chunks={'time': 365})\n",
+    "            if PREDICTAND == 'pr':\n",
+    "                # align datasets and mask missing values\n",
+    "                if l == 'BernoulliGammaLoss':\n",
+    "                    _, y, p = xr.align(y_true[NAMES[PREDICTAND]], ds[NAMES[PREDICTAND]], ds.prob, join='override')\n",
+    "                    y = y.where(~missing, other=np.nan)\n",
+    "                    p = p.where(~missing, other=np.nan)\n",
+    "                else:\n",
+    "                    _, y = xr.align(y_true, ds[PREDICTAND], join='override')\n",
+    "                    y = y.where(~missing, other=np.nan)\n",
+    "                    p = None\n",
+    "                try:           \n",
+    "                    y = y.rename({PREDICTAND: NAMES[PREDICTAND]})\n",
+    "                except ValueError:\n",
+    "                    pass              \n",
+    "                y_pred[l][o][k] = y\n",
+    "                y_prob[l][o][k] = p             \n",
+    "            else:     \n",
+    "                # align datasets and mask missing values\n",
+    "                _, y = xr.align(y_true[PREDICTAND], ds[PREDICTAND], join='override')\n",
+    "                y = y.where(~missing, other=np.nan)\n",
+    "                y_pred[l][o][k] = y\n",
+    "\n",
+    "# rename predictand\n",
+    "if PREDICTAND == 'pr':\n",
+    "    PREDICTAND = NAMES[PREDICTAND]"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "id": "5a72117c-2e1c-462a-968e-91db18b4aaaf",
+   "metadata": {},
+   "source": [
+    "## Model sensitivity"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "0164fd81-aff4-4fa0-b156-5ba0d817ed49",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "# whether to overwrite calculated metrics\n",
+    "OVERWRITE = False"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "419006dd-9cfa-4a49-a5d2-9c250054aa50",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "# initialize DataFrame filename\n",
+    "filename = '_'.join([PREDICTAND, PPREDICTORS]) if PPREDICTORS else PREDICTAND\n",
+    "filename = '_'.join([filename, SPREDICTORS]) if SPREDICTORS else filename\n",
+    "filename = '_'.join([filename, DEM]) if DEM else filename\n",
+    "filename = '_'.join([filename, DOY]) if DOY else filename\n",
+    "filename = '_'.join([filename, 'sensitivity_{}.csv'.format(*PARAMETER.keys())])\n",
+    "df_name = TARGET_PATH.joinpath(PREDICTAND if PREDICTAND != 'precipitation' else 'pr', filename)\n",
+    "df_name"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "9758d58b-82e9-47b2-9cf7-1c93d5ed04a7",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "# yearly average of observations\n",
+    "y_true_yearly_avg = y_true[PREDICTAND].groupby('time.year').mean(dim='time')\n",
+    "\n",
+    "# daily anomalies\n",
+    "y_true_anom = ERA5Dataset.anomalies(y_true[PREDICTAND], timescale='time.month').values\n",
+    "mask = np.isnan(y_true_anom)\n",
+    "y_true_values = y_true_anom[~mask]"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "57ca80c6-2c02-44d1-87eb-35c8a78ec71f",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "if df_name.exists() and not OVERWRITE:\n",
+    "    df = pd.read_csv(df_name, sep=',')\n",
+    "else:\n",
+    "    # initialize model hyperparameter DataFrame\n",
+    "    df = pd.DataFrame(columns=['R2', 'Bias', 'MAE', 'RMSE', 'AUC', 'ROCSS', 'Loss', 'Optim', *PARAMETER.keys()])\n",
+    "    df\n",
+    "    \n",
+    "    # define string to log\n",
+    "    if PREDICTAND == 'precipitation':\n",
+    "        logstring = '({}), ({}), ({}={:.0e}): r2: {:.2f}, bias: {:.2f}%, MAE: {:.2f}mm, RMSE: {:.2f}mm, AUC: {:.2f}, ROCSS: {:.2f}'\n",
+    "    else:\n",
+    "        logstring = '({}), ({}), ({}={:.0e}): r2: {:.2f}, bias: {:.2f}°C, MAE: {:.2f}°C, RMSE: {:.2f}°C, AUC: {:.2f}, ROCSS: {:.2f}'\n",
+    "\n",
+    "    # calculate metrics for each hyperparameter combination\n",
+    "    for loss, prediction in y_pred.items():\n",
+    "        for optim, ds in prediction.items():\n",
+    "            for param, value in ds.items():        \n",
+    "                with warnings.catch_warnings():\n",
+    "                    warnings.simplefilter('ignore', category=RuntimeWarning)\n",
+    "\n",
+    "                    # coefficient of determination\n",
+    "                    try:\n",
+    "                        # compute predicted daily anomalies\n",
+    "                        y_pred_anom = ERA5Dataset.anomalies(value, timescale='time.month').values\n",
+    "\n",
+    "                        # apply mask of valid pixels\n",
+    "                        y_pred_values = y_pred_anom[~mask]\n",
+    "\n",
+    "                        # coefficient of determination\n",
+    "                        r2 = r2_score(y_true_values, y_pred_values)\n",
+    "                    except ValueError:\n",
+    "                        r2 = np.nan\n",
+    "\n",
+    "                    # calculate metrics on yearly average of daily values\n",
+    "                    y_pred_yearly_avg = value.groupby('time.year').mean(dim='time')\n",
+    "\n",
+    "                    # relative or absolute bias\n",
+    "                    if PREDICTAND == 'precipitation':\n",
+    "                        # calculate relative bias for precipitation\n",
+    "                        bias = ((y_pred_yearly_avg - y_true_yearly_avg) / y_true_yearly_avg).mean().values.item() * 100\n",
+    "                    else:\n",
+    "                        # calculate absolute bias for temperatures\n",
+    "                        bias = (y_pred_yearly_avg - y_true_yearly_avg).mean().values.item()\n",
+    "\n",
+    "                    # mean absolute error and root mean squared error\n",
+    "                    mae = np.abs(y_pred_yearly_avg - y_true_yearly_avg).mean().values.item()\n",
+    "                    rmse = np.sqrt(((y_pred_yearly_avg - y_true_yearly_avg) ** 2).mean()).values.item()\n",
+    "\n",
+    "                    # receiver operating characteristics\n",
+    "                    if loss in ['BernoulliGammaLoss', 'BernoulliWeibullLoss']:\n",
+    "                        p_true = (y_true[PREDICTAND] >= param if list(*PARAMETER.values()) == WET_DAY_THRESHOLD\n",
+    "                                  else y_true[PREDICTAND] >= CWDT).values\n",
+    "                        p_pred = y_prob[loss][optim][param].values\n",
+    "\n",
+    "                        # apply mask of valid pixels\n",
+    "                        p_pred = p_pred[~mask]\n",
+    "                        p_t = p_true[~mask].astype(float)\n",
+    "\n",
+    "                        # calculate ROC: false positive rate vs. true positive rate\n",
+    "                        fpr, tpr, _ = roc_curve(p_t, p_pred)\n",
+    "                        area = auc(fpr, tpr) # area under ROC curve\n",
+    "                        rocss = 2 * area - 1 # ROC skill score (cf. https://journals.ametsoc.org/view/journals/clim/16/24/1520-0442_2003_016_4145_otrsop_2.0.co_2.xml)\n",
+    "                    else:\n",
+    "                        area, rocss = np.nan, np.nan\n",
+    "\n",
+    "                    # log metrics to console\n",
+    "                    print(logstring.format(loss, optim, *PARAMETER.keys(), param, r2, bias, mae, rmse, area, rocss))\n",
+    "\n",
+    "                # add simulation to dataframe\n",
+    "                df = df.append(pd.DataFrame([[r2, bias, mae, rmse, area, rocss, loss, optim, param]], columns=df.columns))\n",
+    "\n",
+    "    # save DataFrame of calculated metrics to csv\n",
+    "    df.to_csv(df_name, header=True, index=False, sep=',')\n",
+    "    df"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "id": "975d77a7-e4b1-46ac-877d-f1a4464e4050",
+   "metadata": {},
+   "source": [
+    "## Visualize results"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "026167b6-1737-495c-b33c-dd70779b351c",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "# color palette\n",
+    "PALETTE = 'mako_r'\n",
+    "sns.color_palette(PALETTE, n_colors=len(LOSS))"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "8c66b1cb-318c-48b5-a27c-3e6394455e69",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "# metrics to plot\n",
+    "METRICS = ['Bias', 'MAE', 'ROCSS' if PREDICTAND == 'precipitation' else 'RMSE', 'R2']"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "a6a8d8b1-6c66-488a-b961-86e60593e7c7",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "# instanciate figure\n",
+    "fig, axes = plt.subplots(4, len(OPTIMS), sharex=True, figsize=(16, 10))\n",
+    "for axis, m in zip(axes, METRICS):\n",
+    "    for ax, optim in zip(axis, OPTIMS):\n",
+    "        # subset to current optimizer\n",
+    "        data = df[df['Optim'] == optim]\n",
+    "        sns.barplot(x=list(PARAMETER.keys()).pop(), y=m, hue='Loss', data=data, ax=ax, palette=PALETTE)\n",
+    "        \n",
+    "        # axis properties\n",
+    "        ax.set_ylabel('')\n",
+    "        ax.set_xlabel('')\n",
+    "        ax.get_legend().remove()\n",
+    "        ax.tick_params(axis='both', which='major', labelsize=14)\n",
+    "        if m == 'Bias':  \n",
+    "            ax.set_ylim((-20, 20) if PREDICTAND == 'precipitation' else (-1, 1))\n",
+    "        else:\n",
+    "            ax.set_ylim(0, 1)\n",
+    "            \n",
+    "    # adjust y-ticklabels\n",
+    "    for ax in axis[1:]:\n",
+    "        ax.set_yticklabels('')\n",
+    "\n",
+    "    # adjust y-label\n",
+    "    if m == 'Bias':\n",
+    "        axis[0].set_ylabel('Relative Bias (%)' if PREDICTAND == 'precipitation' else 'Bias (°C)', fontsize=16)\n",
+    "    elif m in ['MAE', 'RMSE']:\n",
+    "        axis[0].set_ylabel('{} ({})'.format(m, 'mm' if PREDICTAND == 'precipitation' else '°C'), fontsize=16)\n",
+    "    else:\n",
+    "        axis[0].set_ylabel(m, fontsize=16)\n",
+    "\n",
+    "# adjust x-tick labels\n",
+    "for ax in axes[-1, :]:\n",
+    "    ax.set_xticklabels(['$10^{{ {:.0f} }}$'.format(np.log10(v)) if v > 0 else '{:.0f}'.format(v) for v in list(*PARAMETER.values())])\n",
+    "    \n",
+    "# add optimizer to axes\n",
+    "for ax, optim in zip(axes[0, :], OPTIMS):\n",
+    "    ax.set_title('SGD + CLR' if optim == 'SGDCLR' else optim, fontsize=16)\n",
+    "\n",
+    "# add legend\n",
+    "axes[-1, 0].legend(loc='upper left', bbox_to_anchor=(-0.05, -0.15), frameon=False, ncol=len(LOSS), fontsize=16);\n",
+    "\n",
+    "# adjust subplots\n",
+    "fig.subplots_adjust(wspace=0.05, hspace=0.15)\n",
+    "fig.savefig('./Figures/{}'.format(df_name.name.replace('csv', 'png')), bbox_inches='tight', dpi=300)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "18b21e6e-711c-4afd-91de-9d7e4403c6ce",
+   "metadata": {},
+   "outputs": [],
+   "source": []
+  }
+ ],
+ "metadata": {
+  "kernelspec": {
+   "display_name": "Python 3 (ipykernel)",
+   "language": "python",
+   "name": "python3"
+  },
+  "language_info": {
+   "codemirror_mode": {
+    "name": "ipython",
+    "version": 3
+   },
+   "file_extension": ".py",
+   "mimetype": "text/x-python",
+   "name": "python",
+   "nbconvert_exporter": "python",
+   "pygments_lexer": "ipython3",
+   "version": "3.7.10"
+  }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 5
+}
diff --git a/Notebooks/eval_temperature.ipynb b/Notebooks/eval_temperature.ipynb
index f315da9240883ed208b9144b88e58c8865caca2e..6422cb5a47f3d8a82346cec838131bb6793e681b 100644
--- a/Notebooks/eval_temperature.ipynb
+++ b/Notebooks/eval_temperature.ipynb
@@ -29,7 +29,7 @@
     "- Specific humidity (q)\n",
     "\n",
     "**Predictors on surface**:\n",
-    "- Mean sea level pressure (msl)\n",
+    "- Surface pressure (p)\n",
     "\n",
     "**Auxiliary predictors**:\n",
     "- Elevation from Copernicus EU-DEM v1.1 (dem)\n",
@@ -47,19 +47,50 @@
   {
    "cell_type": "code",
    "execution_count": null,
-   "id": "bb24b6ed-2d0a-44e0-b9a9-abdcb2a8294d",
+   "id": "bd50716f-eeba-431a-97ce-24ea5d625090",
    "metadata": {},
    "outputs": [],
    "source": [
     "# define the model parameters\n",
-    "PREDICTAND = 'tasmax'\n",
+    "PREDICTAND = 'tasmin'\n",
     "MODEL = 'USegNet'\n",
     "PPREDICTORS = 'ztuvq'\n",
+    "# PPREDICTORS = ''\n",
     "PLEVELS = ['500', '850']\n",
+    "# PLEVELS = []\n",
     "SPREDICTORS = 'p'\n",
+    "# SPREDICTORS = 'ppr'\n",
     "DEM = 'dem'\n",
     "DEM_FEATURES = ''\n",
-    "DOY = 'doy'"
+    "DOY = 'doy'\n",
+    "LOSS = 'L1Loss'\n",
+    "# LOSS = 'MSELoss'\n",
+    "OPTIM = 'SGD'\n",
+    "# OPTIM = 'Adam'\n",
+    "SEASON = ''\n",
+    "DECAY = '1e-06'\n",
+    "LR = '1e-03'"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "d43149c7-ed56-4c91-9ce9-e8b226b48426",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "# construct file pattern to match\n",
+    "PATTERN = '_'.join([MODEL, PREDICTAND, PPREDICTORS, *PLEVELS])\n",
+    "PATTERN = '_'.join([PATTERN, SPREDICTORS]) if SPREDICTORS else PATTERN\n",
+    "PATTERN = '_'.join([PATTERN, DEM]) if DEM else PATTERN\n",
+    "PATTERN = '_'.join([PATTERN, DEM_FEATURES]) if DEM_FEATURES else PATTERN\n",
+    "PATTERN = '_'.join([PATTERN, DOY]) if DOY else PATTERN\n",
+    "PATTERN = '_'.join([PATTERN, LOSS])\n",
+    "PATTERN = '_'.join([PATTERN, OPTIM])\n",
+    "PATTERN = '_'.join([PATTERN, SEASON]) if SEASON else PATTERN\n",
+    "PATTERN = '_'.join([PATTERN, 'd{}'.format(DECAY)])\n",
+    "PATTERN = '_'.join([PATTERN, 'lr{}'.format(LR)])\n",
+    "PATTERN"
    ]
   },
   {
@@ -73,6 +104,17 @@
     "NAMES = {'tasmin': 'minimum temperature', 'tasmax': 'maximum temperature', 'pr': 'precipitation'}"
    ]
   },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "f1ae6427-67e3-4df3-b84a-5aad32504d79",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "# whether to search for the defined pattern\n",
+    "SEARCH = False    "
+   ]
+  },
   {
    "cell_type": "markdown",
    "id": "4d5d12c5-50fd-4c5c-9240-c3df78e49b44",
@@ -97,13 +139,17 @@
     "import xarray as xr\n",
     "import numpy as np\n",
     "import matplotlib.pyplot as plt\n",
+    "from matplotlib import cm\n",
+    "from matplotlib.colors import ListedColormap\n",
     "from mpl_toolkits.axes_grid1.inset_locator import inset_axes\n",
     "import scipy.stats as stats\n",
     "from IPython.display import Image\n",
     "from sklearn.metrics import r2_score\n",
     "\n",
     "# locals\n",
-    "from climax.main.io import ERA5_PATH, OBS_PATH, TARGET_PATH\n",
+    "from climax.main.io import ERA5_PATH, OBS_PATH, TARGET_PATH, DEM_PATH, MODEL_PATH\n",
+    "from climax.core.utils import plot_loss\n",
+    "from climax.core.dataset import ERA5Dataset\n",
     "from pysegcnn.core.utils import search_files"
    ]
   },
@@ -160,16 +206,13 @@
   {
    "cell_type": "code",
    "execution_count": null,
-   "id": "e8cb1011-32dd-4e32-9692-4a4aa009869f",
+   "id": "ad84738f-3600-47e4-8178-4ee4ad5e3fc3",
    "metadata": {},
    "outputs": [],
    "source": [
-    "# construct file pattern to match\n",
-    "PATTERN = '_'.join([MODEL, PREDICTAND, PPREDICTORS, *PLEVELS])\n",
-    "PATTERN = '_'.join([PATTERN, SPREDICTORS]) if SPREDICTORS else PATTERN\n",
-    "PATTERN = '_'.join([PATTERN, DEM]) if DEM else PATTERN\n",
-    "PATTERN = '_'.join([PATTERN, DEM_FEATURES]) if DEM_FEATURES else PATTERN\n",
-    "PATTERN = '_'.join([PATTERN, DOY]) if DOY else PATTERN"
+    "# digital elevation model\n",
+    "dem = xr.open_dataset(search_files(DEM_PATH, 'eu_dem_v11_stt.nc').pop())\n",
+    "dem = dem.Band1.to_dataset().rename({'Band1': 'elevation'})"
    ]
   },
   {
@@ -179,8 +222,23 @@
    "metadata": {},
    "outputs": [],
    "source": [
-    "# model predictions and observations NetCDF\n",
-    "y_pred = xr.open_dataset(search_files(TARGET_PATH.joinpath(PREDICTAND), '.'.join([PATTERN, 'nc$'])).pop())\n",
+    "# model predictions\n",
+    "if SEARCH:\n",
+    "    y_pred = xr.open_dataset(search_files(TARGET_PATH.joinpath(PREDICTAND), '.'.join([PATTERN, 'nc$'])).pop())\n",
+    "else:\n",
+    "    filename = 'USegNet_tasmin_ztuvq_500_850_p_dem_doy_L1Loss_Adam_d1e-03_lr3e-04.nc'\n",
+    "    file = search_files(TARGET_PATH, filename).pop()\n",
+    "    y_pred = xr.open_dataset(file)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "8bb55ddf-1d19-4f1a-8b31-9da778b287fd",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "# observations\n",
     "if PREDICTAND == 'tas':\n",
     "    # read both tasmax and tasmin\n",
     "    tasmax = xr.open_dataset(search_files(OBS_PATH.joinpath('tasmax'), '.nc$').pop())\n",
@@ -251,76 +309,130 @@
    "outputs": [],
    "source": [
     "# align datasets and mask missing values in model predictions\n",
-    "y_true, y_pred, y_refe = xr.align(y_true, y_pred, y_refe, join='override')\n",
+    "y_true, y_refe, y_pred = xr.align(y_true[PREDICTAND], y_refe[PREDICTAND], y_pred[PREDICTAND], join='override')\n",
     "y_pred = y_pred.where(~np.isnan(y_true), other=np.nan)\n",
     "y_refe = y_refe.where(~np.isnan(y_true), other=np.nan)"
    ]
   },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "dee1b549-c5a6-44cd-9fe2-563ccfea658d",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "# align digital elevation model\n",
+    "_, dem = xr.align(y_true.isel(time=0), dem, join='override')\n",
+    "dem = dem.where(~np.isnan(y_true.isel(time=0)), other=np.nan)"
+   ]
+  },
   {
    "cell_type": "markdown",
-   "id": "ddebdf9f-862c-461e-aa57-cd344d54eee9",
+   "id": "9a696a5c-1de8-4027-b3f8-64045b7333fb",
    "metadata": {},
    "source": [
-    "## Model validation: temperature"
+    "### Model convergence"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "98349528-6a4d-469f-86f4-60ce4d18e719",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "# plot model state\n",
+    "model_state = MODEL_PATH.joinpath(PREDICTAND, '.'.join([PATTERN, 'pt']))\n",
+    "try:\n",
+    "    fig = plot_loss(model_state, step=5)\n",
+    "except FileNotFoundError:\n",
+    "    pass"
    ]
   },
   {
    "cell_type": "markdown",
-   "id": "ab15d557-c7ea-40c0-9977-a3d410fea784",
+   "id": "ddebdf9f-862c-461e-aa57-cd344d54eee9",
    "metadata": {},
    "source": [
-    "### Coefficient of determination"
+    "## Model validation: temperature"
    ]
   },
   {
    "cell_type": "code",
    "execution_count": null,
-   "id": "bb8223e0-c5cb-477c-8c31-3dfb94b20ce1",
+   "id": "dda74f76-2b3e-484c-b56d-d1d3d23784c2",
    "metadata": {},
    "outputs": [],
    "source": [
-    "# group timeseries by month and calculate mean over time and space\n",
-    "y_pred_ac = y_pred.groupby('time.month').mean(dim=('time', 'y', 'x'))\n",
-    "y_true_ac = y_true.groupby('time.month').mean(dim=('time', 'y', 'x'))"
+    "# calculate monthly means\n",
+    "y_pred_mm = y_pred.groupby('time.month').mean(dim=('time'))\n",
+    "y_true_mm = y_true.groupby('time.month').mean(dim=('time'))"
    ]
   },
   {
    "cell_type": "code",
    "execution_count": null,
-   "id": "619d5dc9-4d36-43a3-b23c-a4ea51229c78",
+   "id": "0993676b-19de-4426-9c8c-c3897b4f9bf7",
    "metadata": {},
    "outputs": [],
    "source": [
-    "# get predicted and observed values over entire time series and grid points\n",
-    "y_pred_values = y_pred[PREDICTAND].groupby('time.month').mean(dim='time').values.flatten()\n",
-    "y_true_values = y_true[PREDICTAND].groupby('time.month').mean(dim='time').values.flatten()"
+    "# calculate mean annual cycle\n",
+    "y_pred_ac = y_pred_mm.mean(dim=('x', 'y'))\n",
+    "y_true_ac = y_true_mm.mean(dim=('x', 'y'))"
    ]
   },
   {
    "cell_type": "code",
    "execution_count": null,
-   "id": "49dff6ce-a629-460b-a43b-d1a0ef447351",
-   "metadata": {
-    "tags": []
-   },
+   "id": "c3e5f725-9d6a-4837-88a6-7ad1baa06027",
+   "metadata": {},
    "outputs": [],
    "source": [
-    "# apply mask of valid pixels\n",
-    "mask = (~np.isnan(y_pred_values) & ~np.isnan(y_true_values))\n",
-    "y_pred_values = y_pred_values[mask]\n",
-    "y_true_values = y_true_values[mask]"
+    "# compute daily anomalies\n",
+    "y_pred_anom = ERA5Dataset.anomalies(y_pred, timescale='time.month')\n",
+    "y_true_anom = ERA5Dataset.anomalies(y_true, timescale='time.month')"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "id": "ab15d557-c7ea-40c0-9977-a3d410fea784",
+   "metadata": {},
+   "source": [
+    "### Coefficient of determination"
    ]
   },
   {
    "cell_type": "code",
    "execution_count": null,
-   "id": "e9e46770-4176-4257-8ea2-7050d3325e98",
+   "id": "619d5dc9-4d36-43a3-b23c-a4ea51229c78",
    "metadata": {},
    "outputs": [],
    "source": [
-    "# calculate coefficient of determination\n",
-    "r2 = r2_score(y_true_values, y_pred_values)\n",
-    "r2"
+    "# get predicted and observed daily anomalies\n",
+    "y_pred_av = y_pred_anom.values.flatten()\n",
+    "y_true_av = y_true_anom.values.flatten()\n",
+    "\n",
+    "# apply mask of valid pixels\n",
+    "mask = (~np.isnan(y_pred_av) & ~np.isnan(y_true_av))\n",
+    "y_pred_av = y_pred_av[mask]\n",
+    "y_true_av = y_true_av[mask]\n",
+    "\n",
+    "# get predicted and observed monthly means\n",
+    "y_pred_mv = y_pred_mm.values.flatten()\n",
+    "y_true_mv = y_true_mm.values.flatten()\n",
+    "\n",
+    "# apply mask of valid pixels\n",
+    "mask = (~np.isnan(y_pred_mv) & ~np.isnan(y_true_mv))\n",
+    "y_pred_mv = y_pred_mv[mask]\n",
+    "y_true_mv = y_true_mv[mask]\n",
+    "\n",
+    "# calculate coefficient of determination on monthly means\n",
+    "r2_mm = r2_score(y_true_mv, y_pred_mv)\n",
+    "print('R2 on monthly means: {:.2f}'.format(r2_mm))\n",
+    "\n",
+    "# calculate coefficient of determination on daily anomalies\n",
+    "r2_anom = r2_score(y_true_av, y_pred_av)\n",
+    "print('R2 on daily anomalies: {:.2f}'.format(r2_anom))"
    ]
   },
   {
@@ -338,15 +450,18 @@
     "# ax.plot(y_true_values[subset], y_pred_values[subset], 'o', alpha=.5, markeredgecolor='grey', markerfacecolor='none', markersize=3);\n",
     "\n",
     "# plot entire dataset\n",
-    "ax.plot(y_true_values, y_pred_values, 'o', alpha=.5, markeredgecolor='grey', markerfacecolor='none', markersize=3);\n",
+    "ax.plot(y_true_mv, y_pred_mv, 'o', alpha=.5, markeredgecolor='grey', markerfacecolor='none', markersize=3);\n",
     "\n",
     "# plot 1:1 mapping line\n",
-    "interval = np.arange(-15, 45, 5)\n",
-    "#interval = np.arange(-30, 35, 5)\n",
+    "if PREDICTAND == 'tasmin':\n",
+    "    interval = np.arange(-25, 30, 5)\n",
+    "else:\n",
+    "    interval = np.arange(-15, 45, 5)\n",
     "ax.plot(interval, interval, color='k', lw=2, ls='--')\n",
     "\n",
-    "# add coefficient of determination: calculated on entire dataset!\n",
-    "ax.text(interval[-1] - 0.5, interval[0] + 0.5, s='Coefficient of determination R$^2$ = {:.2f}'.format(r2), ha='right', fontsize=18)\n",
+    "# add coefficients of determination\n",
+    "ax.text(interval[-1] - 0.5, interval[0] + 0.5, s='R$^2$ (monthly means)= {:.2f}'.format(r2_mm), ha='right', fontsize=18)\n",
+    "ax.text(interval[-1] - 0.5, interval[0] + 2.5, s='R$^2$ (daily anomalies) = {:.2f}'.format(r2_anom), ha='right', fontsize=18)\n",
     "\n",
     "# format axes\n",
     "ax.set_ylim(interval[0], interval[-1])\n",
@@ -361,21 +476,102 @@
     "\n",
     "# add axis for annual cycle\n",
     "axins = inset_axes(ax, width=\"30%\", height=\"40%\", loc=2, borderpad=1)\n",
-    "axins.plot(y_pred_ac[PREDICTAND].values, ls='--', color='k', label='Predicted')\n",
-    "axins.plot(y_true_ac[PREDICTAND].values, ls='-', color='k', label='Observed')\n",
+    "axins.plot(y_pred_ac.values, ls='--', color='k', label='Predicted')\n",
+    "axins.plot(y_true_ac.values, ls='-', color='k', label='Observed')\n",
     "axins.legend(frameon=False, fontsize=12, loc='lower center');\n",
     "axins.yaxis.tick_right()\n",
-    "#axins.set_yticks(np.arange(-10, 11, 2))\n",
-    "#axins.set_yticklabels(np.arange(-10, 11, 2), fontsize=12)\n",
-    "axins.set_yticks(np.arange(-0, 22, 2))\n",
-    "axins.set_yticklabels(np.arange(0, 22, 2), fontsize=12)\n",
+    "axins.set_yticks(np.arange(-10, 11, 2) if PREDICTAND == 'tasmin' else np.arange(0, 20, 2))\n",
+    "axins.set_yticklabels(np.arange(-10, 11, 2) if PREDICTAND == 'tasmin' else np.arange(0, 20, 2), fontsize=12)\n",
     "axins.set_xticks(np.arange(0, 12))\n",
     "axins.set_xticklabels([calendar.month_name[i + 1] for i in np.arange(0, 12)], rotation=90, fontsize=12)\n",
-    "# axins.set_ylabel('{} / (°C)'.format(NAMES[PREDICTAND].capitalize()), fontsize=14)\n",
-    "# axins.set_xlabel('Month', fontsize=14);\n",
+    "axins.set_title('Mean annual cycle', fontsize=14, pad=5)\n",
+    "\n",
+    "# save figure\n",
+    "fig.savefig('../Notebooks/Figures/{}_r2_{}.png'.format(PREDICTAND, LOSS), dpi=300, bbox_inches='tight')"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "id": "4e50ff44-5598-4694-9ea3-b06ee50ee21f",
+   "metadata": {},
+   "source": [
+    "### Coefficient of determination: Spatially"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "816fad18-b365-4c62-b669-e13e7dc8d322",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "# iterate over the grid points\n",
+    "r2 = np.ones((2, len(y_pred.x), len(y_pred.y)), dtype=np.float32) * np.nan\n",
+    "for i, _ in enumerate(y_pred.x):\n",
+    "    for j, _ in enumerate(y_pred.y):\n",
+    "        # get observed and predicted monthly mean temperature for current grid point\n",
+    "        point_true = y_true_mm.isel(x=i, y=j)\n",
+    "        point_pred = y_pred_mm.isel(x=i, y=j)\n",
+    "\n",
+    "        # remove missing values\n",
+    "        mask = ((~np.isnan(point_true)) & (~np.isnan(point_pred)))\n",
+    "        point_true = point_true[mask].values\n",
+    "        point_pred = point_pred[mask].values\n",
+    "        if point_true.size < 1:\n",
+    "            continue\n",
+    "        \n",
+    "        # get anomalies for current grid point\n",
+    "        point_anom_true = y_true_anom.isel(x=i, y=j)\n",
+    "        point_anom_pred = y_pred_anom.isel(x=i, y=j)\n",
+    "        \n",
+    "        # remove missing values\n",
+    "        mask_anom = ((~np.isnan(point_anom_true)) & (~np.isnan(point_anom_pred)))\n",
+    "        point_anom_true = point_anom_true[mask_anom].values\n",
+    "        point_anom_pred = point_anom_pred[mask_anom].values\n",
+    "\n",
+    "        # compute coefficient of determination\n",
+    "        r2[0, j, i] = r2_score(point_true, point_pred)\n",
+    "        r2[1, j, i] = r2_score(point_anom_true, point_anom_pred)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "b0f0e83c-3dfd-46dd-bd02-7ff1ba8f8573",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "# define color map: red to green\n",
+    "grn = cm.get_cmap('Greens', 128)\n",
+    "red = cm.get_cmap('Reds_r', 128)\n",
+    "red2green = ListedColormap(np.vstack((red(np.linspace(0, 1, 128)),\n",
+    "                                      grn(np.linspace(0, 1, 128)))))\n",
+    "\n",
+    "# plot coefficient of determination\n",
+    "vmin, vmax = -1, 1\n",
+    "fig, ax = plt.subplots(1, 2, figsize=(10, 10))\n",
+    "\n",
+    "# monthly means\n",
+    "im0 = ax[0].imshow(r2[0, :], origin='lower', cmap=red2green, vmin=vmin, vmax=vmax)\n",
+    "ax[0].text(x=r2.shape[1] - 2, y=2, s='Average R$^2$: {:.2f}'.format(np.nanmean(r2[0, :])), fontsize=14, ha='right');\n",
+    "ax[0].set_axis_off()\n",
+    "ax[0].set_title('Monthly mean {} (°C)'.format(NAMES[PREDICTAND]), fontsize=14, pad=10);\n",
+    "\n",
+    "# daily anomalies\n",
+    "im1 = ax[1].imshow(r2[1, :], origin='lower', cmap=red2green, vmin=vmin, vmax=vmax)\n",
+    "ax[1].text(x=r2.shape[1] - 2, y=2, s='Average R$^2$: {:.2f}'.format(np.nanmean(r2[1, :])), fontsize=14, ha='right');\n",
+    "ax[1].set_axis_off()\n",
+    "ax[1].set_title('Daily {} anomaly (°C)'.format(NAMES[PREDICTAND]), fontsize=14, pad=10);\n",
+    "\n",
+    "# add colorbar \n",
+    "cbar_ax_bias = fig.add_axes([ax[1].get_position().x1 + 0.05, ax[1].get_position().y0,\n",
+    "                             0.03, ax[1].get_position().y1 - ax[1].get_position().y0])\n",
+    "cbar_bias = fig.colorbar(im0, cax=cbar_ax_bias)\n",
+    "cbar_bias.set_label(label='Coefficient of determination R$^2$', fontsize=14)\n",
+    "cbar_bias.ax.tick_params(labelsize=14, pad=10)\n",
     "\n",
     "# save figure\n",
-    "fig.savefig('../Notebooks/Figures/{}_r2.png'.format(PREDICTAND), dpi=300, bbox_inches='tight')"
+    "fig.savefig('../Notebooks/Figures/{}_r2_spatial_{}.png'.format(PREDICTAND, LOSS), dpi=300, bbox_inches='tight')"
    ]
   },
   {
@@ -407,9 +603,8 @@
     "y_true_yearly_avg = y_true.groupby('time.year').mean(dim='time')\n",
     "bias_yearly_avg = y_pred_yearly_avg - y_true_yearly_avg\n",
     "bias_yearly_avg_ref = y_refe_yearly_avg - y_true_yearly_avg\n",
-    "for var in bias_yearly_avg:\n",
-    "    print('(Model) Yearly average bias of {}: {:.2f}°C'.format(var, bias_yearly_avg[var].mean().item()))\n",
-    "    print('(ERA-5) Yearly average bias of {}: {:.2f}°C'.format(var, bias_yearly_avg_ref.mean().to_array().values.item()))"
+    "print('(Model) Yearly average bias of {}: {:.2f}°C'.format(PREDICTAND, bias_yearly_avg.mean().item()))\n",
+    "print('(ERA-5) Yearly average bias of {}: {:.2f}°C'.format(PREDICTAND, bias_yearly_avg_ref.mean().item()))"
    ]
   },
   {
@@ -422,9 +617,8 @@
     "# mean absolute error over reference period\n",
     "mae_avg = np.abs(y_pred_yearly_avg - y_true_yearly_avg).mean()\n",
     "mae_avg_ref = np.abs(y_refe_yearly_avg - y_true_yearly_avg).mean()\n",
-    "for var in mae_avg:\n",
-    "    print('(Model) Yearly average MAE of {}: {:.2f}°C'.format(var, mae_avg[var].item()))\n",
-    "    print('(ERA-5) Yearly average MAE of {}: {:.2f}°C'.format(var, mae_avg_ref.mean().to_array().values.item()))"
+    "print('(Model) Yearly average MAE of {}: {:.2f}°C'.format(PREDICTAND, mae_avg.mean().item()))\n",
+    "print('(ERA-5) Yearly average MAE of {}: {:.2f}°C'.format(PREDICTAND, mae_avg_ref.mean().item()))"
    ]
   },
   {
@@ -435,11 +629,10 @@
    "outputs": [],
    "source": [
     "# root mean squared error over reference period\n",
-    "rmse_avg = ((y_pred_yearly_avg - y_true_yearly_avg) ** 2).mean()\n",
-    "rmse_avg_ref = ((y_refe_yearly_avg - y_true_yearly_avg) **2).mean()\n",
-    "for var in rmse_avg:\n",
-    "    print('(Model) Yearly average RMSE of {}: {:.2f}°C'.format(var, rmse_avg[var].item()))\n",
-    "    print('(ERA-5) Yearly average RMSE of {}: {:.2f}°C'.format(var, rmse_avg_ref.mean().to_array().values.item()))"
+    "rmse_avg = np.sqrt(((y_pred_yearly_avg - y_true_yearly_avg) ** 2)).mean()\n",
+    "rmse_avg_ref = np.sqrt(((y_refe_yearly_avg - y_true_yearly_avg) **2)).mean()\n",
+    "print('(Model) Yearly average RMSE of {}: {:.2f}°C'.format(PREDICTAND, rmse_avg.mean().item()))\n",
+    "print('(ERA-5) Yearly average RMSE of {}: {:.2f}°C'.format(PREDICTAND, rmse_avg_ref.mean().item()))"
    ]
   },
   {
@@ -450,14 +643,82 @@
    "outputs": [],
    "source": [
     "# Pearson's correlation coefficient over reference period\n",
-    "for var in y_pred_yearly_avg:\n",
-    "    correlations = []\n",
-    "    for year in y_pred_yearly_avg.year:\n",
-    "        y_p = y_pred_yearly_avg[var].sel(year=year).values        \n",
-    "        y_t = y_true_yearly_avg[var].sel(year=year).values\n",
-    "        r, _ = stats.pearsonr(y_p[~np.isnan(y_p)], y_t[~np.isnan(y_t)])\n",
-    "        correlations.append(r)\n",
-    "print('Yearly average Pearson correlation coefficient for {}: {:.2f}'.format(var, np.asarray(r).mean()))"
+    "correlations = []\n",
+    "for year in y_pred_yearly_avg.year:\n",
+    "    y_p = y_pred_yearly_avg.sel(year=year).values        \n",
+    "    y_t = y_true_yearly_avg.sel(year=year).values\n",
+    "    r, _ = stats.pearsonr(y_p[~np.isnan(y_p)], y_t[~np.isnan(y_t)])\n",
+    "    print('({:d}): {:.2f}'.format(year.item(), r))\n",
+    "    correlations.append(r)\n",
+    "print('Yearly average Pearson correlation coefficient for {}: {:.2f}'.format(PREDICTAND, np.asarray(r).mean()))"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "4ec9b998-bc6c-4ef4-a602-ed87ae3260ae",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "# plot yearly average bias of reference vs. prediction\n",
+    "vmin, vmax = -2, 2\n",
+    "fig, axes = plt.subplots(1, 2, figsize=(16, 8), sharex=True, sharey=True)\n",
+    "\n",
+    "# plot bias of ERA-5 reference\n",
+    "reference = bias_yearly_avg_ref.mean(dim='year')\n",
+    "im1 = axes[0].imshow(reference.values, origin='lower', cmap='RdBu_r', vmin=vmin, vmax=vmax)\n",
+    "axes[0].text(x=reference.shape[0] - 2, y=2, s='Average: {:.2f}°C'.format(reference.mean().item()), fontsize=14, ha='right')\n",
+    "\n",
+    "# plot MAE of model\n",
+    "prediction = bias_yearly_avg.mean(dim='year')\n",
+    "im2 = axes[1].imshow(prediction.values, origin='lower', cmap='RdBu_r', vmin=vmin, vmax=vmax)\n",
+    "axes[1].text(x=reference.shape[0] - 2, y=2, s='Average: {:.2f}°C'.format(prediction.mean().item()), fontsize=14, ha='right')\n",
+    "\n",
+    "# plot topography\n",
+    "# im_dem = axes[2].imshow(dem['elevation'].values, origin='lower', cmap='terrain', vmin=0, vmax=4000)\n",
+    "\n",
+    "# set titles\n",
+    "axes[0].set_title('ERA-5', fontsize=14, pad=10);\n",
+    "axes[1].set_title('DCEDN', fontsize=14, pad=10);\n",
+    "# axes[2].set_title('Copernicus EU-DEM v1.1', fontsize=14, pad=10)\n",
+    "\n",
+    "# adjust axes\n",
+    "for ax in axes.flat:\n",
+    "    ax.axes.get_xaxis().set_ticklabels([])\n",
+    "    ax.axes.get_xaxis().set_ticks([])\n",
+    "    ax.axes.get_yaxis().set_ticklabels([])\n",
+    "    ax.axes.get_yaxis().set_ticks([])\n",
+    "    ax.axes.axis('tight')\n",
+    "    ax.set_xlabel('')\n",
+    "    ax.set_ylabel('')\n",
+    "    ax.set_axis_off()\n",
+    "\n",
+    "# adjust figure\n",
+    "# fig.suptitle('Average yearly mean absolute error: 1991 - 2010', fontsize=20);\n",
+    "fig.subplots_adjust(hspace=0, wspace=0, top=0.85)\n",
+    "\n",
+    "# add colorbar for dem\n",
+    "axes = axes.flatten()\n",
+    "# cbar_ax_bias = fig.add_axes([axes[-1].get_position().x1 + 0.01, axes[-1].get_position().y0,\n",
+    "#                              0.01, axes[-1].get_position().y1 - axes[-1].get_position().y0])\n",
+    "# cbar_bias = fig.colorbar(im_dem, cax=cbar_ax_bias)\n",
+    "# cbar_bias.set_label(label='Elevation (m)', fontsize=14)\n",
+    "# cbar_bias.ax.tick_params(labelsize=14, pad=10)\n",
+    "\n",
+    "# add colorbar for predictand\n",
+    "cbar_ax_predictand = fig.add_axes([axes[0].get_position().x0, axes[0].get_position().y0 - 0.1,\n",
+    "                                   axes[-1].get_position().x0 - axes[0].get_position().x0,\n",
+    "                                   0.03])\n",
+    "cbar_predictand = fig.colorbar(im1, cax=cbar_ax_predictand, orientation='horizontal')\n",
+    "cbar_predictand.set_label(label='Mean error (°C)', fontsize=14)\n",
+    "cbar_predictand.ax.tick_params(labelsize=14, pad=10)\n",
+    "\n",
+    "# add metrics: MAE and RMSE\n",
+    "#axes[1].text(x=ds.shape[0] - 2, y=2, s='MAE = {:.1f}'.format(mae_avg[NAMES[PREDICTAND]].item()) + 'mm day$^{-1}$', fontsize=14, ha='right')\n",
+    "#axes[1].text(x=ds.shape[0] - 2, y=12, s='RMSE = {:.1f}'.format(rmse_avg[NAMES[PREDICTAND]].item()) + 'mm$^2$ day$^{-2}$', fontsize=14, ha='right')\n",
+    "\n",
+    "# save figure\n",
+    "fig.savefig('../Notebooks/Figures/{}_rbias_ERA_{}.png'.format(PREDICTAND, LOSS), dpi=300, bbox_inches='tight')"
    ]
   },
   {
@@ -469,22 +730,20 @@
    "source": [
     "# plot average of observation, prediction, and bias\n",
     "vmin, vmax = (-15, 15) if PREDICTAND == 'tasmin' else (0, 25)\n",
-    "fig, axes = plt.subplots(len(y_pred_yearly_avg.data_vars), 3, figsize=(24, len(y_pred_yearly_avg.data_vars) * 6),\n",
-    "                         sharex=True, sharey=True)\n",
-    "axes = axes.reshape(len(y_pred_yearly_avg.data_vars), -1)\n",
-    "for i, var in enumerate(y_pred_yearly_avg):\n",
-    "    for ds, ax in zip([y_true_yearly_avg, y_pred_yearly_avg, bias_yearly_avg], axes[i, ...]):\n",
-    "        if ds is bias_yearly_avg:\n",
-    "            ds = ds[var].mean(dim='year')\n",
-    "            im2 = ax.imshow(ds.values, origin='lower', cmap='RdBu_r', vmin=-2, vmax=2)\n",
-    "            ax.text(x=ds.shape[0] - 2, y=2, s='Average: {:.2f}°C'.format(ds.mean().item()), fontsize=14, ha='right')\n",
-    "        else:\n",
-    "            im1 = ax.imshow(ds[var].mean(dim='year').values, origin='lower', cmap='RdYlBu_r', vmin=vmin, vmax=vmax)\n",
+    "fig, axes = plt.subplots(1, 3, figsize=(24, 8), sharex=True, sharey=True)\n",
+    "axes = axes.flatten()\n",
+    "for ds, ax in zip([y_true_yearly_avg, y_pred_yearly_avg, bias_yearly_avg], axes):\n",
+    "    if ds is bias_yearly_avg:\n",
+    "        ds = ds.mean(dim='year')\n",
+    "        im2 = ax.imshow(ds.values, origin='lower', cmap='RdBu_r', vmin=-2, vmax=2)\n",
+    "        ax.text(x=ds.shape[0] - 2, y=2, s='Average: {:.2f}°C'.format(ds.mean().item()), fontsize=14, ha='right')\n",
+    "    else:\n",
+    "        im1 = ax.imshow(ds.mean(dim='year').values, origin='lower', cmap='RdYlBu_r', vmin=vmin, vmax=vmax)\n",
     "        \n",
     "# set titles\n",
-    "axes[0, 0].set_title('Observed', fontsize=16, pad=10);\n",
-    "axes[0, 1].set_title('Predicted', fontsize=16, pad=10);\n",
-    "axes[0, 2].set_title('Bias', fontsize=16, pad=10);\n",
+    "axes[0].set_title('Observed', fontsize=16, pad=10);\n",
+    "axes[1].set_title('Predicted', fontsize=16, pad=10);\n",
+    "axes[2].set_title('Bias', fontsize=16, pad=10);\n",
     "\n",
     "# adjust axes\n",
     "for ax in axes.flat:\n",
@@ -498,7 +757,7 @@
     "\n",
     "# adjust figure\n",
     "fig.suptitle('Average {}: 1991 - 2010'.format(NAMES[PREDICTAND]), fontsize=20);\n",
-    "fig.subplots_adjust(hspace=0, wspace=0, top=0.85)c\n",
+    "fig.subplots_adjust(hspace=0, wspace=0, top=0.85)\n",
     "\n",
     "# add colorbar for bias\n",
     "axes = axes.flatten()\n",
@@ -517,8 +776,8 @@
     "cbar_predictand.ax.tick_params(labelsize=14)\n",
     "\n",
     "# add metrics: MAE and RMSE\n",
-    "axes[1].text(x=ds.shape[0] - 2, y=2, s='MAE = {:.2f}°C'.format(mae_avg[PREDICTAND].item()), fontsize=14, ha='right')\n",
-    "axes[1].text(x=ds.shape[0] - 2, y=12, s='RMSE = {:.2f}°C$^2$'.format(rmse_avg[PREDICTAND].item()), fontsize=14, ha='right')\n",
+    "axes[1].text(x=ds.shape[0] - 2, y=2, s='MAE = {:.2f}°C'.format(mae_avg.mean().item()), fontsize=14, ha='right')\n",
+    "axes[1].text(x=ds.shape[0] - 2, y=12, s='RMSE = {:.2f}°C'.format(rmse_avg.mean().item()), fontsize=14, ha='right')\n",
     "\n",
     "# save figure\n",
     "fig.savefig('../Notebooks/Figures/{}_average_bias.png'.format(PREDICTAND), dpi=300, bbox_inches='tight')"
@@ -563,9 +822,8 @@
    "outputs": [],
    "source": [
     "# print average bias per season: ERA-5\n",
-    "for var in bias_snl_ref.data_vars:\n",
-    "    for season in bias_snl_ref[PREDICTAND].season:\n",
-    "        print('(ERA-5) Average bias of mean {} for season {}: {:.1f}°C'.format(var, season.values.item(), bias_snl_ref[var].sel(season=season).mean().item()))"
+    "for season in bias_snl_ref.season:\n",
+    "    print('(ERA-5) Average bias of mean {} for season {}: {:.1f}°C'.format(PREDICTAND, season.item(), bias_snl_ref.sel(season=season).mean().item()))"
    ]
   },
   {
@@ -576,9 +834,8 @@
    "outputs": [],
    "source": [
     "# print average bias per season: model\n",
-    "for var in bias_snl.data_vars:\n",
-    "    for season in bias_snl[PREDICTAND].season:\n",
-    "        print('(Model) Average bias of mean {} for season {}: {:.1f}°C'.format(var, season.values.item(), bias_snl[var].sel(season=season).mean().item()))"
+    "for season in bias_snl.season:\n",
+    "    print('(Model) Average bias of mean {} for season {}: {:.1f}°C'.format(PREDICTAND, season.item(), bias_snl.sel(season=season).mean().item()))"
    ]
   },
   {
@@ -597,20 +854,21 @@
    "outputs": [],
    "source": [
     "# plot seasonal differences\n",
-    "fig, axes = plt.subplots(nrows=2, ncols=3, figsize=(24, 12), sharex=True, sharey=True)\n",
+    "seasons = ('DJF', 'JJA')\n",
+    "fig, axes = plt.subplots(nrows=1, ncols=len(seasons) + 1, figsize=(24, 8), sharex=True, sharey=True)\n",
     "axes = axes.flatten()\n",
     "\n",
     "# plot annual average bias\n",
-    "ds = bias_yearly_avg[PREDICTAND].mean(dim='year')\n",
+    "ds = bias_yearly_avg.mean(dim='year')\n",
     "im = axes[0].imshow(ds.values, origin='lower', cmap='RdBu_r', vmin=-2, vmax=2)\n",
     "axes[0].set_title('Annual', fontsize=16);\n",
     "axes[0].text(x=ds.shape[0] - 2, y=2, s='Average: {:.2f}°C'.format(ds.mean().item()), fontsize=14, ha='right')\n",
     "\n",
     "# plot seasonal average bias\n",
-    "for ax, season in zip(axes[1:], bias_snl.season):\n",
-    "    ds = bias_snl[PREDICTAND].sel(season=season)\n",
+    "for ax, season in zip(axes[1:], seasons):\n",
+    "    ds = bias_snl.sel(season=season)\n",
     "    ax.imshow(ds.values, origin='lower', cmap='RdBu_r', vmin=-2, vmax=2)\n",
-    "    ax.set_title(season.item(), fontsize=16);\n",
+    "    ax.set_title(season, fontsize=16);\n",
     "    ax.text(x=ds.shape[0] - 2, y=2, s='Average: {:.2f}°C'.format(ds.mean().item()), fontsize=14, ha='right')\n",
     "\n",
     "# adjust axes\n",
@@ -622,13 +880,10 @@
     "    ax.axes.axis('tight')\n",
     "    ax.set_xlabel('')\n",
     "    ax.set_ylabel('')\n",
-    "    \n",
-    "# turn off last axis\n",
-    "axes[-1].set_visible(False)\n",
     "\n",
     "# adjust figure\n",
     "fig.suptitle('Average bias of {}: 1991 - 2010'.format(NAMES[PREDICTAND]), fontsize=20);\n",
-    "fig.subplots_adjust(hspace=0.1, wspace=0, top=0.925)\n",
+    "fig.subplots_adjust(hspace=0, wspace=0, top=0.85)\n",
     "\n",
     "# add colorbar\n",
     "cbar_ax_predictand = fig.add_axes([axes[-1].get_position().x1 + 0.01, axes[-1].get_position().y0,\n",
@@ -638,51 +893,7 @@
     "cbar_predictand.ax.tick_params(labelsize=14)\n",
     "\n",
     "# save figure\n",
-    "fig.savefig('../Notebooks/Figures/{}_average_bias_seasonal.png'.format(PREDICTAND), dpi=300, bbox_inches='tight')"
-   ]
-  },
-  {
-   "cell_type": "markdown",
-   "id": "41600269-2f8c-4717-8f74-b3dfaef60359",
-   "metadata": {},
-   "source": [
-    "Calculate the mean annual cycle:"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": null,
-   "id": "c9f27c01-4dfc-4d16-8d29-00e69b7794cd",
-   "metadata": {},
-   "outputs": [],
-   "source": [
-    "# group timeseries by month and calculate mean over time and space\n",
-    "y_pred_ac = y_pred.groupby('time.month').mean(dim=('time', 'y', 'x'))\n",
-    "y_true_ac = y_true.groupby('time.month').mean(dim=('time', 'y', 'x'))"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": null,
-   "id": "bcc63e73-2636-49c1-a66b-241eb5407e2e",
-   "metadata": {},
-   "outputs": [],
-   "source": [
-    "# plot mean annual cycle\n",
-    "fig, ax = plt.subplots(1, 1, figsize=(10, 10))\n",
-    "ax.plot(y_pred_ac[PREDICTAND].values, ls='--', color='k', label='Predicted')\n",
-    "ax.plot(y_true_ac[PREDICTAND].values, ls='-', color='k', label='Observed')\n",
-    "ax.legend(frameon=False, fontsize=14);\n",
-    "ax.set_yticks(np.arange(np.floor(y_true_ac[PREDICTAND].min().item()), np.ceil(y_true_ac[PREDICTAND].max().item()) + 1, 1))\n",
-    "ax.set_yticklabels(np.arange(np.floor(y_true_ac[PREDICTAND].min().item()), np.ceil(y_true_ac[PREDICTAND].max().item()) + 1, 1), fontsize=12)\n",
-    "ax.set_xticks(np.arange(0, 12))\n",
-    "ax.set_xticklabels([calendar.month_name[i + 1] for i in np.arange(0, 12)], rotation=90, fontsize=12)\n",
-    "ax.set_title('Mean annual cycle of {}: 1991 - 2010'.format(NAMES[PREDICTAND]), pad=20, fontsize=16);\n",
-    "ax.set_ylabel('{} / (°C)'.format(NAMES[PREDICTAND].capitalize()), fontsize=14)\n",
-    "ax.set_xlabel('Month', fontsize=14);\n",
-    "\n",
-    "# save figure\n",
-    "fig.savefig('../Notebooks/Figures/{}_mean_annual_cycle.png'.format(PREDICTAND), dpi=300, bbox_inches='tight')"
+    "fig.savefig('../Notebooks/Figures/{}_bias_seasonal_{}.png'.format(PREDICTAND, LOSS), dpi=300, bbox_inches='tight')"
    ]
   },
   {
@@ -739,8 +950,7 @@
    "outputs": [],
    "source": [
     "# bias of extreme quantile: ERA-5\n",
-    "for var in bias_ex_ref:\n",
-    "    print('(ERA-5) Yearly average bias for P{:.0f} of {}: {:.1f}°C'.format(quantile * 100, var, bias_ex_ref[var].mean().item()))"
+    "print('(ERA-5) Yearly average bias for P{:.0f} of {}: {:.1f}°C'.format(quantile * 100, PREDICTAND, bias_ex_ref.mean().item()))"
    ]
   },
   {
@@ -751,8 +961,7 @@
    "outputs": [],
    "source": [
     "# bias of extreme quantile: Model\n",
-    "for var in bias_ex:\n",
-    "    print('(Model) Yearly average bias for P{:.0f} of {}: {:.1f}°C'.format(quantile * 100, var, bias_ex[var].mean().item()))"
+    "print('(Model) Yearly average bias for P{:.0f} of {}: {:.1f}°C'.format(quantile * 100, PREDICTAND, bias_ex.mean().item()))"
    ]
   },
   {
@@ -775,8 +984,7 @@
    "outputs": [],
    "source": [
     "# mae of extreme quantile: ERA-5\n",
-    "for var in mae_ex_ref:\n",
-    "    print('(ERA-5) Yearly average MAE for P{:.0f} of {}: {:.1f}°C'.format(quantile * 100, var, mae_ex_ref[var].item()))"
+    "print('(ERA-5) Yearly average MAE for P{:.0f} of {}: {:.1f}°C'.format(quantile * 100, PREDICTAND, mae_ex_ref.item()))"
    ]
   },
   {
@@ -787,8 +995,7 @@
    "outputs": [],
    "source": [
     "# mae of extreme quantile: Model\n",
-    "for var in mae_ex:\n",
-    "    print('(Model) Yearly average MAE for P{:.0f} of {}: {:.1f}°C'.format(quantile * 100, var, mae_ex[var].item()))"
+    "print('(Model) Yearly average MAE for P{:.0f} of {}: {:.1f}°C'.format(quantile * 100, PREDICTAND, mae_ex.item()))"
    ]
   },
   {
@@ -799,8 +1006,8 @@
    "outputs": [],
    "source": [
     "# root mean squared error in extreme quantile\n",
-    "rmse_ex = ((y_pred_ex - y_true_ex) ** 2).mean()\n",
-    "rmse_ex_ref = ((y_refe_ex - y_true_ex) ** 2).mean()"
+    "rmse_ex = np.sqrt(((y_pred_ex - y_true_ex) ** 2).mean())\n",
+    "rmse_ex_ref = np.sqrt(((y_refe_ex - y_true_ex) ** 2).mean())"
    ]
   },
   {
@@ -811,8 +1018,7 @@
    "outputs": [],
    "source": [
     "# rmse of extreme quantile: ERA-5\n",
-    "for var in rmse_ex_ref:\n",
-    "    print('(ERA-5) Yearly average RMSE for P{:.0f} of {}: {:.1f}°C'.format(quantile * 100, var, rmse_ex_ref[var].item()))"
+    "print('(ERA-5) Yearly average RMSE for P{:.0f} of {}: {:.1f}°C'.format(quantile * 100, PREDICTAND, rmse_ex_ref.item()))"
    ]
   },
   {
@@ -823,8 +1029,7 @@
    "outputs": [],
    "source": [
     "# rmse of extreme quantile: Model\n",
-    "for var in rmse_ex:\n",
-    "    print('(Model) Yearly average RMSE for P{:.0f} of {}: {:.1f}°C'.format(quantile * 100, var, rmse_ex[var].item()))"
+    "print('(Model) Yearly average RMSE for P{:.0f} of {}: {:.1f}°C'.format(quantile * 100, PREDICTAND, rmse_ex.item()))"
    ]
   },
   {
@@ -838,23 +1043,21 @@
    "source": [
     "# plot extremes of observation, prediction, and bias\n",
     "vmin, vmax = (-20, 0) if PREDICTAND == 'tasmin' else (10, 40)\n",
-    "fig, axes = plt.subplots(len(y_pred_ex.data_vars), 3, figsize=(24, len(y_pred_ex.data_vars) * 6),\n",
-    "                         sharex=True, sharey=True)\n",
-    "axes = axes.reshape(len(y_pred_ex.data_vars), -1)\n",
-    "for i, var in enumerate(y_pred_ex):\n",
-    "    for ds, ax in zip([y_true_ex, y_pred_ex, bias_ex], axes[i, ...]):\n",
-    "        if ds is bias_ex:\n",
-    "            ds = ds[var].mean(dim='year')\n",
-    "            im2 = ax.imshow(ds.values, origin='lower', cmap='RdBu_r', vmin=-2, vmax=2)\n",
-    "            ax.text(x=ds.shape[0] - 2, y=2, s='Average: {:.2f}°C'.format(ds.mean().item()), fontsize=14, ha='right')\n",
-    "        else:\n",
-    "            im1 = ax.imshow(ds[var].mean(dim='year').values, origin='lower', cmap='Blues_r' if PREDICTAND == 'tasmin' else 'Reds',\n",
-    "                            vmin=vmin, vmax=vmax)\n",
-    "        \n",
+    "fig, axes = plt.subplots(1, 3, figsize=(24, 8), sharex=True, sharey=True)\n",
+    "axes = axes.flatten()\n",
+    "for ds, ax in zip([y_true_ex, y_pred_ex, bias_ex], axes):\n",
+    "    if ds is bias_ex:\n",
+    "        ds = ds.mean(dim='year')\n",
+    "        im2 = ax.imshow(ds.values, origin='lower', cmap='RdBu_r', vmin=-2, vmax=2)\n",
+    "        ax.text(x=ds.shape[0] - 2, y=2, s='Average: {:.2f}°C'.format(ds.mean().item()), fontsize=14, ha='right')\n",
+    "    else:\n",
+    "        im1 = ax.imshow(ds.mean(dim='year').values, origin='lower', cmap='Blues_r' if PREDICTAND == 'tasmin' else 'Reds',\n",
+    "                        vmin=vmin, vmax=vmax)\n",
+    "\n",
     "# set titles\n",
-    "axes[0, 0].set_title('Observed', fontsize=16, pad=10);\n",
-    "axes[0, 1].set_title('Predicted', fontsize=16, pad=10);\n",
-    "axes[0, 2].set_title('Bias', fontsize=16, pad=10);\n",
+    "axes[0].set_title('Observed', fontsize=16, pad=10);\n",
+    "axes[1].set_title('Predicted', fontsize=16, pad=10);\n",
+    "axes[2].set_title('Bias', fontsize=16, pad=10);\n",
     "\n",
     "# adjust axes\n",
     "for ax in axes.flat:\n",
@@ -887,11 +1090,11 @@
     "cbar_predictand.ax.tick_params(labelsize=14)\n",
     "\n",
     "# add metrics: MAE and RMSE\n",
-    "axes[1].text(x=ds.shape[0] - 2, y=2, s='MAE = {:.2f}°C'.format(mae_ex[PREDICTAND].item()), fontsize=14, ha='right')\n",
-    "axes[1].text(x=ds.shape[0] - 2, y=12, s='RMSE = {:.2f}°C$^2$'.format(rmse_ex[PREDICTAND].item()), fontsize=14, ha='right')\n",
+    "axes[1].text(x=ds.shape[0] - 2, y=2, s='MAE = {:.2f}°C'.format(mae_ex.item()), fontsize=14, ha='right')\n",
+    "axes[1].text(x=ds.shape[0] - 2, y=12, s='RMSE = {:.2f}°C$^2$'.format(rmse_ex.item()), fontsize=14, ha='right')\n",
     "\n",
     "# save figure\n",
-    "fig.savefig('../Notebooks/Figures/{}_average_bias_p{:.0f}.png'.format(PREDICTAND, quantile * 100), dpi=300, bbox_inches='tight')"
+    "fig.savefig('../Notebooks/Figures/{}_bias_p{:.0f}_{}.png'.format(PREDICTAND, quantile * 100, LOSS), dpi=300, bbox_inches='tight')"
    ]
   },
   {
@@ -1002,7 +1205,7 @@
  ],
  "metadata": {
   "kernelspec": {
-   "display_name": "Python 3",
+   "display_name": "Python 3 (ipykernel)",
    "language": "python",
    "name": "python3"
   },
diff --git a/Notebooks/lr_range_test.ipynb b/Notebooks/lr_range_test.ipynb
new file mode 100644
index 0000000000000000000000000000000000000000..f12faf6d50687fee13302ae02360e18564b3fb4c
--- /dev/null
+++ b/Notebooks/lr_range_test.ipynb
@@ -0,0 +1,290 @@
+{
+ "cells": [
+  {
+   "cell_type": "markdown",
+   "id": "4d010ae8-4172-417f-a457-c14e46adbf85",
+   "metadata": {},
+   "source": [
+    "# Learning rate range test"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "id": "c3200572-47da-48a6-a143-0ab26a3a3e40",
+   "metadata": {},
+   "source": [
+    "Define the predictand and the model to evaluate:"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "dd249096-9e57-425f-a0e9-29f31ccf4f6c",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "# define the model parameters\n",
+    "PREDICTANDS = ['pr', 'tasmax', 'tasmin']\n",
+    "MODEL = 'USegNet'\n",
+    "PPREDICTORS = 'ztuvq'\n",
+    "# PPREDICTORS = ''\n",
+    "PLEVELS = ['500', '850']\n",
+    "# PLEVELS = []\n",
+    "SPREDICTORS = 'p'\n",
+    "# SPREDICTORS = 'pr'\n",
+    "DEM = 'dem'\n",
+    "DEM_FEATURES = ''\n",
+    "DOY = 'doy'\n",
+    "LOSS = ['BernoulliGammaLoss', 'L1Loss', 'MSELoss']\n",
+    "OPTIM = ['SGD', 'Adam']\n",
+    "WET_DAY_THRESHOLD = '1'"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "5c4f6f06-6bf2-453c-a2a8-88386358c36f",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "# mapping from predictands to variable names\n",
+    "NAMES = {'tasmin': 'minimum temperature', 'tasmax': 'maximum temperature', 'pr': 'precipitation'}"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "id": "0f47fdee-484b-457b-bdbf-a3b2c67132d9",
+   "metadata": {},
+   "source": [
+    "### Imports"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "21b9927a-ebbe-48ff-9c64-fbb27a069f4e",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "# builtins\n",
+    "import datetime\n",
+    "import warnings\n",
+    "import calendar\n",
+    "\n",
+    "# externals\n",
+    "import torch\n",
+    "import xarray as xr\n",
+    "import numpy as np\n",
+    "import matplotlib.pyplot as plt\n",
+    "import matplotlib.lines as mlines\n",
+    "import seaborn as sns\n",
+    "\n",
+    "# locals\n",
+    "from climax.main.io import ERA5_PATH, OBS_PATH, TARGET_PATH, DEM_PATH, MODEL_PATH\n",
+    "from climax.core.utils import plot_loss\n",
+    "from pysegcnn.core.utils import search_files\n",
+    "from pysegcnn.core.graphics import running_mean"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "ff1b059d-71b6-46bb-b2c3-66b56120546f",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "# path to models for learning rate range test\n",
+    "MODEL_PATH = MODEL_PATH.joinpath('lr-range')"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "id": "40c94ce9-ed05-4434-908a-91683c1400cb",
+   "metadata": {},
+   "source": [
+    "### Load datasets"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "49cf0a29-9564-4c78-9fcd-57af945855df",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "# construct file pattern to match\n",
+    "PATTERN = '_'.join([MODEL, 'predictand',])\n",
+    "PATTERN = '_'.join([PATTERN, PPREDICTORS]) if PPREDICTORS else PATTERN\n",
+    "PATTERN = '_'.join([PATTERN, *PLEVELS]) if PPREDICTORS else PATTERN\n",
+    "PATTERN = '_'.join([PATTERN, SPREDICTORS]) if SPREDICTORS else PATTERN\n",
+    "PATTERN = '_'.join([PATTERN, DEM]) if DEM else PATTERN\n",
+    "PATTERN = '_'.join([PATTERN, DOY]) if DOY else PATTERN\n",
+    "PATTERN = '_'.join([PATTERN, 'loss'])\n",
+    "PATTERN = '_'.join([PATTERN, 'optim'])\n",
+    "PATTERN = '_'.join([PATTERN, 'lr_test.pt'])\n",
+    "PATTERNS = {k: PATTERN.replace('predictand', k) for k in PREDICTANDS}\n",
+    "PATTERNS = {k1: {v2: v1.replace('loss', '{}mm_{}'.format(WET_DAY_THRESHOLD, v2)) if v2 in\n",
+    "                 ['BernoulliGammaLoss', 'BernoulliWeibullLoss'] else v1.replace('loss', v2) for\n",
+    "                 v2 in LOSS} for k1, v1 in PATTERNS.items()}\n",
+    "PATTERNS = {optim: {k: {l: v.replace('optim', optim) for l, v in PATTERNS[k].items()} for k in PREDICTANDS} for optim in OPTIM}\n",
+    "PATTERNS"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "fee242d6-cd1a-48b4-ac3d-6d35d428247d",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "# load datasets based on loss function\n",
+    "y_pred = {}\n",
+    "for o in OPTIM:\n",
+    "    y_pred[o] = {}\n",
+    "    for k in PREDICTANDS:\n",
+    "        y_pred[o][k] = {}\n",
+    "        for l, v in PATTERNS[o][k].items():\n",
+    "            try:\n",
+    "                y_pred[o][k][l] = torch.load(MODEL_PATH.joinpath(k, v), map_location=torch.device('cpu'))['state']\n",
+    "            except FileNotFoundError:\n",
+    "                y_pred[o][k][l] = np.nan"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "id": "14493ac2-e2bd-40f8-a52b-57d2d9df5942",
+   "metadata": {},
+   "source": [
+    "## Loss as function of learning rate"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "b19e74f7-50c8-4445-8190-520ec7f9c2c2",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "# minimum learning rate\n",
+    "MIN_LR = 1e-4\n",
+    "\n",
+    "# factor of learning rate increase\n",
+    "GAMMA = 1.6"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "d3b806dc-f729-4f4f-ba63-a9e1595d5666",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "# color palette\n",
+    "PALETTE = 'mako'\n",
+    "COLORS = sns.color_palette(PALETTE, n_colors=len(LOSS))\n",
+    "COLORS"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "7916e76e-806c-43c6-8dd6-c5d0296d982a",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "# plot loss as function of learning rate\n",
+    "fig, axes = plt.subplots(1, len(PREDICTANDS), figsize=(16, 6), sharex=True, sharey=True)\n",
+    "axes = axes.flatten()\n",
+    "max_lrs = {}\n",
+    "for ax, p in zip(axes, PREDICTANDS):\n",
+    "    max_lrs[p] = {}\n",
+    "    for l, c in zip(LOSS, COLORS):\n",
+    "        max_lrs[p][l] = {}\n",
+    "        # plot observed loss as function of learning rate\n",
+    "        for o, ls, fs in zip(OPTIM, ['-', ':'], ['full', 'none']):\n",
+    "            try:\n",
+    "                loss = y_pred[o][p][l]['train_loss']\n",
+    "                batches, epochs = loss.shape    \n",
+    "                lr = running_mean(np.repeat(np.asarray([MIN_LR * (GAMMA ** e) for e in range(0, epochs)]), batches), batches)\n",
+    "                loss = running_mean(loss.flatten('F').clip(max=100), batches)\n",
+    "\n",
+    "                # learning rate at minimum loss\n",
+    "                mask = ~np.isnan(loss)\n",
+    "                try:\n",
+    "                    max_lr = lr[mask][np.argmin(loss[mask])]\n",
+    "                    min_loss = np.min(loss[mask])\n",
+    "                except ValueError:\n",
+    "                    max_lr = np.nan\n",
+    "                max_lrs[p][l][o] = max_lr\n",
+    "\n",
+    "                # plot running mean\n",
+    "                ax.plot(lr, loss, color=c, ls=ls)\n",
+    "\n",
+    "                # plot minimum loss\n",
+    "                ax.plot(max_lr, min_loss, 'o', color=c, fillstyle=fs)\n",
+    "                # ax.legend(frameon=False, fontsize=14);\n",
+    "            except TypeError:\n",
+    "                ax.plot(np.nan, np.nan)\n",
+    "\n",
+    "    # axes properties\n",
+    "    ax.set_title('{}'.format(NAMES[p].capitalize()), fontsize=16, pad=10)\n",
+    "    ax.set_xscale('log')\n",
+    "    ax.set_yscale('log')\n",
+    "    ax.set_ylim(5e-1, 2e2);\n",
+    "    ax.set_xlim(5e-5, 3e0);\n",
+    "    ax.set_xlabel('Learning rate $\\\\alpha$', fontsize=16)\n",
+    "    ax.tick_params(axis='both', which='major', labelsize=16)\n",
+    "    \n",
+    "# add legends\n",
+    "\n",
+    "# loss functions\n",
+    "patches = [mlines.Line2D([], [], color=c, label=l) for c, l in\n",
+    "           zip(COLORS, LOSS)]\n",
+    "axes[0].legend(handles=patches, loc='upper left', frameon=False, bbox_to_anchor=(-0.05, -0.15), ncol=len(LOSS), fontsize=14)\n",
+    "\n",
+    "# optimizers\n",
+    "patches = [mlines.Line2D([], [], color='black', ls=ls, label=l) for ls, l in\n",
+    "           zip(['-', ':'], OPTIM)]\n",
+    "axes[1].legend(handles=patches, loc='upper left', frameon=False, bbox_to_anchor=(-1.1, -0.25), ncol=len(OPTIM), fontsize=14)\n",
+    "    \n",
+    "# figure properties\n",
+    "axes[0].set_ylabel('Training loss', fontsize=16) \n",
+    "fig.subplots_adjust(wspace=0.05)\n",
+    "fig.savefig('./Figures/lr_range_test.png', dpi=300, bbox_inches='tight')"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "d5588c08-1fec-42ba-9463-7e714e0daf03",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "# print learning rates at minimum of loss, i.e. max learning rate\n",
+    "for p in PREDICTANDS:\n",
+    "    [[print('({}), {}: {}: Max LR = {:.5f}'.format(p, l, o, lr)) for (o, lr) in max_lrs[p][l].items()] for l in max_lrs[p].keys()]"
+   ]
+  }
+ ],
+ "metadata": {
+  "kernelspec": {
+   "display_name": "Python 3 (ipykernel)",
+   "language": "python",
+   "name": "python3"
+  },
+  "language_info": {
+   "codemirror_mode": {
+    "name": "ipython",
+    "version": 3
+   },
+   "file_extension": ".py",
+   "mimetype": "text/x-python",
+   "name": "python",
+   "nbconvert_exporter": "python",
+   "pygments_lexer": "ipython3",
+   "version": "3.7.10"
+  }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 5
+}
diff --git a/Notebooks/pr_distribution.ipynb b/Notebooks/pr_distribution.ipynb
index 1a41e3d539ccb0dfd5b3310cc214242655de63f8..f8f697499276ade57ca415fd074b7ed57ddbd59b 100644
--- a/Notebooks/pr_distribution.ipynb
+++ b/Notebooks/pr_distribution.ipynb
@@ -62,6 +62,17 @@
     "quantiles = np.arange(0.01, 1, 0.005)"
    ]
   },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "0fd28eaf-2939-40e1-841a-49dabe6778d7",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "# wet day threshold\n",
+    "WET_DAY_THRESHOLD = 1"
+   ]
+  },
   {
    "cell_type": "markdown",
    "id": "12382efb-1a3a-4ede-a904-7f762bfe56c7",
@@ -129,10 +140,10 @@
    "outputs": [],
    "source": [
     "# helper function retrieving only valid observations\n",
-    "def valid(ds):\n",
+    "def valid(ds, min_amount=WET_DAY_THRESHOLD):\n",
     "    valid = ds.precipitation.values\n",
     "    valid = valid[~np.isnan(valid)]  # mask missing values\n",
-    "    valid = valid[valid > 0]  # only consider pr > 0\n",
+    "    valid = valid[valid > min_amount]  # only consider pr > 0\n",
     "    return valid"
    ]
   },
@@ -179,8 +190,8 @@
    "outputs": [],
    "source": [
     "# fit generalized pareto distribution to data\n",
-    "alpha, loc, beta = stats.genpareto.fit(y_valid, floc=0)\n",
-    "genpareto = stats.genpareto(alpha, loc=loc, scale=beta)"
+    "# alpha, loc, beta = stats.genpareto.fit(y_valid, floc=0)\n",
+    "# genpareto = stats.genpareto(alpha, loc=loc, scale=beta)"
    ]
   },
   {
@@ -207,6 +218,16 @@
     "weibull = stats.weibull_min(alpha, loc=loc, scale=beta)"
    ]
   },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "c7daee53-d996-42ed-aa84-a4770191cafe",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "alpha, loc, beta"
+   ]
+  },
   {
    "cell_type": "code",
    "execution_count": null,
@@ -217,7 +238,7 @@
     "# empirical quantiles and theoretical quantiles\n",
     "eq = np.quantile(y_valid, quantiles)\n",
     "tq_gamma = gamma.ppf(quantiles)\n",
-    "tq_genpareto = genpareto.ppf(quantiles)\n",
+    "#tq_genpareto = genpareto.ppf(quantiles)\n",
     "tq_expon = expon.ppf(quantiles)\n",
     "tq_lognorm = lognorm.ppf(quantiles)\n",
     "tq_weibull = weibull.ppf(quantiles)\n",
@@ -226,7 +247,7 @@
     "RANGE = 40\n",
     "fig, ax = plt.subplots(1, 1, figsize=(6, 6))\n",
     "ax.scatter(eq, tq_gamma, marker='*', color='k', label='Gamma')\n",
-    "ax.scatter(eq, tq_genpareto, marker='x', color='k', label='GenPareto')\n",
+    "# ax.scatter(eq, tq_genpareto, marker='x', color='k', label='GenPareto')\n",
     "ax.scatter(eq, tq_expon, marker='o', color='k', label='Expon')\n",
     "ax.scatter(eq, tq_lognorm, marker='+', color='k', label='LogNorm')\n",
     "ax.scatter(eq, tq_weibull, marker='^', color='k', label='Weibull')\n",
diff --git a/Notebooks/pr_loss.ipynb b/Notebooks/pr_loss.ipynb
new file mode 100644
index 0000000000000000000000000000000000000000..ffbd160b3c7a6c5854bc69e8189d56d555420677
--- /dev/null
+++ b/Notebooks/pr_loss.ipynb
@@ -0,0 +1,193 @@
+{
+ "cells": [
+  {
+   "cell_type": "markdown",
+   "id": "b21026e9-2bd2-4b52-a902-57a3b21875b1",
+   "metadata": {},
+   "source": [
+    "# Loss functions"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "id": "6b81f5f2-4f49-4680-b6d2-32e7668ca53c",
+   "metadata": {},
+   "source": [
+    "For precipitation, the network is optimizing the negative log-likelihood of a mixed discrete Bernoulli and a continuous Gamma or Weibull distribution. The loss function is defined as the Log-likelihood $l(p, \\alpha, \\beta \\mid y)$ function of the mixed discrete-continuous distribution $P(y \\mid p, \\alpha, \\beta)$ for $i = 1 \\dots N$ i.i.d observations $y_i$:"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "id": "ac97e5ef-1881-44d7-9704-076d26347e15",
+   "metadata": {},
+   "source": [
+    "$$l(p, \\alpha, \\beta \\mid y) = \\log\\left(\\prod_{i=1}^{N} P(y_i \\mid p, \\alpha, \\beta)\\right) = \\sum_{i=1}^{N} \\log\\left(P(y_i \\mid p, \\alpha, \\beta)\\right)$$"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "id": "b0697679-ec7b-4ad4-a676-1c362f3affff",
+   "metadata": {},
+   "source": [
+    "## Bernoulli-Gamma loss function"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "id": "209fc0e8-fff2-44fe-bf70-470c558068d9",
+   "metadata": {},
+   "source": [
+    "The Bernoulli-Gamma distribution for precipitation was introduced by [Cannon (2008)](http://journals.ametsoc.org/doi/10.1175/2008JHM960.1) and is defined as:"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "id": "9289faf7-b006-4a25-8afe-c09a07b710ae",
+   "metadata": {},
+   "source": [
+    "$$P(y \\mid, p, \\alpha, \\beta) = \\begin{cases} 1 - p, & \\text{for } y = 0\\\\ p \\cdot \\frac{y^{\\alpha -1} \\exp\\left(-\\frac{y}{\\beta}\\right)}{\\beta^{\\alpha} \\tau(\\alpha)}, & \\text{for } y > 0\\end{cases}$$"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "id": "4d40fbc3-0180-4ed3-8fdf-97326f9cb5c8",
+   "metadata": {},
+   "source": [
+    "with $\\alpha, \\beta > 0$ as the *shape* and *scale* parameters, $p \\in [0, 1]$ the *predicted probability* of precipitation, $y$ the observed precipitation amount, and $\\tau(\\alpha)$ the *gamma function*."
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "id": "9f8e2eb6-02b3-40a4-be3f-daea5a623235",
+   "metadata": {},
+   "source": [
+    "Log-likelihood function of the Bernoulli-Gamma distribution,"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "id": "8c563a68-a154-47e1-bd4c-d0fb1a3bd16f",
+   "metadata": {},
+   "source": [
+    "$$\\mathcal{l}(p, \\alpha, \\beta \\mid y) = \\underbrace{(1 - P(y > 0)) \\log(1 - p)}_{\\text{Bernoulli}} + \\underbrace{P(y > 0) \\cdot \\left(\\log(p) + (\\alpha - 1) \\log(y) - \\frac{y}{\\beta} - \\alpha \\log(\\beta) - \\log(\\tau(\\alpha))\\right)}_{\\text{Gamma}}$$"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "id": "8d8ed32a-e7bf-491a-a59f-34cc9be82584",
+   "metadata": {},
+   "source": [
+    "where $P(y > 0) \\in \\{0, 1\\}$ is the *true probability* of precipitation."
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "id": "8e398417-8e20-4684-bf30-fbd3c186f6c1",
+   "metadata": {},
+   "source": [
+    "## Bernoulli-Weibull loss function"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "id": "e39f72f9-58a8-4093-8fce-4870e2364de1",
+   "metadata": {},
+   "source": [
+    "The Bernoulli-Weibull distribution is defined as:"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "id": "41067b17-745e-4f8e-9530-c2c72ec5d263",
+   "metadata": {},
+   "source": [
+    "$$P(y \\mid, p, \\alpha, \\beta) = \\begin{cases} 1 - p, & \\text{for } y = 0\\\\ p \\cdot \\frac{\\alpha}{\\beta}\\left(\\frac{y}{\\beta}\\right)^{\\alpha - 1} \\exp\\left(\\left(-\\frac{y}{\\beta}\\right)^{\\alpha}\\right), & \\text{for } y > 0\\end{cases}$$"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "id": "b8d14bcc-3c4c-4cef-bd07-3cd96cb3f79f",
+   "metadata": {},
+   "source": [
+    "with $\\alpha, \\beta > 0$ as the *shape* and *scale* parameters, $p \\in [0, 1]$ the *predicted probability* of precipitation, and $y$ the observed precipitation amount."
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "id": "e87fe6bc-a666-421e-b04a-941827342244",
+   "metadata": {},
+   "source": [
+    "Log-likelihood function of the Bernoulli-Weibull distribution,"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "id": "db6518d7-a329-432b-b750-8a6cba54d909",
+   "metadata": {},
+   "source": [
+    "$$\\mathcal{l}(p, \\alpha, \\beta \\mid y) = \\underbrace{(1 - P(y > 0)) \\log(1 - p)}_{\\text{Bernoulli}} + \\underbrace{P(y > 0) \\cdot \\left(\\log(\\alpha) - \\alpha \\log(\\beta) + (\\alpha - 1) \\log(y) - \\left(\\frac{y}{\\beta}\\right)^{\\alpha}\\right)}_{\\text{Weibull}}$$"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "id": "c955910c-49c6-4e63-b06f-761fc3d7cb1d",
+   "metadata": {},
+   "source": [
+    "where $P(y > 0) \\in \\{0, 1\\}$ is the *true probability* of precipitation."
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "id": "f9847ca2-16f2-4286-8eb0-54773a560234",
+   "metadata": {},
+   "source": [
+    "## Loss function with regularization"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "id": "e95ff9df-a44b-4d03-bbee-5406835df453",
+   "metadata": {},
+   "source": [
+    "The loss function with **regularization (shrinkage)** term is defined as,"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "id": "fffb81c0-3358-40fa-b3f2-e9a19014cde8",
+   "metadata": {},
+   "source": [
+    "$$\\mathcal{l}(p, \\alpha, \\beta \\mid y, \\theta)_{\\lambda} = \\underbrace{\\mathcal{l}(p, \\alpha, \\beta \\mid y, \\theta)}_{\\text{loss function}} + \\underbrace{\\lambda \\mid\\mid\\theta\\mid\\mid_2^2}_{L_2-\\text{penalty}},$$"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "id": "5a62bc1a-c469-4114-812d-588f0916918a",
+   "metadata": {},
+   "source": [
+    "where $\\lambda \\in [0, 1]$ is the weight decay factor."
+   ]
+  }
+ ],
+ "metadata": {
+  "kernelspec": {
+   "display_name": "Python 3",
+   "language": "python",
+   "name": "python3"
+  },
+  "language_info": {
+   "codemirror_mode": {
+    "name": "ipython",
+    "version": 3
+   },
+   "file_extension": ".py",
+   "mimetype": "text/x-python",
+   "name": "python",
+   "nbconvert_exporter": "python",
+   "pygments_lexer": "ipython3",
+   "version": "3.7.10"
+  }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 5
+}