From 9bcb0d5a5d6aca2260f2b6b745e55961060bb4fd Mon Sep 17 00:00:00 2001
From: "Daniel.Frisinghelli" <daniel.frisinghelli@eurac.edu>
Date: Mon, 27 Sep 2021 15:06:47 +0000
Subject: [PATCH] Notebook visualizing stratified sampling.

---
 Notebooks/pr_sampling.ipynb | 170 ++++++++++++++++++++++++++++++++++++
 1 file changed, 170 insertions(+)
 create mode 100644 Notebooks/pr_sampling.ipynb

diff --git a/Notebooks/pr_sampling.ipynb b/Notebooks/pr_sampling.ipynb
new file mode 100644
index 0000000..5d23a94
--- /dev/null
+++ b/Notebooks/pr_sampling.ipynb
@@ -0,0 +1,170 @@
+{
+ "cells": [
+  {
+   "cell_type": "markdown",
+   "id": "d6b83379-c5a8-48c3-bb85-d00a341a37f4",
+   "metadata": {},
+   "source": [
+    "### Imports"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "eeba7f9b-066a-4843-bd64-5b6326c0b536",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "# builtins\n",
+    "import datetime\n",
+    "import warnings\n",
+    "import calendar\n",
+    "\n",
+    "# externals\n",
+    "import xarray as xr\n",
+    "import numpy as np\n",
+    "import matplotlib.pyplot as plt\n",
+    "import seaborn as sns\n",
+    "import pandas as pd\n",
+    "from mpl_toolkits.axes_grid1.inset_locator import inset_axes\n",
+    "import scipy.stats as stats\n",
+    "from IPython.display import Image\n",
+    "from sklearn.metrics import r2_score, roc_curve, auc, classification_report\n",
+    "from sklearn.model_selection import train_test_split\n",
+    "\n",
+    "# locals\n",
+    "from climax.main.io import ERA5_PATH, OBS_PATH, TARGET_PATH, DEM_PATH\n",
+    "from climax.main.config import CALIB_PERIOD\n",
+    "from pysegcnn.core.utils import search_files\n",
+    "from pysegcnn.core.graphics import plot_classification_report"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "e75b3217-26f7-4a4a-ae2a-4fbb92a9f2a2",
+   "metadata": {
+    "tags": []
+   },
+   "outputs": [],
+   "source": [
+    "# model predictions and observations NetCDF \n",
+    "y_true = xr.open_dataset(search_files(OBS_PATH.joinpath('pr'), 'OBS_pr(.*).nc$').pop())"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "3aa8466e-84a9-4c2e-ae19-403b6246e27f",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "# subset to calibration period\n",
+    "y_true = y_true.sel(time=CALIB_PERIOD)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "4f1a58a2-8c4c-4d73-a116-e64e68fdd507",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "# precipitation threshold defining a wet day\n",
+    "WET_DAY_THRESHOLD = 1"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "5e6696df-8660-4083-9a32-0dd282112948",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "# calculate number of wet days in calibration period\n",
+    "wet_days = (y_true.mean(dim=('y', 'x')) >= WET_DAY_THRESHOLD).astype(np.int16)\n",
+    "nwet_days = wet_days.to_array().values.squeeze()"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "b87accd6-d5e4-4dc6-9532-3ef8aa162d24",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "# split training/validation set chronologically\n",
+    "train, valid = train_test_split(CALIB_PERIOD, shuffle=False, test_size=0.25)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "559d1450-09db-4b2f-844a-d572485973e0",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "# split training/validation set by number of wet days\n",
+    "train_st, valid_st = train_test_split(CALIB_PERIOD, stratify=nwet_days, test_size=0.5)\n",
+    "train_st, valid_st = np.asarray(sorted(train_st)), np.asarray(sorted(valid_st))  # sort chronologically"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "7fd013f9-77d0-48de-8d5f-2c6a1cb3ed17",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "# plot distribution of wet days in calibration period\n",
+    "fig, axes = plt.subplots(2, 2, sharex=True, sharey=True, figsize=(10, 10))\n",
+    "axes = axes.flatten()\n",
+    "\n",
+    "# not stratified\n",
+    "sns.countplot(x=wet_days.sel(time=train).to_array().values.squeeze(), ax=axes[0])\n",
+    "sns.countplot(x=wet_days.sel(time=valid).to_array().values.squeeze(), ax=axes[2])\n",
+    "\n",
+    "# stratified\n",
+    "sns.countplot(x=wet_days.sel(time=train_st).to_array().values.squeeze(), ax=axes[1])\n",
+    "sns.countplot(x=wet_days.sel(time=valid_st).to_array().values.squeeze(), ax=axes[3])\n",
+    "\n",
+    "# axes properties\n",
+    "for ax in axes:\n",
+    "    ax.set_ylabel('')\n",
+    "for ax in axes[2:]:\n",
+    "    ax.set_xticklabels(['Dry', 'Wet'])\n",
+    "for ax in [axes[0], axes[1]]:\n",
+    "    ax.text(1, ax.get_ylim()[-1] - 5, 'Training', ha='left', va='top', fontsize=12)\n",
+    "for ax in [axes[2], axes[3]]:\n",
+    "    ax.text(1, ax.get_ylim()[-1] - 5, 'Validation', ha='left', va='top', fontsize=12)\n",
+    "axes[0].set_title('Not stratified')\n",
+    "axes[1].set_title('Stratified')\n",
+    "\n",
+    "# adjust subplot\n",
+    "fig.subplots_adjust(wspace=0.1, hspace=0.1)\n",
+    "fig.suptitle('Stratified sampling: wet day threshold {:0d} mm'.format(WET_DAY_THRESHOLD));"
+   ]
+  }
+ ],
+ "metadata": {
+  "kernelspec": {
+   "display_name": "Python 3",
+   "language": "python",
+   "name": "python3"
+  },
+  "language_info": {
+   "codemirror_mode": {
+    "name": "ipython",
+    "version": 3
+   },
+   "file_extension": ".py",
+   "mimetype": "text/x-python",
+   "name": "python",
+   "nbconvert_exporter": "python",
+   "pygments_lexer": "ipython3",
+   "version": "3.7.10"
+  }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 5
+}
-- 
GitLab