{ "cells": [ { "cell_type": "markdown", "id": "fde8874d-299f-4f48-a10a-9fb6a00b43b9", "metadata": {}, "source": [ "# Evaluate bootstrapped model results" ] }, { "cell_type": "markdown", "id": "969d063b-5262-4324-901f-0a48630c4f27", "metadata": { "tags": [] }, "source": [ "## Imports and constants" ] }, { "cell_type": "code", "execution_count": null, "id": "8af00ae4-4aeb-4ff8-a46a-65966b28c440", "metadata": {}, "outputs": [], "source": [ "# builtins\n", "import pathlib\n", "\n", "# externals\n", "import numpy as np\n", "import xarray as xr\n", "import pandas as pd\n", "import matplotlib.pyplot as plt\n", "from matplotlib import gridspec\n", "import seaborn as sns\n", "\n", "# locals\n", "from climax.main.io import OBS_PATH, ERA5_PATH\n", "from climax.main.config import VALID_PERIOD\n", "from pysegcnn.core.utils import search_files" ] }, { "cell_type": "code", "execution_count": null, "id": "5bc74835-dc59-46ed-849b-3ff614e53eee", "metadata": {}, "outputs": [], "source": [ "# mapping from predictands to variable names\n", "NAMES = {'tasmin': 'minimum temperature', 'tasmax': 'maximum temperature', 'pr': 'precipitation'}" ] }, { "cell_type": "code", "execution_count": null, "id": "c8a63ef3-35ef-4ffa-b1f3-5c2986eb7eb1", "metadata": {}, "outputs": [], "source": [ "# path to bootstrapped model results\n", "RESULTS = pathlib.Path('/mnt/CEPH_PROJECTS/FACT_CLIMAX/ERA5_PRED/bootstrap')" ] }, { "cell_type": "markdown", "id": "7eae545b-4d8a-4689-a6c0-4aba2cb9104e", "metadata": { "tags": [] }, "source": [ "## Search model configurations" ] }, { "cell_type": "code", "execution_count": null, "id": "3b83c9f3-7081-4cec-8f23-c4de007839d7", "metadata": {}, "outputs": [], "source": [ "# predictand to evaluate\n", "PREDICTAND = 'tasmin'" ] }, { "cell_type": "code", "execution_count": null, "id": "3e856f80-14fd-405f-a44e-cc77863f8e5b", "metadata": {}, "outputs": [], "source": [ "# loss function and optimizer\n", "LOSS = ['L1Loss', 'MSELoss', 'BernoulliGammaLoss'] if PREDICTAND == 'pr' else ['L1Loss', 'MSELoss']\n", "OPTIM = 'Adam'" ] }, { "cell_type": "code", "execution_count": null, "id": "011b792d-7349-44ad-997d-11f236472a11", "metadata": {}, "outputs": [], "source": [ "# model to evaluate\n", "models = ['USegNet_{}_ztuvq_500_850_p_dem_doy_1mm_{}_{}'.format(PREDICTAND, loss, OPTIM) if loss == 'BernoulliGammaLoss' else\n", " 'USegNet_{}_ztuvq_500_850_p_dem_doy_{}_{}'.format(PREDICTAND, loss, OPTIM) for loss in LOSS]" ] }, { "cell_type": "code", "execution_count": null, "id": "dc4ca6f0-5490-4522-8661-e36bd1be11b7", "metadata": {}, "outputs": [], "source": [ "# get bootstrapped models\n", "models = {loss: sorted(search_files(RESULTS.joinpath(PREDICTAND), model + '(.*).nc$'),\n", " key=lambda x: int(x.stem.split('_')[-1])) for loss, model in zip(LOSS, models)}\n", "models" ] }, { "cell_type": "markdown", "id": "5a64795a-6e5c-409a-8b3b-c738a96fa255", "metadata": { "tags": [] }, "source": [ "## Load datasets" ] }, { "cell_type": "markdown", "id": "e790ed9f-451c-4368-849d-06d9c50f797c", "metadata": {}, "source": [ "### Load observations" ] }, { "cell_type": "code", "execution_count": null, "id": "0862e0c8-06df-45d6-bc1b-002ffb6e9915", "metadata": {}, "outputs": [], "source": [ "# load observations\n", "y_true = xr.open_dataset(OBS_PATH.joinpath(PREDICTAND, 'OBS_{}_1980_2018.nc'.format(PREDICTAND)),\n", " chunks={'time': 365})\n", "y_true = y_true.sel(time=VALID_PERIOD) # subset to time period covered by predictions\n", "y_true = y_true.rename({NAMES[PREDICTAND]: PREDICTAND}) if PREDICTAND == 'pr' else y_true" ] }, { "cell_type": "code", "execution_count": null, "id": "aba38642-85d1-404a-81f3-65d23985fb7a", "metadata": {}, "outputs": [], "source": [ "# mask of missing values\n", "missing = np.isnan(y_true[PREDICTAND])" ] }, { "cell_type": "markdown", "id": "d4512ed2-d503-4bc1-ae76-84560c101a14", "metadata": {}, "source": [ "### Load reference data" ] }, { "cell_type": "code", "execution_count": null, "id": "f90f6abf-5fd6-49c0-a1ad-f62242b3d3a0", "metadata": {}, "outputs": [], "source": [ "# ERA-5 reference dataset\n", "if PREDICTAND == 'pr':\n", " y_refe = xr.open_dataset(search_files(ERA5_PATH.joinpath('ERA5', 'total_precipitation'), '.nc$').pop(),\n", " chunks={'time': 365})\n", " y_refe = y_refe.rename({'tp': 'pr'})\n", "else:\n", " y_refe = xr.open_dataset(search_files(ERA5_PATH.joinpath('ERA5', '2m_{}_temperature'.format(PREDICTAND.lstrip('tas'))), '.nc$').pop(),\n", " chunks={'time': 365})\n", " y_refe = y_refe - 273.15 # convert to °C\n", " y_refe = y_refe.rename({'t2m': PREDICTAND})" ] }, { "cell_type": "code", "execution_count": null, "id": "ea6d5f56-4f39-4e9a-976d-00ff28fce95c", "metadata": {}, "outputs": [], "source": [ "# subset to time period covered by predictions\n", "y_refe = y_refe.sel(time=VALID_PERIOD).drop_vars('lambert_azimuthal_equal_area')\n", "y_refe = y_refe.transpose('time', 'y', 'x') # change order of dimensions" ] }, { "cell_type": "markdown", "id": "d37702de-da5f-4306-acc1-e569471c1f12", "metadata": {}, "source": [ "### Load QM-adjusted reference data" ] }, { "cell_type": "code", "execution_count": null, "id": "fffbd267-d08b-44f4-869c-7056c4f19c28", "metadata": {}, "outputs": [], "source": [ "y_refe_qm = xr.open_dataset(ERA5_PATH.joinpath('QM_ERA5_{}_day_19912010.nc'.format(PREDICTAND)), chunks={'time': 365})\n", "y_refe_qm = y_refe_qm.transpose('time', 'y', 'x') # change order of dimensions" ] }, { "cell_type": "code", "execution_count": null, "id": "16fa580e-27a7-4758-9164-7f607df7179d", "metadata": {}, "outputs": [], "source": [ "# center hours at 00:00:00 rather than 12:00:00\n", "y_refe_qm['time'] = np.asarray([t.astype('datetime64[D]') for t in y_refe_qm.time.values])" ] }, { "cell_type": "code", "execution_count": null, "id": "6789791f-006b-49b3-aa04-34e4ed8e1571", "metadata": {}, "outputs": [], "source": [ "# subset to time period covered by predictions\n", "y_refe_qm = y_refe_qm.sel(time=VALID_PERIOD).drop_vars('lambert_azimuthal_equal_area')" ] }, { "cell_type": "code", "execution_count": null, "id": "b51cfb3f-caa8-413e-a12d-47bbafcef1df", "metadata": {}, "outputs": [], "source": [ "# align datasets and mask missing values\n", "y_true, y_refe, y_refe_qm = xr.align(y_true[PREDICTAND], y_refe[PREDICTAND], y_refe_qm[PREDICTAND], join='override')\n", "y_refe = y_refe.where(~missing, other=np.nan)\n", "y_refe_qm = y_refe_qm.where(~missing, other=np.nan)" ] }, { "cell_type": "markdown", "id": "b4a6c286-6b88-487d-866c-3cb633686dac", "metadata": {}, "source": [ "### Load model predictions" ] }, { "cell_type": "code", "execution_count": null, "id": "eb889059-17e4-4d8c-b796-e8b1e2d0bf8c", "metadata": {}, "outputs": [], "source": [ "y_pred_raw = {k: [xr.open_dataset(v, chunks={'time': 365}) for v in models[k]] for k in models.keys()}\n", "if PREDICTAND == 'pr':\n", " y_pred_raw = {k: [v.rename({NAMES[PREDICTAND]: PREDICTAND}) for v in y_pred_raw[k]] for k in y_pred_raw.keys()}\n", " y_pred_raw = {k: [v.transpose('time', 'y', 'x') for v in y_pred_raw[k]] for k in y_pred_raw.keys()}" ] }, { "cell_type": "code", "execution_count": null, "id": "534e020d-96b2-403c-b8e4-86de98fbbe3b", "metadata": {}, "outputs": [], "source": [ "# align datasets and mask missing values\n", "y_prob = {}\n", "y_pred = {}\n", "for loss, models in y_pred_raw.items():\n", " y_pred[loss], y_prob[loss] = [], []\n", " for y_p in models:\n", " # check whether evaluating precipitation or temperatures\n", " if len(y_p.data_vars) > 1:\n", " _, _, y_p, y_p_prob = xr.align(y_true, y_refe, y_p[PREDICTAND], y_p.prob, join='override')\n", " y_p_prob = y_p_prob.where(~missing, other=np.nan) # mask missing values\n", " y_prob[loss].append(y_p_prob)\n", " else:\n", " _, _, y_p = xr.align(y_true, y_refe, y_p[PREDICTAND], join='override')\n", "\n", " # mask missing values\n", " y_p = y_p.where(~missing, other=np.nan)\n", " y_pred[loss].append(y_p)" ] }, { "cell_type": "markdown", "id": "6a718ea3-54d3-400a-8c89-76d04347de2d", "metadata": { "tags": [] }, "source": [ "## Ensemble predictions" ] }, { "cell_type": "code", "execution_count": null, "id": "5a6c0bfe-c1d2-4e43-9f8e-35c63c46bb10", "metadata": {}, "outputs": [], "source": [ "# create ensemble dataset\n", "ensemble = {k: xr.Dataset({'Member-{}'.format(i): member for i, member in enumerate(y_pred[k])}).to_array('members')\n", " for k in y_pred.keys() if y_pred[k]}" ] }, { "cell_type": "code", "execution_count": null, "id": "0e526227-cd4c-4a1c-ab72-51b72a4f821f", "metadata": {}, "outputs": [], "source": [ "# full ensemble mean prediction and standard deviation\n", "ensemble_mean_full = {k: v.mean(dim='members') for k, v in ensemble.items()}\n", "ensemble_std_full = {k: v.std(dim='members') for k, v in ensemble.items()}" ] }, { "cell_type": "code", "execution_count": null, "id": "d4a70701-2823-4106-ad6a-3272b678d0f9", "metadata": {}, "outputs": [], "source": [ "# ensemble mean prediction using three random members\n", "ensemble_3 = np.random.randint(0, len(ensemble['L1Loss'].members), size=3)\n", "ensemble_mean_3 = {k: v[ensemble_3, ...].mean(dim='members') for k, v in ensemble.items()}\n", "ensemble_std_3 = {k: v[ensemble_3, ...].std(dim='members') for k, v in ensemble.items()}" ] }, { "cell_type": "code", "execution_count": null, "id": "c4d18814-1340-4ed4-8102-2ccd6f0c2664", "metadata": {}, "outputs": [], "source": [ "# ensemble mean prediction using five random members\n", "ensemble_5 = np.random.randint(0, len(ensemble['L1Loss'].members), size=5)\n", "ensemble_mean_5 = {k: v[ensemble_5, ...].mean(dim='members') for k, v in ensemble.items()}\n", "ensemble_std_5 = {k: v[ensemble_5, ...].std(dim='members') for k, v in ensemble.items()}" ] }, { "cell_type": "markdown", "id": "f8b31e39-d4b9-4347-953f-87af04c0dd7a", "metadata": { "tags": [] }, "source": [ "# Model validation" ] }, { "cell_type": "markdown", "id": "3e6ecc98-f32f-42f7-9971-64b270aa5453", "metadata": { "tags": [] }, "source": [ "## Bias, MAE, and RMSE for reference data" ] }, { "cell_type": "markdown", "id": "671cd3c0-8d6c-41c1-bf8e-93f5943bf9aa", "metadata": {}, "source": [ "Calculate yearly average bias, MAE, and RMSE over entire reference period for model predictions, ERA-5, and QM-adjusted ERA-5." ] }, { "cell_type": "code", "execution_count": null, "id": "7939a4d2-4eff-4507-86f8-dba7c0b635df", "metadata": {}, "outputs": [], "source": [ "# yearly average values over validation period\n", "y_refe_yearly_avg = y_refe.groupby('time.year').mean(dim='time')\n", "y_refe_qm_yearly_avg = y_refe_qm.groupby('time.year').mean(dim='time')\n", "y_true_yearly_avg = y_true.groupby('time.year').mean(dim='time')" ] }, { "cell_type": "code", "execution_count": null, "id": "64e29db7-998d-4952-84b0-1c79016ab9a9", "metadata": {}, "outputs": [], "source": [ "# yearly average bias, mae, and rmse for ERA-5\n", "bias_refe = y_refe_yearly_avg - y_true_yearly_avg\n", "mae_refe = np.abs(y_refe_yearly_avg - y_true_yearly_avg)\n", "rmse_refe = (y_refe_yearly_avg - y_true_yearly_avg) ** 2" ] }, { "cell_type": "code", "execution_count": null, "id": "d0d4c974-876f-45e6-85cc-df91501ead20", "metadata": {}, "outputs": [], "source": [ "# yearly average bias, mae, and rmse for QM-Adjusted ERA-5\n", "bias_refe_qm = y_refe_qm_yearly_avg - y_true_yearly_avg\n", "mae_refe_qm = np.abs(y_refe_qm_yearly_avg - y_true_yearly_avg)\n", "rmse_refe_qm = (y_refe_qm_yearly_avg - y_true_yearly_avg) ** 2" ] }, { "cell_type": "code", "execution_count": null, "id": "d6efe5b9-3a6d-41ea-9f26-295b167cf0af", "metadata": {}, "outputs": [], "source": [ "# compute validation metrics for reference datasets\n", "filename = RESULTS.joinpath(PREDICTAND, 'reference.csv')\n", "if filename.exists():\n", " # check if validation metrics for reference already exist\n", " df_ref = pd.read_csv(filename)\n", "else:\n", " # compute validation metrics\n", " df_ref = pd.DataFrame([], columns=['bias', 'mae', 'rmse', 'product'])\n", " for product, metrics in zip(['Era-5', 'Era-5 QM'], [[bias_refe, mae_refe, rmse_refe], [bias_refe_qm, mae_refe_qm, rmse_refe_qm]]):\n", " values = pd.DataFrame([[np.sqrt(m.mean().values.item()) if name == 'rmse' else m.mean().values.item() for\n", " name, m in zip(['bias', 'mae', 'rmse'], metrics)] + [product]], columns=df_ref.columns)\n", " df_ref = df_ref.append(values, ignore_index=True)\n", " \n", " # save metrics to disk\n", " df_ref.to_csv(filename, index=False)" ] }, { "cell_type": "markdown", "id": "258cb3c6-c2fc-457d-885e-28eaf48f1d5b", "metadata": { "tags": [] }, "source": [ "## Bias, MAE, and RMSE for model predictions" ] }, { "cell_type": "markdown", "id": "630ce1c5-b018-437f-a7cf-8c8d99cd8f84", "metadata": {}, "source": [ "Calculate yearly average bias, MAE, and RMSE over entire reference period for model predictions." ] }, { "cell_type": "code", "execution_count": null, "id": "6980833a-3848-43ca-bcca-d759b4fd9f69", "metadata": {}, "outputs": [], "source": [ "# yearly average bias, mae, and rmse for each ensemble member\n", "y_pred_yearly_avg = {k: v.groupby('time.year').mean(dim='time') for k, v in ensemble.items()}\n", "bias_pred = {k: v - y_true_yearly_avg for k, v in y_pred_yearly_avg.items()}\n", "mae_pred = {k: np.abs(v - y_true_yearly_avg) for k, v in y_pred_yearly_avg.items()}\n", "rmse_pred = {k: (v - y_true_yearly_avg) ** 2 for k, v in y_pred_yearly_avg.items()}" ] }, { "cell_type": "code", "execution_count": null, "id": "64f7a0b9-a772-4a03-9160-7839a48e56cd", "metadata": { "tags": [] }, "outputs": [], "source": [ "# compute validation metrics for model predictions\n", "filename = RESULTS.joinpath(PREDICTAND, 'prediction.csv')\n", "if filename.exists():\n", " # check if validation metrics for predictions already exist\n", " df_pred = pd.read_csv(filename)\n", "else:\n", " # validation metrics for each ensemble member\n", " df_pred = pd.DataFrame([], columns=['bias', 'mae', 'rmse', 'product', 'loss'])\n", " for k in y_pred_yearly_avg.keys():\n", " for i in range(len(bias_pred[k])):\n", " values = pd.DataFrame([[np.sqrt(m.mean().values.item()) if name == 'rmse' else m.mean().values.item()\n", " for name, m in zip(['bias', 'mae', 'rmse'], [bias_pred[k][i], mae_pred[k][i], rmse_pred[k][i]])] +\n", " [bias_pred[k][i].members.item()] + [k]],\n", " columns=df_pred.columns)\n", " df_pred = df_pred.append(values, ignore_index=True)\n", " \n", " # validation metrics for ensembles\n", " for name, ens in zip(['Ensemble-3', 'Ensemble-5', 'Ensemble-{:d}'.format(len(ensemble['L1Loss']))],\n", " [ensemble_mean_3, ensemble_mean_5, ensemble_mean_full]):\n", " for k, v in ens.items():\n", " yearly_avg = v.groupby('time.year').mean(dim='time')\n", " bias = (yearly_avg - y_true_yearly_avg).mean().values.item()\n", " mae = np.abs(yearly_avg - y_true_yearly_avg).mean().values.item()\n", " rmse = np.sqrt(((yearly_avg - y_true_yearly_avg) ** 2).mean().values.item())\n", " values = pd.DataFrame([[bias, mae, rmse, name, k]], columns=df_pred.columns)\n", " df_pred = df_pred.append(values, ignore_index=True)\n", " \n", " # save metrics to disk\n", " df_pred.to_csv(filename, index=False)" ] }, { "cell_type": "markdown", "id": "902e299c-a927-41b1-b2ae-987c30dee8cf", "metadata": {}, "source": [ "## Plot results" ] }, { "cell_type": "code", "execution_count": null, "id": "bdca9b54-3e05-49c8-b1b2-b8c782017306", "metadata": {}, "outputs": [], "source": [ "# create a sequential colormap: for reference data, single ensemble members, and ensemble mean predictions\n", "# palette = sns.color_palette('YlOrRd_r', 10) + sns.color_palette('Greens', 3)\n", "palette = sns.color_palette('Blues', len(LOSS))" ] }, { "cell_type": "markdown", "id": "3cfcd2de-cd37-42d5-b53d-e8abfd21e242", "metadata": { "tags": [] }, "source": [ "### Absolute values: single members vs. ensemble" ] }, { "cell_type": "code", "execution_count": null, "id": "48751f7f-9c26-471d-a75e-b7bb2fcb71be", "metadata": {}, "outputs": [], "source": [ "# dataframe of single members and ensembles only\n", "members = df_pred[~np.isin(df_pred['product'], ['Ensemble-{}'.format(i) for i in [3, 5, 10]])]\n", "ensemble = df_pred[~np.isin(df_pred['product'], ['Member-{}'.format(i) for i in range(10)])]" ] }, { "cell_type": "code", "execution_count": null, "id": "d7c8e987-0257-4263-ac4b-718a614c458f", "metadata": {}, "outputs": [], "source": [ "# initialize figure\n", "fig = plt.figure(figsize=(16, 5))\n", "\n", "# create grid for different subplots\n", "grid = gridspec.GridSpec(ncols=5, nrows=1, width_ratios=[3, 1, 1, 3, 1], wspace=0.05, hspace=0)\n", "\n", "# add subplots\n", "ax1 = fig.add_subplot(grid[0])\n", "ax2 = fig.add_subplot(grid[1], sharey=ax1)\n", "ax3 = fig.add_subplot(grid[3])\n", "ax4 = fig.add_subplot(grid[4], sharey=ax3)\n", "axes = [ax1, ax2, ax3, ax4]\n", "\n", "# plot bias: single members vs. ensemble\n", "sns.barplot(x='product', y='bias', hue='loss', data=members, palette=palette, ax=ax1);\n", "sns.barplot(x='product', y='bias', hue='loss', data=ensemble, palette=palette, ax=ax2);\n", "\n", "# plot mae: single members vs. ensemble\n", "sns.barplot(x='product', y='mae', hue='loss', data=members, palette=palette, ax=ax3);\n", "sns.barplot(x='product', y='mae', hue='loss', data=ensemble, palette=palette, ax=ax4);\n", "\n", "# axes limits and ticks\n", "y_lim_bias = (-50, 50) if PREDICTAND == 'pr' else (-1, 1)\n", "y_lim_mae = (0, 2) if PREDICTAND == 'pr' else (0, 1)\n", "y_ticks_bias = (np.arange(y_lim_bias[0], y_lim_bias[1] + 10, 10) if PREDICTAND == 'pr' else\n", " np.arange(y_lim_bias[0], y_lim_bias[1] + 0.2, 0.2))\n", "y_ticks_mae = (np.arange(y_lim_mae[0], y_lim_mae[1] + 10, 10) if PREDICTAND == 'pr' else\n", " np.arange(y_lim_mae[0], y_lim_mae[1] + 0.2, 0.2))\n", "\n", "# axis for bias\n", "ax1.set_ylabel('Bias (%)' if PREDICTAND == 'pr' else 'Bias (°C)')\n", "ax1.set_ylim(y_lim_bias)\n", "ax1.set_yticks(y_ticks_bias)\n", "\n", "# axis for mae\n", "ax3.set_ylabel('Mean absolute error (mm)' if PREDICTAND == 'pr' else 'Mean absolute error (°C)')\n", "ax3.set_ylim(y_lim_mae)\n", "ax3.set_yticks(y_ticks_mae)\n", "\n", "# adjust axis for ensemble predictions\n", "for ax in [ax2, ax4]:\n", " ax.yaxis.tick_right()\n", " ax.set_ylabel('')\n", "\n", "# axis fontsize and legend\n", "for ax in axes:\n", " ax.tick_params('both', labelsize=14)\n", " ax.set_xticklabels(ax.get_xticklabels(), rotation=90)\n", " ax.yaxis.label.set_size(14)\n", " ax.set_xlabel('')\n", " \n", " # adjust legend\n", " h, _ = ax.get_legend_handles_labels()\n", " ax.get_legend().remove()\n", "\n", "# show single legend\n", "ax4.legend(bbox_to_anchor=(1.3, 1.05), loc=2, frameon=False, fontsize=14);\n", "\n", "# save figure\n", "fig.savefig('./Figures/{}_members_vs_ensemble.pdf'.format(PREDICTAND), bbox_inches='tight')" ] }, { "cell_type": "markdown", "id": "590ffbaf-0e8d-4b63-9264-ad86078d50c9", "metadata": {}, "source": [ "### Absolute values: ensemble vs. reference" ] }, { "cell_type": "code", "execution_count": null, "id": "b1a9b1b7-9cd7-4998-afbb-11e64e91b333", "metadata": {}, "outputs": [], "source": [ "# initialize figure\n", "fig, axes = plt.subplots(1, 3, figsize=(16, 5))\n", "\n", "# plot bias: ensemble predictions vs. reference\n", "sns.barplot(x='product', y='bias', hue='loss', data=ensemble, palette=palette, ax=axes[0]);\n", "\n", "# plot mae: ensemble predictions vs. reference\n", "sns.barplot(x='product', y='mae', hue='loss', data=ensemble, palette=palette, ax=axes[1]);\n", "\n", "# plot rmse: ensemble predictions vs. reference\n", "sns.barplot(x='product', y='rmse', hue='loss', data=ensemble, palette=palette, ax=axes[2]);\n", "\n", "# add metrics for reference\n", "for ax, metric in zip(axes, ['bias', 'mae', 'rmse']):\n", " for product, ls in zip(df_ref['product'], ['-', '--']):\n", " ax.hlines(df_ref[metric][df_ref['product'] == product].item(), xmin=-0.5, xmax=2.5,\n", " color='k', ls=ls, label=product)\n", "\n", "# axis for bias\n", "axes[0].set_ylabel('Bias (%)' if PREDICTAND == 'pr' else 'Bias (°C)')\n", "axes[0].set_ylim(y_lim_bias)\n", "axes[0].set_yticks(y_ticks_bias)\n", "\n", "# axis for mae\n", "axes[1].set_ylabel('Mean absolute error (mm)' if PREDICTAND == 'pr' else 'Mean absolute error (°C)')\n", "axes[1].set_ylim(y_lim_mae)\n", "axes[1].set_yticks(y_ticks_mae)\n", "\n", "# axis for rmse\n", "axes[2].set_ylabel('RMSE (mm)' if PREDICTAND == 'pr' else 'RMSE (°C)')\n", "axes[2].set_ylim(y_lim_mae)\n", "axes[2].set_yticks(y_ticks_mae)\n", "\n", "# axis fontsize and legend\n", "for ax in axes:\n", " ax.tick_params('both', labelsize=14)\n", " ax.set_xticklabels(ax.get_xticklabels(), rotation=90)\n", " ax.yaxis.label.set_size(14)\n", " ax.set_xlabel('')\n", " \n", " # adjust legend\n", " h, _ = ax.get_legend_handles_labels()\n", " ax.get_legend().remove()\n", "\n", "# show single legend\n", "axes[-1].legend(bbox_to_anchor=(1.05, 1.05), loc=2, frameon=False, fontsize=14);\n", "\n", "# save figure\n", "fig.subplots_adjust(wspace=0.25)\n", "fig.savefig('./Figures/{}_ensemble_vs_reference.pdf'.format(PREDICTAND), bbox_inches='tight')" ] }, { "cell_type": "markdown", "id": "775a3c92-1027-49d2-9681-dd53e0af70ac", "metadata": { "tags": [] }, "source": [ "### Regional time series" ] }, { "cell_type": "code", "execution_count": null, "id": "dbe5db42-c31c-493b-a3b8-42c794cde6d9", "metadata": {}, "outputs": [], "source": [ "# whether to compute rolling or hard mean\n", "ROLLING = True" ] }, { "cell_type": "code", "execution_count": null, "id": "fae4d70c-276c-4ba6-b6b6-ba6eb1793e0c", "metadata": {}, "outputs": [], "source": [ "# define scale of mean time series\n", "# scale = '1M' # monthly\n", "scale = '1Y' # yearly" ] }, { "cell_type": "code", "execution_count": null, "id": "5eaaaf2f-d4c4-4f30-b124-66d04d6db2b9", "metadata": {}, "outputs": [], "source": [ "# mean time series over entire grid and validation period\n", "if ROLLING:\n", " y_pred_ts = ensemble_mean_full.rolling(time=365, center=True).mean().mean(dim=('y', 'x')).dropna('time').compute()\n", " y_pred_ts_var = ensemble_std_full.rolling(time=365, center=True).mean().mean(dim=('y', 'x')).dropna('time').compute()\n", " y_true_ts = y_true.rolling(time=365, center=True).mean().mean(dim=('y', 'x')).dropna('time').compute()\n", " y_refe_ts = y_refe.rolling(time=365, center=True).mean().mean(dim=('y', 'x')).dropna('time').compute()\n", " y_refe_qm_ts = y_refe_qm.rolling(time=365, center=True).mean().mean(dim=('y', 'x')).dropna('time').compute()\n", "else:\n", " y_pred_ts = ensemble_mean_full.resample(time=scale).mean(dim=('time', 'y', 'x')).compute()\n", " y_pred_ts_var = ensemble_std_full.resample(time=scale).mean(dim=('time', 'y', 'x')).compute()\n", " y_true_ts = y_true.resample(time=scale).mean(dim=('time', 'y', 'x')).compute()\n", " y_refe_ts = y_refe.resample(time=scale).mean(dim=('time', 'y', 'x')).compute()\n", " y_refe_qm_ts = y_refe_qm.resample(time=scale).mean(dim=('time', 'y', 'x')).compute()" ] }, { "cell_type": "code", "execution_count": null, "id": "28bc3177-b6a0-4938-9e74-59be2491fa56", "metadata": {}, "outputs": [], "source": [ "# color palette\n", "palette = sns.color_palette('viridis', 3)" ] }, { "cell_type": "code", "execution_count": null, "id": "07375015-4205-4dfb-9bd2-0f37d5e56672", "metadata": {}, "outputs": [], "source": [ "# factor of standard deviation to plot as uncertainty around ensemble mean prediction\n", "std_factor = 1" ] }, { "cell_type": "code", "execution_count": null, "id": "8ca32179-66ed-4f9d-a8f6-92cb547afe4a", "metadata": {}, "outputs": [], "source": [ "# initialize figure\n", "fig, ax = plt.subplots(1, 1, figsize=(16, 9))\n", "\n", "# time to plot on x-axis\n", "time = y_true_ts.time if ROLLING else [t.astype('datetime64[{}]'.format(scale.lstrip('1'))) for t in y_true_ts.time.values] \n", "xticks = [t.astype('datetime64[Y]') for t in list(y_true_ts.time.resample(time='1Y').groups.keys())]\n", "\n", "# plot reference: observations, ERA-5, ERA-5 QM-adjusted\n", "ax.plot(time, y_true_ts, label='Observed', ls='-', color='k');\n", "ax.plot(time, y_refe_ts, label='ERA-5', ls='-', color=palette[0]);\n", "ax.plot(time, y_refe_qm_ts, label='ERA-5 QM-adjusted', ls='-', color=palette[1]);\n", "\n", "# plot model predictions: median and IQR\n", "ax.plot(time, y_pred_ts, label='Prediction: Ensemble mean', color=palette[-1])\n", "ax.fill_between(x=time, y1=y_pred_ts - std_factor * y_pred_ts_var, y2=y_pred_ts + std_factor * y_pred_ts_var,\n", " alpha=0.3, label='Prediction: Ensemble std', color=palette[-1]);\n", "\n", "# add legend\n", "ax.legend(frameon=False, loc='lower right', fontsize=12)\n", "\n", "# axis limits and ticks\n", "ax.set_xticks(xticks)\n", "ax.set_xticklabels(xticks)\n", "ax.tick_params(axis='both', labelsize=12)\n", "\n", "# save figure\n", "fig.savefig('./Figures/{}_{}_{}_bootstrap_time_series_{}.png'.format(PREDICTAND, LOSS, OPTIM, scale if not ROLLING else 'rolling'),\n", " bbox_inches='tight', dpi=300)" ] }, { "cell_type": "markdown", "id": "923762ca-6ebc-4ffa-9b65-2faaf816fc05", "metadata": {}, "source": [ "### Spatial distributions" ] }, { "cell_type": "code", "execution_count": null, "id": "a520127b-0dbc-4217-9a00-68cef41afe83", "metadata": {}, "outputs": [], "source": [ "# compute ensemble mean yearly mean bias of each grid point\n", "pred = (ensemble_mean_full.groupby('time.year').mean(dim='time') - y_true_yearly_avg).mean(dim='year').compute()" ] }, { "cell_type": "code", "execution_count": null, "id": "e917db7e-ae9b-48e8-bb23-58905c47a910", "metadata": {}, "outputs": [], "source": [ "# plot yearly average bias of references and predictions\n", "vmin, vmax = -1, 1\n", "fig, axes = plt.subplots(1, 3, figsize=(24, 8), sharex=True, sharey=True)\n", "\n", "# plot bias of ERA-5 reference\n", "era5 = bias_refe.mean(dim='year')\n", "im1 = axes[0].imshow(era5.values, origin='lower', cmap='RdBu_r', vmin=vmin, vmax=vmax)\n", "\n", "# plot bias of ERA-5 QM-adjusted reference\n", "era5_qm = bias_refe_qm.mean(dim='year')\n", "im2 = axes[1].imshow(era5_qm.values, origin='lower', cmap='RdBu_r', vmin=vmin, vmax=vmax)\n", "\n", "# plot bias of ensemble model prediction\n", "im3 = axes[2].imshow(pred, origin='lower', cmap='RdBu_r', vmin=vmin, vmax=vmax)\n", "\n", "# set titles\n", "axes[0].set_title('Era-5', fontsize=14, pad=10);\n", "axes[1].set_title('Era-5: QM-adjusted', fontsize=14, pad=10);\n", "axes[2].set_title('Predictions: Ensemble mean', fontsize=14, pad=10)\n", "\n", "# adjust axes\n", "for ax in axes.flat:\n", " ax.axes.get_xaxis().set_ticklabels([])\n", " ax.axes.get_xaxis().set_ticks([])\n", " ax.axes.get_yaxis().set_ticklabels([])\n", " ax.axes.get_yaxis().set_ticks([])\n", " ax.axes.axis('tight')\n", " ax.set_xlabel('')\n", " ax.set_ylabel('')\n", " ax.set_axis_off()\n", "\n", "# adjust figure\n", "fig.subplots_adjust(hspace=0, wspace=0, top=0.85)\n", "\n", "# add colorbar\n", "axes = axes.flatten()\n", "cbar_ax_bias = fig.add_axes([axes[-1].get_position().x1 + 0.01, axes[-1].get_position().y0,\n", " 0.01, axes[-1].get_position().y1 - axes[-1].get_position().y0])\n", "cbar_bias = fig.colorbar(im3, cax=cbar_ax_bias)\n", "cbar_bias.set_label(label='Bias (°C)', fontsize=14)\n", "cbar_bias.ax.tick_params(labelsize=14, pad=10)\n", "\n", "# save figure\n", "fig.savefig('../Notebooks/Figures/{}_{}_{}_bootstrap_bias.png'.format(PREDICTAND, LOSS, OPTIM), dpi=300, bbox_inches='tight')" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.12" } }, "nbformat": 4, "nbformat_minor": 5 }