{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "f15afea1-9ea4-4201-bdd7-32ae377db6a9",
   "metadata": {},
   "source": [
    "# Evaluate ERA-5 downscaling"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4186c89c-b55b-4559-a818-8b712baaf44e",
   "metadata": {},
   "source": [
    "Define the predictand and the model to evaluate"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bb24b6ed-2d0a-44e0-b9a9-abdcb2a8294d",
   "metadata": {},
   "outputs": [],
   "source": [
    "# define the predictand and the model to evaluate\n",
    "PREDICTAND = 'tasmin'\n",
    "MODEL = 'USegNet'"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4d5d12c5-50fd-4c5c-9240-c3df78e49b44",
   "metadata": {},
   "source": [
    "### Imports"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1afb0fab-5d2a-4875-9032-29b99c6dec89",
   "metadata": {},
   "outputs": [],
   "source": [
    "# builtins\n",
    "import datetime\n",
    "\n",
    "# externals\n",
    "import xarray as xr\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "# locals\n",
    "from climax.main.io import ERA5_PATH, OBS_PATH, TARGET_PATH"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f47caa41-9380-4c02-8785-4febcf2cb2d0",
   "metadata": {},
   "source": [
    "### Load datasets"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "020dfe33-ce3c-467f-ad0a-295cc338b1a9",
   "metadata": {},
   "outputs": [],
   "source": [
    "# model predictions and observations NetCDF\n",
    "y_pred = TARGET_PATH.joinpath(PREDICTAND, '_'.join([MODEL, PREDICTAND]) + '.nc')\n",
    "y_true = OBS_PATH.joinpath(PREDICTAND, '_'.join(['OBS', PREDICTAND, '1980', '2018']) + '.nc')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1dc2a386-d63b-4c6a-8e63-00365927559d",
   "metadata": {},
   "outputs": [],
   "source": [
    "# load datasets\n",
    "y_pred = xr.open_dataset(y_pred)\n",
    "y_true = xr.open_dataset(y_true).sel(time=y_pred.time)  # subset to time period covered by predictions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8748e102-ba64-4472-8960-7cd0830fdcf6",
   "metadata": {},
   "outputs": [],
   "source": [
    "# replace variable name by predictand\n",
    "y_true = y_true.rename({var: PREDICTAND for var in y_true.data_vars})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "966d85fb-9185-408f-ac2b-1e4ca829ccd1",
   "metadata": {},
   "outputs": [],
   "source": [
    "# align datasets and mask missing values in model predictions\n",
    "y_true, y_pred = xr.align(y_true, y_pred, join='override')\n",
    "y_pred = y_pred.where(~np.isnan(y_true), other=np.nan)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ddebdf9f-862c-461e-aa57-cd344d54eee9",
   "metadata": {},
   "source": [
    "## Model validation"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5e8be24e-8ca2-4582-98c0-b56c6db289d2",
   "metadata": {},
   "source": [
    "### Overall bias"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a4f7177c-7d09-401f-957b-0e493b9ef5d0",
   "metadata": {},
   "source": [
    "Calculate average bias over entire reference period:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "746bf95f-a78b-4da8-a063-1fa48e3c5da8",
   "metadata": {},
   "outputs": [],
   "source": [
    "# average bias over reference period\n",
    "y_pred_avg = y_pred.mean(dim='time')\n",
    "y_true_avg = y_true.mean(dim='time')\n",
    "bias = y_pred_avg - y_true_avg\n",
    "print('Overall average bias: {:.2f}'.format(bias[PREDICTAND].mean().item()))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "760f86ce-9e04-4938-b24f-d2819fbf622e",
   "metadata": {},
   "outputs": [],
   "source": [
    "# plot average of observation, prediction, and bias\n",
    "fig, axes = plt.subplots(1, 3, figsize=(24, 6))\n",
    "for ds, ax, title in zip([y_true_avg, y_pred_avg, bias], axes, ['Observed', 'Predicted', 'Difference']):\n",
    "    ds[PREDICTAND].plot(ax=ax)\n",
    "    ax.set_title(title)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "aa4ef730-5a9d-40dc-a318-2f43a4cf1cd2",
   "metadata": {},
   "source": [
    "### Seasonal bias"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "eda455a2-e8ee-4644-bb85-b0cf76acd11a",
   "metadata": {},
   "source": [
    "Calculate seasonal bias:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "24aadff5-b19d-4f4b-a4c5-32ee656e64cd",
   "metadata": {},
   "outputs": [],
   "source": [
    "# group data by season: (DJF, MAM, JJA, SON)\n",
    "y_true_snl = y_true.groupby('time.season').mean(dim='time')\n",
    "y_pred_snl = y_pred.groupby('time.season').mean(dim='time')\n",
    "bias_snl = y_pred_snl - y_true_snl"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c4232ae9-a557-4d61-8ab3-d7eda6201f98",
   "metadata": {},
   "source": [
    "Plot seasonal differences, taken from the [xarray documentation](xarray.pydata.org/en/stable/examples/monthly-means.html)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b39a6cc0-614c-452d-bb85-bc10e5179948",
   "metadata": {},
   "outputs": [],
   "source": [
    "# plot seasonal differences\n",
    "fig, axes = plt.subplots(nrows=4, ncols=3, figsize=(14,12))\n",
    "for i, season in enumerate(('DJF', 'MAM', 'JJA', 'SON')):\n",
    "    y_true_snl[PREDICTAND].sel(season=season).plot.pcolormesh(\n",
    "        ax=axes[i, 0], add_colorbar=True, extend='both')\n",
    "\n",
    "    y_pred_snl[PREDICTAND].sel(season=season).plot.pcolormesh(\n",
    "        ax=axes[i, 1], add_colorbar=True, extend='both')\n",
    "\n",
    "    bias_snl[PREDICTAND].sel(season=season).plot.pcolormesh(\n",
    "        ax=axes[i, 2], vmin=-1, vmax=1, add_colorbar=True,\n",
    "        extend='both')\n",
    "\n",
    "    axes[i, 0].set_ylabel(season)\n",
    "    axes[i, 1].set_ylabel('')\n",
    "    axes[i, 2].set_ylabel('')\n",
    "\n",
    "for ax in axes.flat:\n",
    "    ax.axes.get_xaxis().set_ticklabels([])\n",
    "    ax.axes.get_yaxis().set_ticklabels([])\n",
    "    ax.axes.axis('tight')\n",
    "    ax.set_xlabel('')\n",
    "\n",
    "axes[0, 0].set_title('Observed')\n",
    "axes[0, 1].set_title('Predicted')\n",
    "axes[0, 2].set_title('Difference')\n",
    "\n",
    "plt.tight_layout()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}