{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "d6b83379-c5a8-48c3-bb85-d00a341a37f4",
   "metadata": {},
   "source": [
    "### Imports"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "eeba7f9b-066a-4843-bd64-5b6326c0b536",
   "metadata": {},
   "outputs": [],
   "source": [
    "# builtins\n",
    "import datetime\n",
    "import warnings\n",
    "import calendar\n",
    "\n",
    "# externals\n",
    "import xarray as xr\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "import seaborn as sns\n",
    "import pandas as pd\n",
    "from mpl_toolkits.axes_grid1.inset_locator import inset_axes\n",
    "import scipy.stats as stats\n",
    "from IPython.display import Image\n",
    "from sklearn.metrics import r2_score, roc_curve, auc, classification_report\n",
    "from sklearn.model_selection import train_test_split\n",
    "\n",
    "# locals\n",
    "from climax.main.io import ERA5_PATH, OBS_PATH, TARGET_PATH, DEM_PATH\n",
    "from climax.main.config import CALIB_PERIOD\n",
    "from pysegcnn.core.utils import search_files\n",
    "from pysegcnn.core.graphics import plot_classification_report"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e75b3217-26f7-4a4a-ae2a-4fbb92a9f2a2",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "# model predictions and observations NetCDF \n",
    "y_true = xr.open_dataset(search_files(OBS_PATH.joinpath('pr'), 'OBS_pr(.*).nc$').pop())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3aa8466e-84a9-4c2e-ae19-403b6246e27f",
   "metadata": {},
   "outputs": [],
   "source": [
    "# subset to calibration period\n",
    "y_true = y_true.sel(time=CALIB_PERIOD)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4f1a58a2-8c4c-4d73-a116-e64e68fdd507",
   "metadata": {},
   "outputs": [],
   "source": [
    "# precipitation threshold defining a wet day\n",
    "WET_DAY_THRESHOLD = 1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5e6696df-8660-4083-9a32-0dd282112948",
   "metadata": {},
   "outputs": [],
   "source": [
    "# calculate number of wet days in calibration period\n",
    "wet_days = (y_true.mean(dim=('y', 'x')) >= WET_DAY_THRESHOLD).astype(np.int16)\n",
    "nwet_days = wet_days.to_array().values.squeeze()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b87accd6-d5e4-4dc6-9532-3ef8aa162d24",
   "metadata": {},
   "outputs": [],
   "source": [
    "# split training/validation set chronologically\n",
    "train, valid = train_test_split(CALIB_PERIOD, shuffle=False, test_size=0.25)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "559d1450-09db-4b2f-844a-d572485973e0",
   "metadata": {},
   "outputs": [],
   "source": [
    "# split training/validation set by number of wet days\n",
    "train_st, valid_st = train_test_split(CALIB_PERIOD, stratify=nwet_days, test_size=0.5)\n",
    "train_st, valid_st = np.asarray(sorted(train_st)), np.asarray(sorted(valid_st))  # sort chronologically"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7fd013f9-77d0-48de-8d5f-2c6a1cb3ed17",
   "metadata": {},
   "outputs": [],
   "source": [
    "# plot distribution of wet days in calibration period\n",
    "fig, axes = plt.subplots(2, 2, sharex=True, sharey=True, figsize=(10, 10))\n",
    "axes = axes.flatten()\n",
    "\n",
    "# not stratified\n",
    "sns.countplot(x=wet_days.sel(time=train).to_array().values.squeeze(), ax=axes[0])\n",
    "sns.countplot(x=wet_days.sel(time=valid).to_array().values.squeeze(), ax=axes[2])\n",
    "\n",
    "# stratified\n",
    "sns.countplot(x=wet_days.sel(time=train_st).to_array().values.squeeze(), ax=axes[1])\n",
    "sns.countplot(x=wet_days.sel(time=valid_st).to_array().values.squeeze(), ax=axes[3])\n",
    "\n",
    "# axes properties\n",
    "for ax in axes:\n",
    "    ax.set_ylabel('')\n",
    "for ax in axes[2:]:\n",
    "    ax.set_xticklabels(['Dry', 'Wet'])\n",
    "for ax in [axes[0], axes[1]]:\n",
    "    ax.text(1, ax.get_ylim()[-1] - 5, 'Training', ha='left', va='top', fontsize=12)\n",
    "for ax in [axes[2], axes[3]]:\n",
    "    ax.text(1, ax.get_ylim()[-1] - 5, 'Validation', ha='left', va='top', fontsize=12)\n",
    "axes[0].set_title('Not stratified')\n",
    "axes[1].set_title('Stratified')\n",
    "\n",
    "# adjust subplot\n",
    "fig.subplots_adjust(wspace=0.1, hspace=0.1)\n",
    "fig.suptitle('Stratified sampling: wet day threshold {:0d} mm'.format(WET_DAY_THRESHOLD));"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}