### Imports

In [None]:
# builtins
import datetime
import warnings
import calendar

# externals
import xarray as xr
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
from mpl_toolkits.axes_grid1.inset_locator import inset_axes
import scipy.stats as stats
from IPython.display import Image
from sklearn.metrics import r2_score, roc_curve, auc, classification_report
from sklearn.model_selection import train_test_split

# locals
from climax.main.io import ERA5_PATH, OBS_PATH, TARGET_PATH, DEM_PATH
from climax.main.config import CALIB_PERIOD
from pysegcnn.core.utils import search_files
from pysegcnn.core.graphics import plot_classification_report

In [None]:
# model predictions and observations NetCDF 
y_true = xr.open_dataset(search_files(OBS_PATH.joinpath('pr'), 'OBS_pr(.*).nc$').pop())

In [None]:
# subset to calibration period
y_true = y_true.sel(time=CALIB_PERIOD)

In [None]:
# precipitation threshold defining a wet day
WET_DAY_THRESHOLD = 1

In [None]:
# calculate number of wet days in calibration period
wet_days = (y_true.mean(dim=('y', 'x')) >= WET_DAY_THRESHOLD).astype(np.int16)
nwet_days = wet_days.to_array().values.squeeze()

In [None]:
# split training/validation set chronologically
train, valid = train_test_split(CALIB_PERIOD, shuffle=False, test_size=0.25)

In [None]:
# split training/validation set by number of wet days
train_st, valid_st = train_test_split(CALIB_PERIOD, stratify=nwet_days, test_size=0.5)
train_st, valid_st = np.asarray(sorted(train_st)), np.asarray(sorted(valid_st))  # sort chronologically

In [None]:
# plot distribution of wet days in calibration period
fig, axes = plt.subplots(2, 2, sharex=True, sharey=True, figsize=(10, 10))
axes = axes.flatten()

# not stratified
sns.countplot(x=wet_days.sel(time=train).to_array().values.squeeze(), ax=axes[0])
sns.countplot(x=wet_days.sel(time=valid).to_array().values.squeeze(), ax=axes[2])

# stratified
sns.countplot(x=wet_days.sel(time=train_st).to_array().values.squeeze(), ax=axes[1])
sns.countplot(x=wet_days.sel(time=valid_st).to_array().values.squeeze(), ax=axes[3])

# axes properties
for ax in axes:
    ax.set_ylabel('')
for ax in axes[2:]:
    ax.set_xticklabels(['Dry', 'Wet'])
for ax in [axes[0], axes[1]]:
    ax.text(1, ax.get_ylim()[-1] - 5, 'Training', ha='left', va='top', fontsize=12)
for ax in [axes[2], axes[3]]:
    ax.text(1, ax.get_ylim()[-1] - 5, 'Validation', ha='left', va='top', fontsize=12)
axes[0].set_title('Not stratified')
axes[1].set_title('Stratified')

# adjust subplot
fig.subplots_adjust(wspace=0.1, hspace=0.1)
fig.suptitle('Stratified sampling: wet day threshold {:0d} mm'.format(WET_DAY_THRESHOLD));