Skip to content
Snippets Groups Projects
eval_bootstrap.ipynb 30.6 KiB
Newer Older
{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "fde8874d-299f-4f48-a10a-9fb6a00b43b9",
   "metadata": {},
   "source": [
    "# Evaluate bootstrapped model results"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "969d063b-5262-4324-901f-0a48630c4f27",
   "metadata": {
    "tags": []
   },
    "## Imports and constants"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8af00ae4-4aeb-4ff8-a46a-65966b28c440",
   "metadata": {},
   "outputs": [],
   "source": [
    "# builtins\n",
    "import pathlib\n",
    "\n",
    "# externals\n",
    "import numpy as np\n",
    "import xarray as xr\n",
    "import pandas as pd\n",
    "import matplotlib.pyplot as plt\n",
Frisinghelli Daniel's avatar
Frisinghelli Daniel committed
    "from matplotlib import gridspec\n",
    "\n",
    "# locals\n",
    "from climax.main.io import OBS_PATH, ERA5_PATH\n",
    "from climax.main.config import VALID_PERIOD\n",
    "from pysegcnn.core.utils import search_files"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5bc74835-dc59-46ed-849b-3ff614e53eee",
   "metadata": {},
   "outputs": [],
   "source": [
    "# mapping from predictands to variable names\n",
    "NAMES = {'tasmin': 'minimum temperature', 'tasmax': 'maximum temperature', 'pr': 'precipitation'}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c8a63ef3-35ef-4ffa-b1f3-5c2986eb7eb1",
   "metadata": {},
   "outputs": [],
   "source": [
    "# path to bootstrapped model results\n",
    "RESULTS = pathlib.Path('/mnt/CEPH_PROJECTS/FACT_CLIMAX/ERA5_PRED/bootstrap')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7eae545b-4d8a-4689-a6c0-4aba2cb9104e",
   "metadata": {
    "tags": []
   },
Frisinghelli Daniel's avatar
Frisinghelli Daniel committed
    "## Search model configurations"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
Frisinghelli Daniel's avatar
Frisinghelli Daniel committed
   "id": "3b83c9f3-7081-4cec-8f23-c4de007839d7",
   "metadata": {},
   "outputs": [],
   "source": [
    "# predictand to evaluate\n",
Frisinghelli Daniel's avatar
Frisinghelli Daniel committed
    "PREDICTAND = 'tasmin'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3e856f80-14fd-405f-a44e-cc77863f8e5b",
   "metadata": {},
   "outputs": [],
   "source": [
    "# loss function and optimizer\n",
    "LOSS = ['L1Loss', 'MSELoss', 'BernoulliGammaLoss'] if PREDICTAND == 'pr' else ['L1Loss', 'MSELoss']\n",
    "OPTIM = 'Adam'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "011b792d-7349-44ad-997d-11f236472a11",
   "metadata": {},
   "outputs": [],
   "source": [
    "# model to evaluate\n",
Frisinghelli Daniel's avatar
Frisinghelli Daniel committed
    "models = ['USegNet_{}_ztuvq_500_850_p_dem_doy_1mm_{}_{}'.format(PREDICTAND, loss, OPTIM) if loss == 'BernoulliGammaLoss' else\n",
    "          'USegNet_{}_ztuvq_500_850_p_dem_doy_{}_{}'.format(PREDICTAND, loss, OPTIM) for loss in LOSS]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "dc4ca6f0-5490-4522-8661-e36bd1be11b7",
   "metadata": {},
   "outputs": [],
   "source": [
    "# get bootstrapped models\n",
Frisinghelli Daniel's avatar
Frisinghelli Daniel committed
    "models = {loss: sorted(search_files(RESULTS.joinpath(PREDICTAND), model + '(.*).nc$'),\n",
    "                       key=lambda x: int(x.stem.split('_')[-1])) for loss, model in zip(LOSS, models)}\n",
  {
   "cell_type": "markdown",
   "id": "5a64795a-6e5c-409a-8b3b-c738a96fa255",
   "metadata": {
    "tags": []
   },
   "source": [
    "## Load datasets"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e790ed9f-451c-4368-849d-06d9c50f797c",
   "metadata": {},
   "source": [
    "### Load observations"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0862e0c8-06df-45d6-bc1b-002ffb6e9915",
   "metadata": {},
   "outputs": [],
   "source": [
    "# load observations\n",
    "y_true = xr.open_dataset(OBS_PATH.joinpath(PREDICTAND, 'OBS_{}_1980_2018.nc'.format(PREDICTAND)),\n",
    "                         chunks={'time': 365})\n",
    "y_true = y_true.sel(time=VALID_PERIOD)  # subset to time period covered by predictions\n",
    "y_true = y_true.rename({NAMES[PREDICTAND]: PREDICTAND}) if PREDICTAND == 'pr' else y_true"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "aba38642-85d1-404a-81f3-65d23985fb7a",
   "metadata": {},
   "outputs": [],
   "source": [
    "# mask of missing values\n",
    "missing = np.isnan(y_true[PREDICTAND])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d4512ed2-d503-4bc1-ae76-84560c101a14",
   "metadata": {},
   "source": [
    "### Load reference data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f90f6abf-5fd6-49c0-a1ad-f62242b3d3a0",
   "metadata": {},
   "outputs": [],
   "source": [
    "# ERA-5 reference dataset\n",
    "if PREDICTAND == 'pr':\n",
    "    y_refe = xr.open_dataset(search_files(ERA5_PATH.joinpath('ERA5', 'total_precipitation'), '.nc$').pop(),\n",
    "                             chunks={'time': 365})\n",
    "    y_refe = y_refe.rename({'tp': 'pr'})\n",
    "else:\n",
    "    y_refe = xr.open_dataset(search_files(ERA5_PATH.joinpath('ERA5', '2m_{}_temperature'.format(PREDICTAND.lstrip('tas'))), '.nc$').pop(),\n",
    "                             chunks={'time': 365})\n",
    "    y_refe = y_refe - 273.15  # convert to °C\n",
    "    y_refe = y_refe.rename({'t2m': PREDICTAND})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ea6d5f56-4f39-4e9a-976d-00ff28fce95c",
   "metadata": {},
   "outputs": [],
   "source": [
    "# subset to time period covered by predictions\n",
    "y_refe = y_refe.sel(time=VALID_PERIOD).drop_vars('lambert_azimuthal_equal_area')\n",
    "y_refe = y_refe.transpose('time', 'y', 'x')  # change order of dimensions"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d37702de-da5f-4306-acc1-e569471c1f12",
   "metadata": {},
   "source": [
    "### Load QM-adjusted reference data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fffbd267-d08b-44f4-869c-7056c4f19c28",
   "metadata": {},
   "outputs": [],
   "source": [
    "y_refe_qm = xr.open_dataset(ERA5_PATH.joinpath('QM_ERA5_{}_day_19912010.nc'.format(PREDICTAND)), chunks={'time': 365})\n",
    "y_refe_qm = y_refe_qm.transpose('time', 'y', 'x')  # change order of dimensions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "16fa580e-27a7-4758-9164-7f607df7179d",
   "metadata": {},
   "outputs": [],
   "source": [
    "# center hours at 00:00:00 rather than 12:00:00\n",
    "y_refe_qm['time'] = np.asarray([t.astype('datetime64[D]') for t in y_refe_qm.time.values])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6789791f-006b-49b3-aa04-34e4ed8e1571",
   "metadata": {},
   "outputs": [],
   "source": [
    "# subset to time period covered by predictions\n",
    "y_refe_qm = y_refe_qm.sel(time=VALID_PERIOD).drop_vars('lambert_azimuthal_equal_area')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b51cfb3f-caa8-413e-a12d-47bbafcef1df",
   "metadata": {},
   "outputs": [],
   "source": [
    "# align datasets and mask missing values\n",
    "y_true, y_refe, y_refe_qm = xr.align(y_true[PREDICTAND], y_refe[PREDICTAND], y_refe_qm[PREDICTAND], join='override')\n",
    "y_refe = y_refe.where(~missing, other=np.nan)\n",
    "y_refe_qm = y_refe_qm.where(~missing, other=np.nan)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b4a6c286-6b88-487d-866c-3cb633686dac",
   "metadata": {},
   "source": [
    "### Load model predictions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
Frisinghelli Daniel's avatar
Frisinghelli Daniel committed
   "id": "eb889059-17e4-4d8c-b796-e8b1e2d0bf8c",
   "metadata": {},
   "outputs": [],
   "source": [
Frisinghelli Daniel's avatar
Frisinghelli Daniel committed
    "y_pred_raw = {k: [xr.open_dataset(v, chunks={'time': 365}) for v in models[k]] for k in models.keys()}\n",
    "if PREDICTAND == 'pr':\n",
Frisinghelli Daniel's avatar
Frisinghelli Daniel committed
    "    y_pred_raw = {k: [v.rename({NAMES[PREDICTAND]: PREDICTAND}) for v in y_pred_raw[k]] for k in y_pred_raw.keys()}\n",
    "    y_pred_raw = {k: [v.transpose('time', 'y', 'x') for v in y_pred_raw[k]] for k in y_pred_raw.keys()}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
Frisinghelli Daniel's avatar
Frisinghelli Daniel committed
   "id": "534e020d-96b2-403c-b8e4-86de98fbbe3b",
   "metadata": {},
   "outputs": [],
   "source": [
    "# align datasets and mask missing values\n",
Frisinghelli Daniel's avatar
Frisinghelli Daniel committed
    "y_prob = {}\n",
    "y_pred = {}\n",
    "for loss, models in y_pred_raw.items():\n",
    "    y_pred[loss], y_prob[loss] = [], []\n",
    "    for y_p in models:\n",
    "        # check whether evaluating precipitation or temperatures\n",
    "        if len(y_p.data_vars) > 1:\n",
    "            _, _, y_p, y_p_prob = xr.align(y_true, y_refe, y_p[PREDICTAND], y_p.prob, join='override')\n",
    "            y_p_prob = y_p_prob.where(~missing, other=np.nan)  # mask missing values\n",
    "            y_prob[loss].append(y_p_prob)\n",
    "        else:\n",
    "            _, _, y_p = xr.align(y_true, y_refe, y_p[PREDICTAND], join='override')\n",
    "\n",
    "        # mask missing values\n",
    "        y_p = y_p.where(~missing, other=np.nan)\n",
    "        y_pred[loss].append(y_p)"
   "id": "6a718ea3-54d3-400a-8c89-76d04347de2d",
   "metadata": {
    "tags": []
   },
    "## Ensemble predictions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5a6c0bfe-c1d2-4e43-9f8e-35c63c46bb10",
   "metadata": {},
   "outputs": [],
   "source": [
    "# create ensemble dataset\n",
Frisinghelli Daniel's avatar
Frisinghelli Daniel committed
    "ensemble = {k: xr.Dataset({'Member-{}'.format(i): member for i, member in enumerate(y_pred[k])}).to_array('members')\n",
    "            for k in y_pred.keys() if y_pred[k]}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0e526227-cd4c-4a1c-ab72-51b72a4f821f",
   "metadata": {},
   "outputs": [],
   "source": [
    "# full ensemble mean prediction and standard deviation\n",
Frisinghelli Daniel's avatar
Frisinghelli Daniel committed
    "ensemble_mean_full = {k: v.mean(dim='members') for k, v in ensemble.items()}\n",
    "ensemble_std_full = {k: v.std(dim='members') for k, v in ensemble.items()}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d4a70701-2823-4106-ad6a-3272b678d0f9",
   "metadata": {},
   "outputs": [],
   "source": [
    "# ensemble mean prediction using three random members\n",
Frisinghelli Daniel's avatar
Frisinghelli Daniel committed
    "ensemble_3 = np.random.randint(0, len(ensemble['L1Loss'].members), size=3)\n",
    "ensemble_mean_3 = {k: v[ensemble_3, ...].mean(dim='members') for k, v in ensemble.items()}\n",
    "ensemble_std_3 = {k: v[ensemble_3, ...].std(dim='members') for k, v in ensemble.items()}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c4d18814-1340-4ed4-8102-2ccd6f0c2664",
   "metadata": {},
   "outputs": [],
   "source": [
    "# ensemble mean prediction using five random members\n",
Frisinghelli Daniel's avatar
Frisinghelli Daniel committed
    "ensemble_5 = np.random.randint(0, len(ensemble['L1Loss'].members), size=5)\n",
    "ensemble_mean_5 = {k: v[ensemble_5, ...].mean(dim='members') for k, v in ensemble.items()}\n",
    "ensemble_std_5 = {k: v[ensemble_5, ...].std(dim='members') for k, v in ensemble.items()}"
   "cell_type": "markdown",
   "id": "f8b31e39-d4b9-4347-953f-87af04c0dd7a",
   "metadata": {
    "tags": []
   },
    "# Model validation"
   "id": "3e6ecc98-f32f-42f7-9971-64b270aa5453",
   "metadata": {
    "tags": []
   },
    "## Bias, MAE, and RMSE for reference data"
   "id": "671cd3c0-8d6c-41c1-bf8e-93f5943bf9aa",
   "metadata": {},
   "source": [
    "Calculate yearly average bias, MAE, and RMSE over entire reference period for model predictions, ERA-5, and QM-adjusted ERA-5."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7939a4d2-4eff-4507-86f8-dba7c0b635df",
   "metadata": {},
   "outputs": [],
   "source": [
    "# yearly average values over validation period\n",
    "y_refe_yearly_avg = y_refe.groupby('time.year').mean(dim='time')\n",
    "y_refe_qm_yearly_avg = y_refe_qm.groupby('time.year').mean(dim='time')\n",
    "y_true_yearly_avg = y_true.groupby('time.year').mean(dim='time')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "64e29db7-998d-4952-84b0-1c79016ab9a9",
   "metadata": {},
   "outputs": [],
   "source": [
    "# yearly average bias, mae, and rmse for ERA-5\n",
    "bias_refe = y_refe_yearly_avg - y_true_yearly_avg\n",
    "mae_refe = np.abs(y_refe_yearly_avg - y_true_yearly_avg)\n",
    "rmse_refe = (y_refe_yearly_avg - y_true_yearly_avg) ** 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d0d4c974-876f-45e6-85cc-df91501ead20",
   "metadata": {},
   "outputs": [],
   "source": [
    "# yearly average bias, mae, and rmse for QM-Adjusted ERA-5\n",
    "bias_refe_qm = y_refe_qm_yearly_avg - y_true_yearly_avg\n",
    "mae_refe_qm = np.abs(y_refe_qm_yearly_avg - y_true_yearly_avg)\n",
    "rmse_refe_qm = (y_refe_qm_yearly_avg - y_true_yearly_avg) ** 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d6efe5b9-3a6d-41ea-9f26-295b167cf0af",
   "metadata": {},
   "outputs": [],
   "source": [
Frisinghelli Daniel's avatar
Frisinghelli Daniel committed
    "# compute validation metrics for reference datasets\n",
    "filename = RESULTS.joinpath(PREDICTAND, 'reference.csv')\n",
    "if filename.exists():\n",
    "    # check if validation metrics for reference already exist\n",
    "    df_ref = pd.read_csv(filename)\n",
    "else:\n",
    "    # compute validation metrics\n",
    "    df_ref = pd.DataFrame([], columns=['bias', 'mae', 'rmse', 'product'])\n",
    "    for product, metrics in zip(['Era-5', 'Era-5 QM'], [[bias_refe, mae_refe, rmse_refe], [bias_refe_qm, mae_refe_qm, rmse_refe_qm]]):\n",
    "        values = pd.DataFrame([[np.sqrt(m.mean().values.item()) if name == 'rmse' else m.mean().values.item() for\n",
    "                                name, m in zip(['bias', 'mae', 'rmse'], metrics)] + [product]], columns=df_ref.columns)\n",
    "        df_ref = df_ref.append(values, ignore_index=True)\n",
    "    \n",
    "    # save metrics to disk\n",
    "    df_ref.to_csv(filename, index=False)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "258cb3c6-c2fc-457d-885e-28eaf48f1d5b",
   "metadata": {
    "tags": []
   },
   "source": [
    "## Bias, MAE, and RMSE for model predictions"
   ]
  },
  {
   "cell_type": "markdown",
Frisinghelli Daniel's avatar
Frisinghelli Daniel committed
   "id": "630ce1c5-b018-437f-a7cf-8c8d99cd8f84",
   "metadata": {},
   "source": [
Frisinghelli Daniel's avatar
Frisinghelli Daniel committed
    "Calculate yearly average bias, MAE, and RMSE over entire reference period for model predictions."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6980833a-3848-43ca-bcca-d759b4fd9f69",
   "metadata": {},
   "outputs": [],
   "source": [
    "# yearly average bias, mae, and rmse for each ensemble member\n",
Frisinghelli Daniel's avatar
Frisinghelli Daniel committed
    "y_pred_yearly_avg = {k: v.groupby('time.year').mean(dim='time') for k, v in ensemble.items()}\n",
    "bias_pred = {k: v - y_true_yearly_avg for k, v in y_pred_yearly_avg.items()}\n",
    "mae_pred = {k: np.abs(v - y_true_yearly_avg) for k, v in y_pred_yearly_avg.items()}\n",
    "rmse_pred = {k: (v - y_true_yearly_avg) ** 2 for k, v in y_pred_yearly_avg.items()}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "64f7a0b9-a772-4a03-9160-7839a48e56cd",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
Frisinghelli Daniel's avatar
Frisinghelli Daniel committed
    "# compute validation metrics for model predictions\n",
    "filename = RESULTS.joinpath(PREDICTAND, 'prediction.csv')\n",
    "if filename.exists():\n",
    "    # check if validation metrics for predictions already exist\n",
    "    df_pred = pd.read_csv(filename)\n",
    "else:\n",
    "    # validation metrics for each ensemble member\n",
    "    df_pred = pd.DataFrame([], columns=['bias', 'mae', 'rmse', 'product', 'loss'])\n",
    "    for k in y_pred_yearly_avg.keys():\n",
    "        for i in range(len(bias_pred[k])):\n",
    "            values = pd.DataFrame([[np.sqrt(m.mean().values.item()) if name == 'rmse' else m.mean().values.item()\n",
    "                                    for name, m in zip(['bias', 'mae', 'rmse'], [bias_pred[k][i], mae_pred[k][i], rmse_pred[k][i]])] +\n",
    "                                   [bias_pred[k][i].members.item()] + [k]],\n",
    "                                  columns=df_pred.columns)\n",
    "            df_pred = df_pred.append(values, ignore_index=True)\n",
    "        \n",
    "    # validation metrics for ensembles\n",
    "    for name, ens in zip(['Ensemble-3', 'Ensemble-5', 'Ensemble-{:d}'.format(len(ensemble['L1Loss']))],\n",
    "                         [ensemble_mean_3, ensemble_mean_5, ensemble_mean_full]):\n",
    "        for k, v in ens.items():\n",
    "            yearly_avg = v.groupby('time.year').mean(dim='time')\n",
    "            bias = (yearly_avg - y_true_yearly_avg).mean().values.item()\n",
    "            mae = np.abs(yearly_avg - y_true_yearly_avg).mean().values.item()\n",
    "            rmse = np.sqrt(((yearly_avg - y_true_yearly_avg) ** 2).mean().values.item())\n",
    "            values = pd.DataFrame([[bias, mae, rmse, name, k]], columns=df_pred.columns)\n",
    "            df_pred = df_pred.append(values, ignore_index=True)\n",
    "    \n",
    "    # save metrics to disk\n",
    "    df_pred.to_csv(filename, index=False)"
  {
   "cell_type": "markdown",
Frisinghelli Daniel's avatar
Frisinghelli Daniel committed
   "id": "902e299c-a927-41b1-b2ae-987c30dee8cf",
   "metadata": {},
   "source": [
Frisinghelli Daniel's avatar
Frisinghelli Daniel committed
    "## Plot results"
  {
   "cell_type": "code",
   "execution_count": null,
Frisinghelli Daniel's avatar
Frisinghelli Daniel committed
   "id": "bdca9b54-3e05-49c8-b1b2-b8c782017306",
   "metadata": {},
   "outputs": [],
   "source": [
Frisinghelli Daniel's avatar
Frisinghelli Daniel committed
    "# create a sequential colormap: for reference data, single ensemble members, and ensemble mean predictions\n",
    "# palette = sns.color_palette('YlOrRd_r', 10) + sns.color_palette('Greens', 3)\n",
    "palette = sns.color_palette('Blues', len(LOSS))"
   ]
  },
  {
   "cell_type": "markdown",
Frisinghelli Daniel's avatar
Frisinghelli Daniel committed
   "id": "3cfcd2de-cd37-42d5-b53d-e8abfd21e242",
   "metadata": {
    "tags": []
   },
   "source": [
Frisinghelli Daniel's avatar
Frisinghelli Daniel committed
    "### Absolute values: single members vs. ensemble"
Frisinghelli Daniel's avatar
Frisinghelli Daniel committed
   "cell_type": "code",
   "execution_count": null,
   "id": "48751f7f-9c26-471d-a75e-b7bb2fcb71be",
   "metadata": {},
Frisinghelli Daniel's avatar
Frisinghelli Daniel committed
   "outputs": [],
   "source": [
Frisinghelli Daniel's avatar
Frisinghelli Daniel committed
    "# dataframe of single members and ensembles only\n",
    "members = df_pred[~np.isin(df_pred['product'], ['Ensemble-{}'.format(i) for i in [3, 5, 10]])]\n",
    "ensemble = df_pred[~np.isin(df_pred['product'], ['Member-{}'.format(i) for i in range(10)])]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
Frisinghelli Daniel's avatar
Frisinghelli Daniel committed
   "id": "d7c8e987-0257-4263-ac4b-718a614c458f",
   "metadata": {},
   "outputs": [],
   "source": [
Frisinghelli Daniel's avatar
Frisinghelli Daniel committed
    "# initialize figure\n",
    "fig = plt.figure(figsize=(16, 5))\n",
    "\n",
    "# create grid for different subplots\n",
    "grid = gridspec.GridSpec(ncols=5, nrows=1, width_ratios=[3, 1, 1, 3, 1], wspace=0.05, hspace=0)\n",
    "\n",
    "# add subplots\n",
    "ax1 = fig.add_subplot(grid[0])\n",
    "ax2 = fig.add_subplot(grid[1], sharey=ax1)\n",
    "ax3 = fig.add_subplot(grid[3])\n",
    "ax4 = fig.add_subplot(grid[4], sharey=ax3)\n",
    "axes = [ax1, ax2, ax3, ax4]\n",
    "\n",
    "# plot bias: single members vs. ensemble\n",
    "sns.barplot(x='product', y='bias', hue='loss', data=members, palette=palette, ax=ax1);\n",
    "sns.barplot(x='product', y='bias', hue='loss', data=ensemble, palette=palette, ax=ax2);\n",
    "\n",
    "# plot mae: single members vs. ensemble\n",
    "sns.barplot(x='product', y='mae', hue='loss', data=members, palette=palette, ax=ax3);\n",
    "sns.barplot(x='product', y='mae', hue='loss', data=ensemble, palette=palette, ax=ax4);\n",
    "\n",
    "# axes limits and ticks\n",
    "y_lim_bias = (-50, 50) if PREDICTAND == 'pr' else (-1, 1)\n",
    "y_lim_mae = (0, 2) if PREDICTAND == 'pr' else (0, 1)\n",
    "y_ticks_bias = (np.arange(y_lim_bias[0], y_lim_bias[1] + 10, 10) if PREDICTAND == 'pr' else\n",
    "                np.arange(y_lim_bias[0], y_lim_bias[1] + 0.2, 0.2))\n",
    "y_ticks_mae = (np.arange(y_lim_mae[0], y_lim_mae[1] + 10, 10) if PREDICTAND == 'pr' else\n",
    "               np.arange(y_lim_mae[0], y_lim_mae[1] + 0.2, 0.2))\n",
    "\n",
    "# axis for bias\n",
    "ax1.set_ylabel('Bias (%)' if PREDICTAND == 'pr' else 'Bias (°C)')\n",
    "ax1.set_ylim(y_lim_bias)\n",
    "ax1.set_yticks(y_ticks_bias)\n",
    "\n",
    "# axis for mae\n",
    "ax3.set_ylabel('Mean absolute error (mm)' if PREDICTAND == 'pr' else 'Mean absolute error (°C)')\n",
    "ax3.set_ylim(y_lim_mae)\n",
    "ax3.set_yticks(y_ticks_mae)\n",
    "\n",
    "# adjust axis for ensemble predictions\n",
    "for ax in [ax2, ax4]:\n",
    "    ax.yaxis.tick_right()\n",
    "    ax.set_ylabel('')\n",
    "\n",
    "# axis fontsize and legend\n",
    "for ax in axes:\n",
    "    ax.tick_params('both', labelsize=14)\n",
    "    ax.set_xticklabels(ax.get_xticklabels(), rotation=90)\n",
    "    ax.yaxis.label.set_size(14)\n",
    "    ax.set_xlabel('')\n",
    "    \n",
    "    # adjust legend\n",
    "    h, _ = ax.get_legend_handles_labels()\n",
    "    ax.get_legend().remove()\n",
    "\n",
    "# show single legend\n",
    "ax4.legend(bbox_to_anchor=(1.3, 1.05), loc=2, frameon=False, fontsize=14);\n",
    "\n",
    "# save figure\n",
    "fig.savefig('./Figures/{}_members_vs_ensemble.pdf'.format(PREDICTAND), bbox_inches='tight')"
Frisinghelli Daniel's avatar
Frisinghelli Daniel committed
   "cell_type": "markdown",
   "id": "590ffbaf-0e8d-4b63-9264-ad86078d50c9",
Frisinghelli Daniel's avatar
Frisinghelli Daniel committed
    "### Absolute values: ensemble vs. reference"
  {
   "cell_type": "code",
   "execution_count": null,
Frisinghelli Daniel's avatar
Frisinghelli Daniel committed
   "id": "b1a9b1b7-9cd7-4998-afbb-11e64e91b333",
   "metadata": {},
   "outputs": [],
   "source": [
    "# initialize figure\n",
Frisinghelli Daniel's avatar
Frisinghelli Daniel committed
    "fig, axes = plt.subplots(1, 3, figsize=(16, 5))\n",
Frisinghelli Daniel's avatar
Frisinghelli Daniel committed
    "# plot bias: ensemble predictions vs. reference\n",
    "sns.barplot(x='product', y='bias', hue='loss', data=ensemble, palette=palette, ax=axes[0]);\n",
    "\n",
    "# plot mae: ensemble predictions vs. reference\n",
    "sns.barplot(x='product', y='mae', hue='loss', data=ensemble, palette=palette, ax=axes[1]);\n",
    "\n",
    "# plot rmse: ensemble predictions vs. reference\n",
    "sns.barplot(x='product', y='rmse', hue='loss', data=ensemble, palette=palette, ax=axes[2]);\n",
    "\n",
    "# add metrics for reference\n",
    "for ax, metric in zip(axes, ['bias', 'mae', 'rmse']):\n",
    "    for product, ls in zip(df_ref['product'], ['-', '--']):\n",
    "        ax.hlines(df_ref[metric][df_ref['product'] == product].item(), xmin=-0.5, xmax=2.5,\n",
    "                  color='k', ls=ls, label=product)\n",
    "\n",
    "# axis for bias\n",
    "axes[0].set_ylabel('Bias (%)' if PREDICTAND == 'pr' else 'Bias (°C)')\n",
    "axes[0].set_ylim(y_lim_bias)\n",
    "axes[0].set_yticks(y_ticks_bias)\n",
    "\n",
    "# axis for mae\n",
    "axes[1].set_ylabel('Mean absolute error (mm)' if PREDICTAND == 'pr' else 'Mean absolute error (°C)')\n",
    "axes[1].set_ylim(y_lim_mae)\n",
    "axes[1].set_yticks(y_ticks_mae)\n",
    "\n",
    "# axis for rmse\n",
    "axes[2].set_ylabel('RMSE (mm)' if PREDICTAND == 'pr' else 'RMSE (°C)')\n",
    "axes[2].set_ylim(y_lim_mae)\n",
    "axes[2].set_yticks(y_ticks_mae)\n",
    "\n",
    "# axis fontsize and legend\n",
    "for ax in axes:\n",
    "    ax.tick_params('both', labelsize=14)\n",
    "    ax.set_xticklabels(ax.get_xticklabels(), rotation=90)\n",
Frisinghelli Daniel's avatar
Frisinghelli Daniel committed
    "    ax.yaxis.label.set_size(14)\n",
    "    ax.set_xlabel('')\n",
    "    \n",
    "    # adjust legend\n",
    "    h, _ = ax.get_legend_handles_labels()\n",
    "    ax.get_legend().remove()\n",
    "\n",
    "# show single legend\n",
    "axes[-1].legend(bbox_to_anchor=(1.05, 1.05), loc=2, frameon=False, fontsize=14);\n",
    "\n",
    "# save figure\n",
    "fig.subplots_adjust(wspace=0.25)\n",
    "fig.savefig('./Figures/{}_ensemble_vs_reference.pdf'.format(PREDICTAND), bbox_inches='tight')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "775a3c92-1027-49d2-9681-dd53e0af70ac",
   "metadata": {
    "tags": []
   },
   "source": [
    "### Regional time series"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "dbe5db42-c31c-493b-a3b8-42c794cde6d9",
   "metadata": {},
   "outputs": [],
   "source": [
    "# whether to compute rolling or hard mean\n",
    "ROLLING = True"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fae4d70c-276c-4ba6-b6b6-ba6eb1793e0c",
   "metadata": {},
   "outputs": [],
   "source": [
    "# define scale of mean time series\n",
    "# scale = '1M'  # monthly\n",
    "scale = '1Y'  # yearly"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5eaaaf2f-d4c4-4f30-b124-66d04d6db2b9",
   "metadata": {},
   "outputs": [],
   "source": [
    "# mean time series over entire grid and validation period\n",
    "if ROLLING:\n",
    "    y_pred_ts = ensemble_mean_full.rolling(time=365, center=True).mean().mean(dim=('y', 'x')).dropna('time').compute()\n",
    "    y_pred_ts_var = ensemble_std_full.rolling(time=365, center=True).mean().mean(dim=('y', 'x')).dropna('time').compute()\n",
    "    y_true_ts = y_true.rolling(time=365, center=True).mean().mean(dim=('y', 'x')).dropna('time').compute()\n",
    "    y_refe_ts = y_refe.rolling(time=365, center=True).mean().mean(dim=('y', 'x')).dropna('time').compute()\n",
    "    y_refe_qm_ts = y_refe_qm.rolling(time=365, center=True).mean().mean(dim=('y', 'x')).dropna('time').compute()\n",
    "else:\n",
    "    y_pred_ts = ensemble_mean_full.resample(time=scale).mean(dim=('time', 'y', 'x')).compute()\n",
    "    y_pred_ts_var = ensemble_std_full.resample(time=scale).mean(dim=('time', 'y', 'x')).compute()\n",
    "    y_true_ts = y_true.resample(time=scale).mean(dim=('time', 'y', 'x')).compute()\n",
    "    y_refe_ts = y_refe.resample(time=scale).mean(dim=('time', 'y', 'x')).compute()\n",
    "    y_refe_qm_ts = y_refe_qm.resample(time=scale).mean(dim=('time', 'y', 'x')).compute()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "28bc3177-b6a0-4938-9e74-59be2491fa56",
   "metadata": {},
   "outputs": [],
   "source": [
    "# color palette\n",
    "palette = sns.color_palette('viridis', 3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "07375015-4205-4dfb-9bd2-0f37d5e56672",
   "metadata": {},
   "outputs": [],
   "source": [
    "# factor of standard deviation to plot as uncertainty around ensemble mean prediction\n",
    "std_factor = 1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8ca32179-66ed-4f9d-a8f6-92cb547afe4a",
   "metadata": {},
   "outputs": [],
   "source": [
    "# initialize figure\n",
    "fig, ax = plt.subplots(1, 1, figsize=(16, 9))\n",
    "\n",
    "# time to plot on x-axis\n",
    "time = y_true_ts.time if ROLLING else [t.astype('datetime64[{}]'.format(scale.lstrip('1'))) for t in y_true_ts.time.values] \n",
    "xticks = [t.astype('datetime64[Y]') for t in list(y_true_ts.time.resample(time='1Y').groups.keys())]\n",
    "\n",
    "# plot reference: observations, ERA-5, ERA-5 QM-adjusted\n",
    "ax.plot(time, y_true_ts, label='Observed', ls='-', color='k');\n",
    "ax.plot(time, y_refe_ts, label='ERA-5', ls='-', color=palette[0]);\n",
    "ax.plot(time, y_refe_qm_ts, label='ERA-5 QM-adjusted', ls='-', color=palette[1]);\n",
    "\n",
    "# plot model predictions: median and IQR\n",
    "ax.plot(time, y_pred_ts, label='Prediction: Ensemble mean', color=palette[-1])\n",
    "ax.fill_between(x=time, y1=y_pred_ts - std_factor * y_pred_ts_var, y2=y_pred_ts + std_factor * y_pred_ts_var,\n",
    "                alpha=0.3, label='Prediction: Ensemble std', color=palette[-1]);\n",
    "\n",
    "# add legend\n",
    "ax.legend(frameon=False, loc='lower right', fontsize=12)\n",
    "\n",
    "# axis limits and ticks\n",
    "ax.set_xticks(xticks)\n",
    "ax.set_xticklabels(xticks)\n",
    "ax.tick_params(axis='both', labelsize=12)\n",
    "\n",
    "# save figure\n",
    "fig.savefig('./Figures/{}_{}_{}_bootstrap_time_series_{}.png'.format(PREDICTAND, LOSS, OPTIM, scale if not ROLLING else 'rolling'),\n",
    "            bbox_inches='tight', dpi=300)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "923762ca-6ebc-4ffa-9b65-2faaf816fc05",
   "metadata": {},
   "source": [
    "### Spatial distributions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a520127b-0dbc-4217-9a00-68cef41afe83",
   "metadata": {},
   "outputs": [],
   "source": [
    "# compute ensemble mean yearly mean bias of each grid point\n",
    "pred = (ensemble_mean_full.groupby('time.year').mean(dim='time') - y_true_yearly_avg).mean(dim='year').compute()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e917db7e-ae9b-48e8-bb23-58905c47a910",
   "metadata": {},
   "outputs": [],
   "source": [
    "# plot yearly average bias of references and predictions\n",
    "vmin, vmax = -1, 1\n",
    "fig, axes = plt.subplots(1, 3, figsize=(24, 8), sharex=True, sharey=True)\n",
    "\n",
    "# plot bias of ERA-5 reference\n",
    "era5 = bias_refe.mean(dim='year')\n",
    "im1 = axes[0].imshow(era5.values, origin='lower', cmap='RdBu_r', vmin=vmin, vmax=vmax)\n",
    "\n",
    "# plot bias of ERA-5 QM-adjusted reference\n",
    "era5_qm = bias_refe_qm.mean(dim='year')\n",
    "im2 = axes[1].imshow(era5_qm.values, origin='lower', cmap='RdBu_r', vmin=vmin, vmax=vmax)\n",
    "\n",
    "# plot bias of ensemble model prediction\n",
    "im3 = axes[2].imshow(pred, origin='lower', cmap='RdBu_r', vmin=vmin, vmax=vmax)\n",
    "\n",
    "# set titles\n",
    "axes[0].set_title('Era-5', fontsize=14, pad=10);\n",
    "axes[1].set_title('Era-5: QM-adjusted', fontsize=14, pad=10);\n",
    "axes[2].set_title('Predictions: Ensemble mean', fontsize=14, pad=10)\n",
    "\n",
    "# adjust axes\n",
    "for ax in axes.flat:\n",
    "    ax.axes.get_xaxis().set_ticklabels([])\n",
    "    ax.axes.get_xaxis().set_ticks([])\n",
    "    ax.axes.get_yaxis().set_ticklabels([])\n",
    "    ax.axes.get_yaxis().set_ticks([])\n",
    "    ax.axes.axis('tight')\n",
    "    ax.set_xlabel('')\n",
    "    ax.set_ylabel('')\n",
    "    ax.set_axis_off()\n",
    "\n",
    "# adjust figure\n",
    "fig.subplots_adjust(hspace=0, wspace=0, top=0.85)\n",
    "\n",
    "# add colorbar\n",
    "axes = axes.flatten()\n",
    "cbar_ax_bias = fig.add_axes([axes[-1].get_position().x1 + 0.01, axes[-1].get_position().y0,\n",
    "                             0.01, axes[-1].get_position().y1 - axes[-1].get_position().y0])\n",
    "cbar_bias = fig.colorbar(im3, cax=cbar_ax_bias)\n",
    "cbar_bias.set_label(label='Bias (°C)', fontsize=14)\n",
    "cbar_bias.ax.tick_params(labelsize=14, pad=10)\n",
    "\n",
    "# save figure\n",
    "fig.savefig('../Notebooks/Figures/{}_{}_{}_bootstrap_bias.png'.format(PREDICTAND, LOSS, OPTIM), dpi=300, bbox_inches='tight')"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}