{ "cells": [ { "cell_type": "markdown", "id": "d6b83379-c5a8-48c3-bb85-d00a341a37f4", "metadata": {}, "source": [ "### Imports" ] }, { "cell_type": "code", "execution_count": null, "id": "eeba7f9b-066a-4843-bd64-5b6326c0b536", "metadata": {}, "outputs": [], "source": [ "# builtins\n", "import datetime\n", "import warnings\n", "import calendar\n", "\n", "# externals\n", "import xarray as xr\n", "import numpy as np\n", "import matplotlib.pyplot as plt\n", "import seaborn as sns\n", "import pandas as pd\n", "from mpl_toolkits.axes_grid1.inset_locator import inset_axes\n", "import scipy.stats as stats\n", "from IPython.display import Image\n", "from sklearn.metrics import r2_score, roc_curve, auc, classification_report\n", "from sklearn.model_selection import train_test_split\n", "\n", "# locals\n", "from climax.main.io import ERA5_PATH, OBS_PATH, TARGET_PATH, DEM_PATH\n", "from climax.main.config import CALIB_PERIOD\n", "from pysegcnn.core.utils import search_files\n", "from pysegcnn.core.graphics import plot_classification_report" ] }, { "cell_type": "code", "execution_count": null, "id": "e75b3217-26f7-4a4a-ae2a-4fbb92a9f2a2", "metadata": { "tags": [] }, "outputs": [], "source": [ "# model predictions and observations NetCDF \n", "y_true = xr.open_dataset(search_files(OBS_PATH.joinpath('pr'), 'OBS_pr(.*).nc$').pop())" ] }, { "cell_type": "code", "execution_count": null, "id": "3aa8466e-84a9-4c2e-ae19-403b6246e27f", "metadata": {}, "outputs": [], "source": [ "# subset to calibration period\n", "y_true = y_true.sel(time=CALIB_PERIOD)" ] }, { "cell_type": "code", "execution_count": null, "id": "4f1a58a2-8c4c-4d73-a116-e64e68fdd507", "metadata": {}, "outputs": [], "source": [ "# precipitation threshold defining a wet day\n", "WET_DAY_THRESHOLD = 1" ] }, { "cell_type": "code", "execution_count": null, "id": "5e6696df-8660-4083-9a32-0dd282112948", "metadata": {}, "outputs": [], "source": [ "# calculate number of wet days in calibration period\n", "wet_days = (y_true.mean(dim=('y', 'x')) >= WET_DAY_THRESHOLD).astype(np.int16)\n", "nwet_days = wet_days.to_array().values.squeeze()" ] }, { "cell_type": "code", "execution_count": null, "id": "b87accd6-d5e4-4dc6-9532-3ef8aa162d24", "metadata": {}, "outputs": [], "source": [ "# split training/validation set chronologically\n", "train, valid = train_test_split(CALIB_PERIOD, shuffle=False, test_size=0.25)" ] }, { "cell_type": "code", "execution_count": null, "id": "559d1450-09db-4b2f-844a-d572485973e0", "metadata": {}, "outputs": [], "source": [ "# split training/validation set by number of wet days\n", "train_st, valid_st = train_test_split(CALIB_PERIOD, stratify=nwet_days, test_size=0.5)\n", "train_st, valid_st = np.asarray(sorted(train_st)), np.asarray(sorted(valid_st)) # sort chronologically" ] }, { "cell_type": "code", "execution_count": null, "id": "7fd013f9-77d0-48de-8d5f-2c6a1cb3ed17", "metadata": {}, "outputs": [], "source": [ "# plot distribution of wet days in calibration period\n", "fig, axes = plt.subplots(2, 2, sharex=True, sharey=True, figsize=(10, 10))\n", "axes = axes.flatten()\n", "\n", "# not stratified\n", "sns.countplot(x=wet_days.sel(time=train).to_array().values.squeeze(), ax=axes[0])\n", "sns.countplot(x=wet_days.sel(time=valid).to_array().values.squeeze(), ax=axes[2])\n", "\n", "# stratified\n", "sns.countplot(x=wet_days.sel(time=train_st).to_array().values.squeeze(), ax=axes[1])\n", "sns.countplot(x=wet_days.sel(time=valid_st).to_array().values.squeeze(), ax=axes[3])\n", "\n", "# axes properties\n", "for ax in axes:\n", " ax.set_ylabel('')\n", "for ax in axes[2:]:\n", " ax.set_xticklabels(['Dry', 'Wet'])\n", "for ax in [axes[0], axes[1]]:\n", " ax.text(1, ax.get_ylim()[-1] - 5, 'Training', ha='left', va='top', fontsize=12)\n", "for ax in [axes[2], axes[3]]:\n", " ax.text(1, ax.get_ylim()[-1] - 5, 'Validation', ha='left', va='top', fontsize=12)\n", "axes[0].set_title('Not stratified')\n", "axes[1].set_title('Stratified')\n", "\n", "# adjust subplot\n", "fig.subplots_adjust(wspace=0.1, hspace=0.1)\n", "fig.suptitle('Stratified sampling: wet day threshold {:0d} mm'.format(WET_DAY_THRESHOLD));" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.10" } }, "nbformat": 4, "nbformat_minor": 5 }