diff --git a/Notebooks/pr_sampling.ipynb b/Notebooks/pr_sampling.ipynb new file mode 100644 index 0000000000000000000000000000000000000000..5d23a94fea3627f58d06ce29333aeeb7d7d51264 --- /dev/null +++ b/Notebooks/pr_sampling.ipynb @@ -0,0 +1,170 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "d6b83379-c5a8-48c3-bb85-d00a341a37f4", + "metadata": {}, + "source": [ + "### Imports" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "eeba7f9b-066a-4843-bd64-5b6326c0b536", + "metadata": {}, + "outputs": [], + "source": [ + "# builtins\n", + "import datetime\n", + "import warnings\n", + "import calendar\n", + "\n", + "# externals\n", + "import xarray as xr\n", + "import numpy as np\n", + "import matplotlib.pyplot as plt\n", + "import seaborn as sns\n", + "import pandas as pd\n", + "from mpl_toolkits.axes_grid1.inset_locator import inset_axes\n", + "import scipy.stats as stats\n", + "from IPython.display import Image\n", + "from sklearn.metrics import r2_score, roc_curve, auc, classification_report\n", + "from sklearn.model_selection import train_test_split\n", + "\n", + "# locals\n", + "from climax.main.io import ERA5_PATH, OBS_PATH, TARGET_PATH, DEM_PATH\n", + "from climax.main.config import CALIB_PERIOD\n", + "from pysegcnn.core.utils import search_files\n", + "from pysegcnn.core.graphics import plot_classification_report" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e75b3217-26f7-4a4a-ae2a-4fbb92a9f2a2", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "# model predictions and observations NetCDF \n", + "y_true = xr.open_dataset(search_files(OBS_PATH.joinpath('pr'), 'OBS_pr(.*).nc$').pop())" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3aa8466e-84a9-4c2e-ae19-403b6246e27f", + "metadata": {}, + "outputs": [], + "source": [ + "# subset to calibration period\n", + "y_true = y_true.sel(time=CALIB_PERIOD)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "4f1a58a2-8c4c-4d73-a116-e64e68fdd507", + "metadata": {}, + "outputs": [], + "source": [ + "# precipitation threshold defining a wet day\n", + "WET_DAY_THRESHOLD = 1" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5e6696df-8660-4083-9a32-0dd282112948", + "metadata": {}, + "outputs": [], + "source": [ + "# calculate number of wet days in calibration period\n", + "wet_days = (y_true.mean(dim=('y', 'x')) >= WET_DAY_THRESHOLD).astype(np.int16)\n", + "nwet_days = wet_days.to_array().values.squeeze()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b87accd6-d5e4-4dc6-9532-3ef8aa162d24", + "metadata": {}, + "outputs": [], + "source": [ + "# split training/validation set chronologically\n", + "train, valid = train_test_split(CALIB_PERIOD, shuffle=False, test_size=0.25)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "559d1450-09db-4b2f-844a-d572485973e0", + "metadata": {}, + "outputs": [], + "source": [ + "# split training/validation set by number of wet days\n", + "train_st, valid_st = train_test_split(CALIB_PERIOD, stratify=nwet_days, test_size=0.5)\n", + "train_st, valid_st = np.asarray(sorted(train_st)), np.asarray(sorted(valid_st)) # sort chronologically" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7fd013f9-77d0-48de-8d5f-2c6a1cb3ed17", + "metadata": {}, + "outputs": [], + "source": [ + "# plot distribution of wet days in calibration period\n", + "fig, axes = plt.subplots(2, 2, sharex=True, sharey=True, figsize=(10, 10))\n", + "axes = axes.flatten()\n", + "\n", + "# not stratified\n", + "sns.countplot(x=wet_days.sel(time=train).to_array().values.squeeze(), ax=axes[0])\n", + "sns.countplot(x=wet_days.sel(time=valid).to_array().values.squeeze(), ax=axes[2])\n", + "\n", + "# stratified\n", + "sns.countplot(x=wet_days.sel(time=train_st).to_array().values.squeeze(), ax=axes[1])\n", + "sns.countplot(x=wet_days.sel(time=valid_st).to_array().values.squeeze(), ax=axes[3])\n", + "\n", + "# axes properties\n", + "for ax in axes:\n", + " ax.set_ylabel('')\n", + "for ax in axes[2:]:\n", + " ax.set_xticklabels(['Dry', 'Wet'])\n", + "for ax in [axes[0], axes[1]]:\n", + " ax.text(1, ax.get_ylim()[-1] - 5, 'Training', ha='left', va='top', fontsize=12)\n", + "for ax in [axes[2], axes[3]]:\n", + " ax.text(1, ax.get_ylim()[-1] - 5, 'Validation', ha='left', va='top', fontsize=12)\n", + "axes[0].set_title('Not stratified')\n", + "axes[1].set_title('Stratified')\n", + "\n", + "# adjust subplot\n", + "fig.subplots_adjust(wspace=0.1, hspace=0.1)\n", + "fig.suptitle('Stratified sampling: wet day threshold {:0d} mm'.format(WET_DAY_THRESHOLD));" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.7.10" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +}