### Imports

In [None]:
# builtins
import datetime
import warnings
import calendar

# externals
import xarray as xr
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
import scipy.stats as stats
from mpl_toolkits.axes_grid1.inset_locator import inset_axes
import scipy.stats as stats
from IPython.display import Image
from sklearn.metrics import r2_score, roc_curve, auc, classification_report
from sklearn.model_selection import train_test_split

# locals
from climax.main.io import ERA5_PATH, OBS_PATH, TARGET_PATH, DEM_PATH
from climax.main.config import CALIB_PERIOD, VALID_PERIOD
from pysegcnn.core.utils import search_files
from pysegcnn.core.graphics import plot_classification_report

In [None]:
# entire reference period
REFERENCE_PERIOD = np.concatenate([CALIB_PERIOD, VALID_PERIOD], axis=0)

### Load observations

In [None]:
# model predictions and observations NetCDF 
y_true_pr = xr.open_dataset(search_files(OBS_PATH.joinpath('pr'), 'OBS_pr(.*).nc$').pop())
y_true_tmax = xr.open_dataset(search_files(OBS_PATH.joinpath('tasmax'), 'OBS_tasmax(.*).nc$').pop())
y_true_tmin = xr.open_dataset(search_files(OBS_PATH.joinpath('tasmin'), 'OBS_tasmin(.*).nc$').pop())

### Load ERA-5 reference dataset

In [None]:
# search ERA-5 reference dataset
y_refe_pr = xr.open_dataset(search_files(ERA5_PATH.joinpath('ERA5', 'total_precipitation'), '.nc$').pop())
y_refe_tmax = xr.open_dataset(search_files(ERA5_PATH.joinpath('ERA5', '2m_{}_temperature'.format('max')), '.nc$').pop())
y_refe_tmin = xr.open_dataset(search_files(ERA5_PATH.joinpath('ERA5', '2m_{}_temperature'.format('min')), '.nc$').pop())

In [None]:
# convert to Â°C
y_refe_tmax = y_refe_tmax - 273.15
y_refe_tmin = y_refe_tmin - 273.15

### Select time period

In [None]:
# time period
PERIOD = REFERENCE_PERIOD

In [None]:
# subset observations to time period
y_true_pr = y_true_pr.sel(time=PERIOD).precipitation
y_true_tmax = y_true_tmax.sel(time=PERIOD).tasmax
y_true_tmin = y_true_tmin.sel(time=PERIOD).tasmin

In [None]:
# subset Era-5 to time period
y_refe_pr = y_refe_pr.sel(time=PERIOD).drop_vars('lambert_azimuthal_equal_area').rename({'tp': 'precipitation'}).precipitation
y_refe_tmax = y_refe_tmax.sel(time=PERIOD).drop_vars('lambert_azimuthal_equal_area').rename({'t2m': 'tasmax'}).tasmax
y_refe_tmin = y_refe_tmin.sel(time=PERIOD).drop_vars('lambert_azimuthal_equal_area').rename({'t2m': 'tasmin'}).tasmin

## Align datasets

In [None]:
# precipitation
y_true_pr, y_refe_pr = xr.align(y_true_pr, y_refe_pr, join='override')
y_refe_pr = y_refe_pr.where(~np.isnan(y_true_pr), other=np.nan)

In [None]:
# tasmax
y_true_tmax, y_refe_tmax = xr.align(y_true_tmax, y_refe_tmax, join='override')
y_refe_tmax = y_refe_tmax.where(~np.isnan(y_true_tmax), other=np.nan)

In [None]:
# tasmin
y_true_tmin, y_refe_tmin = xr.align(y_true_tmin, y_refe_tmin, join='override')
y_refe_tmin = y_refe_tmin.where(~np.isnan(y_true_tmin), other=np.nan)

### Plot ERA-5 vs. Observed

In [None]:
y_refe_values = y_refe_pr.resample(time='1M').sum(skipna=False).groupby('time.month').mean(dim='time')
y_true_values = y_true_pr.resample(time='1M').sum(skipna=False).groupby('time.month').mean(dim='time')

In [None]:
bias_pr = ((y_refe_values - y_true_values) / y_true_values) * 100

In [None]:
# plot average of observation and reference
vmin, vmax = 0, 150
fig, axes = plt.subplots(1, 3, figsize=(24, 8), sharex=True, sharey=True)
axes = axes.flatten()

# plot Era-5 reanalysis
im1 = axes[0].imshow(y_refe_values.mean(dim='month'), origin='lower', cmap='viridis_r', vmin=vmin, vmax=vmax)
im2 = axes[1].imshow(y_true_values.mean(dim='month'), origin='lower', cmap='viridis_r', vmin=vmin, vmax=vmax)
im3 = axes[2].imshow(bias_pr.mean(dim='month'), origin='lower', cmap='RdBu_r', vmin=-60, vmax=60)
axes[0].set_title('ERA-5 reanalysis', fontsize=16, pad=10);
axes[1].set_title('Observations', fontsize=16, pad=10);
axes[2].set_title('Bias: ERA-5 - Observations', fontsize=16, pad=10)

# adjust axes
for ax in axes.flat:
    ax.axes.get_xaxis().set_ticklabels([])
    ax.axes.get_xaxis().set_ticks([])
    ax.axes.get_yaxis().set_ticklabels([])
    ax.axes.get_yaxis().set_ticks([])
    ax.axes.axis('tight')
    ax.set_xlabel('')
    ax.set_ylabel('')
    ax.set_axis_off()

# adjust figure
fig.suptitle('Average monthly precipitation (mm): 1980 - 2010', fontsize=20);
fig.subplots_adjust(hspace=0, wspace=0, top=0.85)

# add colorbar for bias
axes = axes.flatten()
cbar_ax_bias = fig.add_axes([axes[-1].get_position().x1 + 0.01, axes[-1].get_position().y0,
                             0.01, axes[-1].get_position().y1 - axes[-1].get_position().y0])
cbar_bias = fig.colorbar(im3, cax=cbar_ax_bias)
cbar_bias.set_label(label='Relative bias / (%)', fontsize=16)
cbar_bias.ax.tick_params(labelsize=14)

# add colorbar for predictand
cbar_ax_predictand = fig.add_axes([axes[0].get_position().x0, axes[0].get_position().y0 - 0.1,
                                   axes[-1].get_position().x0 - axes[0].get_position().x0,
                                   0.05])
cbar_predictand = fig.colorbar(im1, cax=cbar_ax_predictand, orientation='horizontal')
cbar_predictand.set_label(label='Precipitation / (mm month$^{-1}$)', fontsize=16)
cbar_predictand.ax.tick_params(labelsize=14)

# save figure
fig.savefig('../Notebooks/Figures/capstone_pr.png', dpi=300, bbox_inches='tight')