{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "fde8874d-299f-4f48-a10a-9fb6a00b43b9",
   "metadata": {},
   "source": [
    "# Evaluate bootstrapped model results"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "969d063b-5262-4324-901f-0a48630c4f27",
   "metadata": {},
   "source": [
    "### Imports and constants"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8af00ae4-4aeb-4ff8-a46a-65966b28c440",
   "metadata": {},
   "outputs": [],
   "source": [
    "# builtins\n",
    "import pathlib\n",
    "\n",
    "# externals\n",
    "import numpy as np\n",
    "import xarray as xr\n",
    "\n",
    "# locals\n",
    "from climax.main.io import OBS_PATH, ERA5_PATH\n",
    "from climax.main.config import VALID_PERIOD\n",
    "from pysegcnn.core.utils import search_files"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5bc74835-dc59-46ed-849b-3ff614e53eee",
   "metadata": {},
   "outputs": [],
   "source": [
    "# mapping from predictands to variable names\n",
    "NAMES = {'tasmin': 'minimum temperature', 'tasmax': 'maximum temperature', 'pr': 'precipitation'}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c8a63ef3-35ef-4ffa-b1f3-5c2986eb7eb1",
   "metadata": {},
   "outputs": [],
   "source": [
    "# path to bootstrapped model results\n",
    "RESULTS = pathlib.Path('/mnt/CEPH_PROJECTS/FACT_CLIMAX/ERA5_PRED/bootstrap')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7eae545b-4d8a-4689-a6c0-4aba2cb9104e",
   "metadata": {},
   "source": [
    "## Search model configuration"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3e856f80-14fd-405f-a44e-cc77863f8e5b",
   "metadata": {},
   "outputs": [],
   "source": [
    "# predictand to evaluate\n",
    "PREDICTAND = 'tasmin'\n",
    "LOSS = 'L1Loss'\n",
    "OPTIM = 'Adam'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "011b792d-7349-44ad-997d-11f236472a11",
   "metadata": {},
   "outputs": [],
   "source": [
    "# model to evaluate\n",
    "model = 'USegNet_{}_ztuvq_500_850_p_dem_doy_{}_{}'.format(PREDICTAND, LOSS, OPTIM)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "dc4ca6f0-5490-4522-8661-e36bd1be11b7",
   "metadata": {},
   "outputs": [],
   "source": [
    "# get bootstrapped models\n",
    "models = sorted(search_files(RESULTS.joinpath(PREDICTAND), model + '(.*).nc$'),\n",
    "                key=lambda x: int(x.stem.split('_')[-1]))\n",
    "models"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e790ed9f-451c-4368-849d-06d9c50f797c",
   "metadata": {},
   "source": [
    "### Load observations"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0862e0c8-06df-45d6-bc1b-002ffb6e9915",
   "metadata": {},
   "outputs": [],
   "source": [
    "# load observations\n",
    "y_true = xr.open_dataset(OBS_PATH.joinpath(PREDICTAND, 'OBS_{}_1980_2018.nc'.format(PREDICTAND)),\n",
    "                         chunks={'time': 365})\n",
    "y_true = y_true.sel(time=VALID_PERIOD)  # subset to time period covered by predictions\n",
    "y_true = y_true.rename({NAMES[PREDICTAND]: PREDICTAND}) if PREDICTAND == 'pr' else y_true"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "aba38642-85d1-404a-81f3-65d23985fb7a",
   "metadata": {},
   "outputs": [],
   "source": [
    "# mask of missing values\n",
    "missing = np.isnan(y_true[PREDICTAND])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d4512ed2-d503-4bc1-ae76-84560c101a14",
   "metadata": {},
   "source": [
    "### Load reference data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f90f6abf-5fd6-49c0-a1ad-f62242b3d3a0",
   "metadata": {},
   "outputs": [],
   "source": [
    "# ERA-5 reference dataset\n",
    "if PREDICTAND == 'pr':\n",
    "    y_refe = xr.open_dataset(search_files(ERA5_PATH.joinpath('ERA5', 'total_precipitation'), '.nc$').pop(),\n",
    "                             chunks={'time': 365})\n",
    "    y_refe = y_refe.rename({'tp': 'pr'})\n",
    "else:\n",
    "    y_refe = xr.open_dataset(search_files(ERA5_PATH.joinpath('ERA5', '2m_{}_temperature'.format(PREDICTAND.lstrip('tas'))), '.nc$').pop(),\n",
    "                             chunks={'time': 365})\n",
    "    y_refe = y_refe - 273.15  # convert to °C\n",
    "    y_refe = y_refe.rename({'t2m': PREDICTAND})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ea6d5f56-4f39-4e9a-976d-00ff28fce95c",
   "metadata": {},
   "outputs": [],
   "source": [
    "# subset to time period covered by predictions\n",
    "y_refe = y_refe.sel(time=VALID_PERIOD).drop_vars('lambert_azimuthal_equal_area')\n",
    "y_refe = y_refe.transpose('time', 'y', 'x')  # change order of dimensions"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d37702de-da5f-4306-acc1-e569471c1f12",
   "metadata": {},
   "source": [
    "### Load QM-adjusted reference data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fffbd267-d08b-44f4-869c-7056c4f19c28",
   "metadata": {},
   "outputs": [],
   "source": [
    "y_refe_qm = xr.open_dataset(ERA5_PATH.joinpath('QM_ERA5_{}_day_19912010.nc'.format(PREDICTAND)), chunks={'time': 365})\n",
    "y_refe_qm = y_refe_qm.transpose('time', 'x', 'y')  # change order of dimensions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "16fa580e-27a7-4758-9164-7f607df7179d",
   "metadata": {},
   "outputs": [],
   "source": [
    "# center hours at 00:00:00 rather than 12:00:00\n",
    "y_refe_qm['time'] = np.asarray([t.astype('datetime64[D]') for t in y_refe_qm.time.values])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6789791f-006b-49b3-aa04-34e4ed8e1571",
   "metadata": {},
   "outputs": [],
   "source": [
    "# subset to time period covered by predictions\n",
    "y_refe_qm = y_refe_qm.sel(time=VALID_PERIOD).drop_vars('lambert_azimuthal_equal_area')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b51cfb3f-caa8-413e-a12d-47bbafcef1df",
   "metadata": {},
   "outputs": [],
   "source": [
    "# align datasets and mask missing values\n",
    "y_true, y_refe, y_refe_qm = xr.align(y_true[PREDICTAND], y_refe[PREDICTAND], y_refe_qm[PREDICTAND], join='override')\n",
    "y_refe = y_refe.where(~missing, other=np.nan)\n",
    "y_refe_qm = y_refe_qm.where(~missing, other=np.nan)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b4a6c286-6b88-487d-866c-3cb633686dac",
   "metadata": {},
   "source": [
    "### Load model predictions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ccaf0118-da51-43b0-a2b6-56ba4b252999",
   "metadata": {},
   "outputs": [],
   "source": [
    "y_pred = [xr.open_dataset(sim, chunks={'time': 365}) for sim in models]\n",
    "if PREDICTAND == 'pr':\n",
    "    y_pred = [y_p.rename({NAMES[PREDICTAND]: PREDICTAND}) for y_p in y_pred]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "df3f018e-4723-4878-b72a-0586b68e6dfd",
   "metadata": {},
   "outputs": [],
   "source": [
    "# align datasets and mask missing values\n",
    "y_prob = []\n",
    "for i, y_p in enumerate(y_pred):\n",
    "    \n",
    "    # check whether evaluating precipitation or temperatures\n",
    "    if len(y_p.data_vars) > 1:\n",
    "        _, _, y_p, y_p_prob = xr.align(y_true, y_refe, y_p[PREDICTAND], y_p.prob, join='override')\n",
    "        y_p_prob = y_p_prob.where(~missing, other=np.nan)  # mask missing values\n",
    "        y_prob.append(y_p_prob)\n",
    "    else:\n",
    "        _, _, y_p = xr.align(y_true, y_refe, y_p[PREDICTAND], join='override')\n",
    "    \n",
    "    # mask missing values\n",
    "    y_p = y_p.where(~missing, other=np.nan)\n",
    "    y_pred[i] = y_p"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7effbf83-7977-4a47-aa6d-d57c4c4c22c6",
   "metadata": {},
   "source": [
    "## Bias, MAE, and RMSE"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b8c23b7b-ccdf-412a-a30d-ac686af03d9f",
   "metadata": {},
   "source": [
    "Calculate yearly average bias, MAE, and RMSE over entire reference period for model predictions, ERA-5, and QM-adjusted ERA-5."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7939a4d2-4eff-4507-86f8-dba7c0b635df",
   "metadata": {},
   "outputs": [],
   "source": [
    "# yearly average values over validation period\n",
    "y_pred_yearly_avg = [y_p.groupby('time.year').mean(dim='time') for y_p in y_pred]\n",
    "y_refe_yearly_avg = y_refe.groupby('time.year').mean(dim='time')\n",
    "y_refe_qm_yearly_avg = y_refe_qm.groupby('time.year').mean(dim='time')\n",
    "y_true_yearly_avg = y_true.groupby('time.year').mean(dim='time')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cce7ffbf-7e16-45f1-a2b6-5bc688595ee7",
   "metadata": {},
   "outputs": [],
   "source": [
    "# yearly average bias, mae, and rmse for model predictions\n",
    "bias_pred = [y_p - y_true_yearly_avg for y_p in y_pred_yearly_avg]\n",
    "mae_pred = [np.abs(y_p - y_true_yearly_avg) for y_p in y_pred_yearly_avg]\n",
    "rmse_pred = [np.sqrt((y_p - y_true_yearly_avg) ** 2) for y_p in y_pred_yearly_avg]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "64e29db7-998d-4952-84b0-1c79016ab9a9",
   "metadata": {},
   "outputs": [],
   "source": [
    "# yearly average bias, mae, and rmse for ERA-5\n",
    "bias_refe = y_refe_yearly_avg - y_true_yearly_avg\n",
    "mae_refe = np.abs(y_refe_yearly_avg - y_true_yearly_avg)\n",
    "rmse_refe = np.sqrt((y_refe_yearly_avg - y_true_yearly_avg) ** 2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bb7dc5c2-c518-4251-8d27-5b52e20e0397",
   "metadata": {},
   "outputs": [],
   "source": [
    "# yearly average bias, mae, and rmse for QM-Adjusted ERA-5\n",
    "bias_refe_qm = y_refe_qm_yearly_avg - y_true_yearly_avg\n",
    "mae_refe_qm = np.abs(y_refe_qm_yearly_avg - y_true_yearly_avg)\n",
    "rmse_refe_qm = np.sqrt((y_refe_qm_yearly_avg - y_true_yearly_avg) ** 2)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}