diff --git a/Notebooks/eval_capstone.ipynb b/Notebooks/eval_capstone.ipynb new file mode 100644 index 0000000000000000000000000000000000000000..93b40901353071d30379e31902b646055fc9effe --- /dev/null +++ b/Notebooks/eval_capstone.ipynb @@ -0,0 +1,415 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "7a50f2fe-b66a-4056-955f-2b3d40fca3be", + "metadata": {}, + "source": [ + "# Evaluate ERA-5 downscaling: minimum and maximum temperatures" + ] + }, + { + "cell_type": "markdown", + "id": "9ed2c2d9-30b8-4813-9707-ee08b5dfde5c", + "metadata": {}, + "source": [ + "We used **1981-1991 as training** period and **1991-2010 as reference** period. The results shown in this notebook are based on the model predictions on the reference period." + ] + }, + { + "cell_type": "markdown", + "id": "2e7da988-24d7-4623-8a27-24188b04638f", + "metadata": {}, + "source": [ + "**Predictors on pressure levels (500, 850)**:\n", + "- Geopotential (z)\n", + "- Temperature (t)\n", + "- Zonal wind (u)\n", + "- Meridional wind (v)\n", + "- Specific humidity (q)\n", + "\n", + "**Predictors on surface**:\n", + "- Surface pressure (p)\n", + "\n", + "**Auxiliary predictors**:\n", + "- Elevation from Copernicus EU-DEM v1.1 (dem)\n", + "- Day of the year (doy)" + ] + }, + { + "cell_type": "markdown", + "id": "84929d31-fd5d-4564-ab5d-239a05556f02", + "metadata": {}, + "source": [ + "Define the predictand and the model to evaluate:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5969d002-1bc9-4e0e-8c98-6757c3e61a8a", + "metadata": {}, + "outputs": [], + "source": [ + "# define the predictand\n", + "PREDICTAND = 'tasmin' # 'tasmin' or 'tasmax'" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "a7a2d388-9fd4-4e1d-8c05-705ab66ad098", + "metadata": {}, + "outputs": [], + "source": [ + "# model filename to evaluate\n", + "FILENAME = 'USegNet_tasmin_ztuvq_500_850_p_dem_doy_L1Loss_Adam_d1e-05.nc' # change me!" + ] + }, + { + "cell_type": "markdown", + "id": "f9653fb9-030e-4665-8759-664211569293", + "metadata": {}, + "source": [ + "### Imports" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9c2f1cdb-c918-4997-8a55-4ca7c7092209", + "metadata": {}, + "outputs": [], + "source": [ + "# builtins\n", + "import datetime\n", + "import warnings\n", + "import calendar\n", + "\n", + "# externals\n", + "import xarray as xr\n", + "import numpy as np\n", + "import matplotlib.pyplot as plt\n", + "from matplotlib import cm\n", + "from matplotlib.colors import ListedColormap\n", + "from mpl_toolkits.axes_grid1.inset_locator import inset_axes\n", + "import scipy.stats as stats\n", + "from IPython.display import Image\n", + "from sklearn.metrics import r2_score\n", + "\n", + "# locals\n", + "from climax.main.io import ERA5_PATH, OBS_PATH, TARGET_PATH, DEM_PATH, MODEL_PATH\n", + "from climax.core.utils import plot_loss\n", + "from climax.core.dataset import ERA5Dataset\n", + "from pysegcnn.core.utils import search_files" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c5a084dd-1853-4cc8-b413-e4acac534de9", + "metadata": {}, + "outputs": [], + "source": [ + "# mapping from predictands to variable names\n", + "NAMES = {'tasmin': 'minimum temperature', 'tasmax': 'maximum temperature', 'pr': 'precipitation'}" + ] + }, + { + "cell_type": "markdown", + "id": "8b921da4-e4d5-44d3-9d34-00b676f47891", + "metadata": {}, + "source": [ + "### Model architecture" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "dda4e69d-0cef-4f04-b6b0-499a2021fd27", + "metadata": {}, + "outputs": [], + "source": [ + "Image(\"./Figures/architecture.png\", width=900, height=400)" + ] + }, + { + "cell_type": "markdown", + "id": "8aaf6fa7-25c5-43d1-9a5b-e5604a7becce", + "metadata": {}, + "source": [ + "### Load datasets" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5d476c21-a797-46d6-816c-0af0a15168ee", + "metadata": {}, + "outputs": [], + "source": [ + "# model predictions\n", + "y_pred = xr.open_dataset(TARGET_PATH.joinpath(PREDICTAND, FILENAME))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9ce62388-f7f1-444c-96e9-02055cef9a9f", + "metadata": {}, + "outputs": [], + "source": [ + "# target values: observations\n", + "y_true = xr.open_dataset(search_files(OBS_PATH.joinpath(PREDICTAND), '.nc$').pop())" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "57f4ed1a-86b1-4f25-bb74-03b84d2ac7cf", + "metadata": {}, + "outputs": [], + "source": [ + "# subset to time period covered by predictions\n", + "y_true = y_true.sel(time=y_pred.time)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "48fe83a2-236d-4d19-8dff-77429077e6be", + "metadata": {}, + "outputs": [], + "source": [ + "# align datasets and mask missing values in model predictions\n", + "y_true, y_pred = xr.align(y_true[PREDICTAND], y_pred[PREDICTAND], join='override')\n", + "y_pred = y_pred.where(~np.isnan(y_true), other=np.nan)" + ] + }, + { + "cell_type": "markdown", + "id": "edcccc9c-e4b2-42ba-ac4c-113de4614c51", + "metadata": {}, + "source": [ + "## Model validation" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e36acbd7-15b7-48f7-9b2d-64cf1b2094a4", + "metadata": {}, + "outputs": [], + "source": [ + "# calculate monthly means\n", + "y_pred_mm = y_pred.groupby('time.month').mean(dim=('time'))\n", + "y_true_mm = y_true.groupby('time.month').mean(dim=('time'))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "77397c0b-3238-489a-b798-d7cc9af1ac76", + "metadata": {}, + "outputs": [], + "source": [ + "# calculate mean annual cycle\n", + "y_pred_ac = y_pred_mm.mean(dim=('x', 'y'))\n", + "y_true_ac = y_true_mm.mean(dim=('x', 'y'))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "fdb22b20-e19d-4da6-9373-4724fae48d19", + "metadata": {}, + "outputs": [], + "source": [ + "# compute daily anomalies\n", + "y_pred_anom = ERA5Dataset.anomalies(y_pred, timescale='time.month')\n", + "y_true_anom = ERA5Dataset.anomalies(y_true, timescale='time.month')" + ] + }, + { + "cell_type": "markdown", + "id": "36e15e61-6552-4a76-9db3-166b9769b343", + "metadata": {}, + "source": [ + "### Coefficient of determination" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c076bffa-30f9-40de-97fa-eca80ea2c272", + "metadata": {}, + "outputs": [], + "source": [ + "# get predicted and observed daily anomalies\n", + "y_pred_av = y_pred_anom.values.flatten()\n", + "y_true_av = y_true_anom.values.flatten()\n", + "\n", + "# apply mask of valid pixels\n", + "mask = (~np.isnan(y_pred_av) & ~np.isnan(y_true_av))\n", + "y_pred_av = y_pred_av[mask]\n", + "y_true_av = y_true_av[mask]\n", + "\n", + "# get predicted and observed monthly means\n", + "y_pred_mv = y_pred_mm.values.flatten()\n", + "y_true_mv = y_true_mm.values.flatten()\n", + "\n", + "# apply mask of valid pixels\n", + "mask = (~np.isnan(y_pred_mv) & ~np.isnan(y_true_mv))\n", + "y_pred_mv = y_pred_mv[mask]\n", + "y_true_mv = y_true_mv[mask]\n", + "\n", + "# calculate coefficient of determination on monthly means\n", + "r2_mm = r2_score(y_true_mv, y_pred_mv)\n", + "print('R2 on monthly means: {:.2f}'.format(r2_mm))\n", + "\n", + "# calculate coefficient of determination on daily anomalies\n", + "r2_anom = r2_score(y_true_av, y_pred_av)\n", + "print('R2 on daily anomalies: {:.2f}'.format(r2_anom))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e34f28dc-af4e-463f-b668-3259b107be45", + "metadata": {}, + "outputs": [], + "source": [ + "# scatter plot of observations vs. predictions\n", + "fig, ax = plt.subplots(1, 1, figsize=(10, 10))\n", + "\n", + "# plot entire dataset\n", + "ax.plot(y_true_mv, y_pred_mv, 'o', alpha=.5, markeredgecolor='grey', markerfacecolor='none', markersize=3);\n", + "\n", + "# plot 1:1 mapping line\n", + "if PREDICTAND == 'tasmin':\n", + " interval = np.arange(-25, 30, 5)\n", + "else:\n", + " interval = np.arange(-15, 45, 5)\n", + "ax.plot(interval, interval, color='k', lw=2, ls='--')\n", + "\n", + "# add coefficients of determination\n", + "ax.text(interval[-1] - 0.5, interval[0] + 0.5, s='R$^2$ (monthly means)= {:.2f}'.format(r2_mm), ha='right', fontsize=18)\n", + "ax.text(interval[-1] - 0.5, interval[0] + 2.5, s='R$^2$ (daily anomalies) = {:.2f}'.format(r2_anom), ha='right', fontsize=18)\n", + "\n", + "# format axes\n", + "ax.set_ylim(interval[0], interval[-1])\n", + "ax.set_xlim(interval[0], interval[-1])\n", + "ax.set_xticks(interval)\n", + "ax.set_xticklabels(interval, fontsize=16)\n", + "ax.set_yticks(interval)\n", + "ax.set_yticklabels(interval, fontsize=16)\n", + "ax.set_xlabel('Observed', fontsize=18)\n", + "ax.set_ylabel('Predicted', fontsize=18)\n", + "ax.set_title('Monthly mean {} (°C)'.format(NAMES[PREDICTAND]), fontsize=20, pad=10);\n", + "\n", + "# add axis for annual cycle\n", + "axins = inset_axes(ax, width=\"30%\", height=\"40%\", loc=2, borderpad=1)\n", + "axins.plot(y_pred_ac.values, ls='--', color='k', label='Predicted')\n", + "axins.plot(y_true_ac.values, ls='-', color='k', label='Observed')\n", + "axins.legend(frameon=False, fontsize=12, loc='lower center');\n", + "axins.yaxis.tick_right()\n", + "axins.set_yticks(np.arange(-10, 11, 2) if PREDICTAND == 'tasmin' else np.arange(0, 20, 2))\n", + "axins.set_yticklabels(np.arange(-10, 11, 2) if PREDICTAND == 'tasmin' else np.arange(0, 20, 2), fontsize=12)\n", + "axins.set_xticks(np.arange(0, 12))\n", + "axins.set_xticklabels([calendar.month_name[i + 1] for i in np.arange(0, 12)], rotation=90, fontsize=12)\n", + "axins.set_title('Mean annual cycle', fontsize=14, pad=5);" + ] + }, + { + "cell_type": "markdown", + "id": "e503c143-3844-4d25-9c3e-52704f9bd02b", + "metadata": {}, + "source": [ + "### Mean error (Bias)" + ] + }, + { + "cell_type": "markdown", + "id": "1d5fdda2-40ae-4a97-a5ce-c7bf87e0824e", + "metadata": {}, + "source": [ + "Calculate yearly average bias over entire reference period:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c976976d-f90e-4877-abc4-225d99d08f94", + "metadata": {}, + "outputs": [], + "source": [ + "# yearly average bias over reference period\n", + "y_pred_yearly_avg = y_pred.groupby('time.year').mean(dim='time')\n", + "y_true_yearly_avg = y_true.groupby('time.year').mean(dim='time')\n", + "bias_yearly_avg = y_pred_yearly_avg - y_true_yearly_avg\n", + "print('(Model) Yearly average bias of {}: {:.2f}°C'.format(PREDICTAND, bias_yearly_avg.mean().item()))" + ] + }, + { + "cell_type": "markdown", + "id": "9b896c84-6a53-4929-8ec5-964a9775e1e8", + "metadata": {}, + "source": [ + "### Mean absolute error (MAE)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "dd1f3693-9c17-463a-801f-a54196a534d8", + "metadata": {}, + "outputs": [], + "source": [ + "# mean absolute error over reference period\n", + "mae_avg = np.abs(y_pred_yearly_avg - y_true_yearly_avg).mean()\n", + "print('(Model) Yearly average MAE of {}: {:.2f}°C'.format(PREDICTAND, mae_avg.mean().item()))" + ] + }, + { + "cell_type": "markdown", + "id": "2c8f1e82-cc7e-4b99-a01d-5227eedceb5c", + "metadata": {}, + "source": [ + "### Root mean squared error (RMSE)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "05fe6ca3-705f-4047-a7af-449fdf3f72df", + "metadata": {}, + "outputs": [], + "source": [ + "# root mean squared error over reference period\n", + "rmse_avg = np.sqrt(((y_pred_yearly_avg - y_true_yearly_avg) ** 2).mean())\n", + "print('(Model) Yearly average RMSE of {}: {:.2f}°C'.format(PREDICTAND, rmse_avg.mean().item()))" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.7.12" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +}