# Evaluate ERA-5 downscaling: precipitation

We used **1981-1991 as training** period and **1991-2010 as reference** period. The results shown in this notebook are based on the model predictions on the reference period.

**Predictors on pressure levels (500, 850)**:
- Geopotential (z)
- Temperature (t)
- Zonal wind (u)
- Meridional wind (v)
- Specific humidity (q)

**Predictors on surface**:
- Mean sea level pressure (msl)

**Auxiliary predictors**:
- Elevation from Copernicus EU-DEM v1.1 (dem)
- Day of the year (doy)

Define the predictand and the model to evaluate:

In [None]:
# define the model parameters
PREDICTAND = 'pr'
MODEL = 'USegNet'
PPREDICTORS = 'ztuvq'
PLEVELS = ['500', '850']
SPREDICTORS = 'p'
DEM = 'dem'
DEM_FEATURES = ''
DOY = 'doy'

### Imports

In [None]:
# builtins
import datetime
import warnings
import calendar

# externals
import xarray as xr
import numpy as np
import matplotlib.pyplot as plt
import scipy.stats as stats
from IPython.display import Image
from sklearn.metrics import r2_score, roc_curve, auc, classification_report

# locals
from climax.main.io import ERA5_PATH, OBS_PATH, TARGET_PATH
from pysegcnn.core.utils import search_files
from pysegcnn.core.graphics import plot_classification_report

In [None]:
# mapping from predictands to variable names
NAMES = {'tasmin': 'minimum temperature', 'tasmax': 'maximum temperature', 'pr': 'precipitation'}

### Model architecture

In [None]:
Image("./Figures/architecture.png", width=900, height=400)

### Loss function

For precipitation, the network is optimizing the negative log-likelihood of a Bernoulli-Gamma distribution after [Cannon (2008)](http://journals.ametsoc.org/doi/10.1175/2008JHM960.1).

Bernoulli-Gamma distribution:

$$P(y \mid, p, \alpha, \beta) = \begin{cases} 1 - p, & \text{for } y = 0\\ p \cdot \frac{y^{\alpha -1} \exp(-y/\beta)}{\beta^{\alpha} \tau(\alpha)}, & \text{for } y > 0\end{cases}$$

Log-likelihood function:

$$\mathcal{J}(p, \alpha, \beta \mid y) = \underbrace{(1 - P(y > 0)) \log(1 - p)}_{\text{Bernoulli}} + \underbrace{P(y > 0) \cdot \left(\log(p) + (\alpha - 1) \log(y) - \frac{y}{\beta} - \alpha \log(\beta) - \log(\tau(\alpha))\right)}_{\text{Gamma}}$$

### Load datasets

In [None]:
# construct file pattern to match
PATTERN = '_'.join([MODEL, PREDICTAND, PPREDICTORS, *PLEVELS])
PATTERN = '_'.join([PATTERN, SPREDICTORS]) if SPREDICTORS else PATTERN
PATTERN = '_'.join([PATTERN, DEM]) if DEM else PATTERN
PATTERN = '_'.join([PATTERN, DEM_FEATURES]) if DEM_FEATURES else PATTERN
PATTERN = '_'.join([PATTERN, DOY]) if DOY else PATTERN

In [None]:
# model predictions and observations NetCDF 
y_pred = xr.open_dataset(search_files(TARGET_PATH.joinpath(PREDICTAND), '.'.join([PATTERN, 'nc$'])).pop())
y_true = xr.open_dataset(search_files(OBS_PATH.joinpath(PREDICTAND), '.nc$').pop())

In [None]:
# subset to time period covered by predictions
y_true = y_true.sel(time=y_pred.time)  

In [None]:
# align datasets and mask missing values in model predictions
y_true, y_pred_pr, y_pred_prob = xr.align(y_true, y_pred.precipitation.to_dataset(), y_pred.prob.to_dataset(), join='override')
y_pred_pr = y_pred_pr.where(~np.isnan(y_true.precipitation), other=np.nan)
y_pred_prob = y_pred_prob.where(~np.isnan(y_true.precipitation), other=np.nan)

## Model validation: precipitation amount

### Coefficient of determination: monthly mean

In [None]:
# get predicted and observed values over entire time series and grid points
y_pred_values = y_pred_pr[NAMES[PREDICTAND]].groupby('time.month').mean(dim='time').values.flatten()
y_true_values = y_true[NAMES[PREDICTAND]].groupby('time.month').mean(dim='time').values.flatten()

In [None]:
# apply mask of valid pixels
mask = (~np.isnan(y_pred_values) & ~np.isnan(y_true_values))
y_pred_values = y_pred_values[mask]
y_true_values = y_true_values[mask]

In [None]:
# calculate coefficient of determination
r2 = r2_score(y_true_values, y_pred_values)

In [None]:
# scatter plot of observations vs. predictions
fig, ax = plt.subplots(1, 1, figsize=(10, 10))

# plot only a subset of data: otherwise plot is overloaded ...
# subset = np.random.choice(np.arange(0, len(y_pred_values)), size=int(1e3), replace=False)
# ax.plot(y_true_values[subset], y_pred_values[subset], 'o', alpha=.5, markeredgecolor='grey', markerfacecolor='none', markersize=3);

# plot entire dataset
ax.plot(y_true_values, y_pred_values, 'o', alpha=.5, markeredgecolor='grey', markerfacecolor='none', markersize=3);

# plot 1:1 mapping line
interval = np.arange(0, 11)
ax.plot(interval, interval, color='k', lw=2, ls='--')

# add coefficient of determination: calculated on entire dataset!
ax.text(interval[-1], interval[0], s='Coefficient of determination R$^2$ = {:.2f}'.format(r2), ha='right', fontsize=14)

# format axes
ax.set_ylim(interval[0], interval[-1])
ax.set_xlim(interval[0], interval[-1])
ax.set_xticks(interval)
ax.set_xticklabels(interval, fontsize=14)
ax.set_yticks(interval)
ax.set_yticklabels(interval, fontsize=14)
ax.set_xlabel('Observed', fontsize=14)
ax.set_ylabel('Predicted', fontsize=14)
ax.set_title('Monthly mean {} (mm / day): 1991 - 2010'.format(NAMES[PREDICTAND]), fontsize=16, pad=10);

# save figure
fig.savefig('../Notebooks/Figures/{}_r2.png'.format(PREDICTAND), dpi=300, bbox_inches='tight')

### Bias

Calculate yearly average bias over entire reference period:

In [None]:
# yearly average bias over reference period
y_pred_yearly_avg = y_pred_pr.groupby('time.year').mean(dim='time')
y_true_yearly_avg = y_true.groupby('time.year').mean(dim='time')
bias_yearly_avg = ((y_pred_yearly_avg - y_true_yearly_avg) / y_true_yearly_avg) * 100
for var in bias_yearly_avg:
    print('Yearly average relative bias of {}: {:.2f}%'.format(var, bias_yearly_avg[var].mean().item()))

In [None]:
# mean absolute error over reference period
mae_avg = np.abs(y_pred_yearly_avg - y_true_yearly_avg).mean()
for var in mae_avg:
    print('Yearly average MAE of {}: {:.2f} '.format(var, mae_avg[var].item()) + 'mm / day')

In [None]:
# root mean squared error over reference period
rmse_avg = ((y_pred_yearly_avg - y_true_yearly_avg) ** 2).mean()
for var in rmse_avg:
    print('Yearly average RMSE of {}: {:.2f} '.format(var, rmse_avg[var].item()) + 'mm / day')

In [None]:
# Pearson's correlation coefficient over reference period
for var in y_pred_yearly_avg:
    correlations = []
    for year in y_pred_yearly_avg.year:
        y_p = y_pred_yearly_avg[var].sel(year=year).values        
        y_t = y_true_yearly_avg[var].sel(year=year).values
        r, _ = stats.pearsonr(y_p[~np.isnan(y_p)], y_t[~np.isnan(y_t)])
        correlations.append(r)
print('Yearly average Pearson correlation coefficient for {}: {:.2f}'.format(var, np.asarray(r).mean()))

In [None]:
# plot average of observation, prediction, and bias
vmin, vmax = 0, 5
fig, axes = plt.subplots(len(y_pred_yearly_avg.data_vars), 3, figsize=(24, len(y_pred_yearly_avg.data_vars) * 6),
                         sharex=True, sharey=True)
axes = axes.reshape(len(y_pred_yearly_avg.data_vars), -1)
for i, var in enumerate(y_pred_yearly_avg):
    for ds, ax in zip([y_true_yearly_avg, y_pred_yearly_avg, bias_yearly_avg], axes[i, ...]):
        if ds is bias_yearly_avg:
            ds = ds[var].mean(dim='year')
            im2 = ax.imshow(ds.values, origin='lower', cmap='RdBu_r', vmin=-40, vmax=40)
            ax.text(x=ds.shape[0] - 2, y=2, s='Average: {:.1f}%'.format(ds.mean().item()), fontsize=14, ha='right')
        else:
            im1 = ax.imshow(ds[var].mean(dim='year').values, origin='lower', cmap='BuPu', vmin=vmin, vmax=vmax)
        
# set titles
axes[0, 0].set_title('Observed', fontsize=16, pad=10);
axes[0, 1].set_title('Predicted', fontsize=16, pad=10);
axes[0, 2].set_title('Bias', fontsize=16, pad=10);

# adjust axes
for ax in axes.flat:
    ax.axes.get_xaxis().set_ticklabels([])
    ax.axes.get_xaxis().set_ticks([])
    ax.axes.get_yaxis().set_ticklabels([])
    ax.axes.get_yaxis().set_ticks([])
    ax.axes.axis('tight')
    ax.set_xlabel('')
    ax.set_ylabel('')

# adjust figure
fig.suptitle('Average {}: 1991 - 2010'.format(NAMES[PREDICTAND]), fontsize=20);
fig.subplots_adjust(hspace=0, wspace=0, top=0.85)

# add colorbar for bias
axes = axes.flatten()
cbar_ax_bias = fig.add_axes([axes[-1].get_position().x1 + 0.01, axes[-1].get_position().y0,
                             0.01, axes[-1].get_position().y1 - axes[-1].get_position().y0])
cbar_bias = fig.colorbar(im2, cax=cbar_ax_bias)
cbar_bias.set_label(label='Relative bias / (%)', fontsize=16)
cbar_bias.ax.tick_params(labelsize=14)

# add colorbar for predictand
cbar_ax_predictand = fig.add_axes([axes[0].get_position().x0, axes[0].get_position().y0 - 0.1,
                                   axes[-1].get_position().x0 - axes[0].get_position().x0,
                                   0.05])
cbar_predictand = fig.colorbar(im1, cax=cbar_ax_predictand, orientation='horizontal')
cbar_predictand.set_label(label='{} / '.format(NAMES[PREDICTAND].capitalize()) + '(mm day$^{-1}$)', fontsize=16)
cbar_predictand.ax.tick_params(labelsize=14)

# add metrics: MAE and RMSE
axes[1].text(x=ds.shape[0] - 2, y=2, s='MAE = {:.1f}'.format(mae_avg[NAMES[PREDICTAND]].item()) + 'mm day$^{-1}$', fontsize=14, ha='right')
axes[1].text(x=ds.shape[0] - 2, y=12, s='RMSE = {:.1f}'.format(rmse_avg[NAMES[PREDICTAND]].item()) + 'mm$^2$ day$^{-2}$', fontsize=14, ha='right')

# save figure
fig.savefig('../Notebooks/Figures/{}_average_bias.png'.format(PREDICTAND), dpi=300, bbox_inches='tight')

### Seasonal bias

Calculate seasonal bias:

In [None]:
# group data by season: (DJF, MAM, JJA, SON)
y_true_snl = y_true.groupby('time.season').mean(dim='time')
y_pred_snl = y_pred_pr.groupby('time.season').mean(dim='time')
bias_snl = ((y_pred_snl - y_true_snl) / y_true_snl) * 100

In [None]:
# print average bias per season
for var in bias_snl.data_vars:
    for season in bias_snl[NAMES[PREDICTAND]].season:
        print('Average bias of mean {} for season {}: {:.1f}%'.format(var, season.values.item(), bias_snl[var].sel(season=season).mean().item()))

Plot seasonal differences, taken from the [xarray documentation](xarray.pydata.org/en/stable/examples/monthly-means.html).

In [None]:
# plot seasonal differences
seasons = ('DJF', 'JJA')
fig, axes = plt.subplots(nrows=1, ncols=len(seasons) + 1, figsize=(24,8), sharex=True, sharey=True)
axes = axes.flatten()

# plot annual average bias
ds = bias_yearly_avg[NAMES[PREDICTAND]].mean(dim='year')
axes[0].imshow(ds.values, origin='lower', cmap='RdBu_r', vmin=-40, vmax=40)
axes[0].set_title('Annual', fontsize=16);
axes[0].text(x=ds.shape[0] - 2, y=2, s='Average: {:.1f}%'.format(ds.mean().item()), fontsize=14, ha='right')

# plot seasonal average bias
for ax, season in zip(axes[1:], seasons):
    ds = bias_snl[NAMES[PREDICTAND]].sel(season=season)
    ax.imshow(ds.values, origin='lower', cmap='RdBu_r', vmin=-40, vmax=40)
    ax.set_title(season, fontsize=16);
    ax.text(x=ds.shape[0] - 2, y=2, s='Average: {:.1f}%'.format(ds.mean().item()), fontsize=14, ha='right')

# adjust axes
for ax in axes.flat:
    ax.axes.get_xaxis().set_ticklabels([])
    ax.axes.get_xaxis().set_ticks([])
    ax.axes.get_yaxis().set_ticklabels([])
    ax.axes.get_yaxis().set_ticks([])
    ax.axes.axis('tight')
    ax.set_xlabel('')
    ax.set_ylabel('')

# adjust figure
fig.suptitle('Average bias of {}: 1991 - 2010'.format(NAMES[PREDICTAND]), fontsize=20);
fig.subplots_adjust(hspace=0, wspace=0, top=0.85)

# add colorbar for bias
axes = axes.flatten()
cbar_ax = fig.add_axes([axes[-1].get_position().x1 + 0.01, axes[-1].get_position().y0,
                        0.01, axes[-1].get_position().y1 - axes[-1].get_position().y0])
cbar = fig.colorbar(im2, cax=cbar_ax)
cbar.set_label(label='Relative bias / (%)', fontsize=16)
cbar.ax.tick_params(labelsize=14)

# save figure
fig.savefig('../Notebooks/Figures/{}_average_bias_seasonal.png'.format(PREDICTAND), dpi=300, bbox_inches='tight')

### Bias of extreme values

In [None]:
# extreme quantile of interest
quantile = 0.98

In [None]:
# calculate extreme quantile for each year
with warnings.catch_warnings():
    warnings.simplefilter('ignore', category=RuntimeWarning)
    y_pred_ex = y_pred_pr.groupby('time.year').quantile(quantile, dim='time')
    y_true_ex = y_true.groupby('time.year').quantile(quantile, dim='time')

In [None]:
# calculate bias in extreme quantile for each year
bias_ex = ((y_pred_ex - y_true_ex) / y_true_ex) * 100
for var in bias_ex:
    print('Yearly average bias for P{:.0f} of {}: {:.1f}%'.format(quantile * 100, var, bias_ex[var].mean().item()))

In [None]:
# mean absolute error in extreme quantile
mae_ex = np.abs(y_pred_ex - y_true_ex).mean()
for var in mae_ex:
    print('Yearly average MAE for P{:.0f} of {}: {:.1f} mm / day'.format(quantile * 100, var, mae_ex[var].item()))

In [None]:
# root mean squared error over reference period
rmse_ex = ((y_pred_ex - y_true_ex) ** 2).mean()
for var in rmse_ex:
    print('Yearly average RMSE for P{:.0f} of {}: {:.1f} mm / day'.format(quantile * 100, var, rmse_ex[var].item()))

In [None]:
# plot extremes of observation, prediction, and bias
vmin, vmax = 10, 40
fig, axes = plt.subplots(len(y_pred_ex.data_vars), 3, figsize=(24, len(y_pred_ex.data_vars) * 6),
                         sharex=True, sharey=True)
axes = axes.reshape(len(y_pred_ex.data_vars), -1)
for i, var in enumerate(y_pred_ex):
    for ds, ax in zip([y_true_ex, y_pred_ex, bias_ex], axes[i, ...]):
        if ds is bias_ex:
            ds = ds[var].mean(dim='year')
            im2 = ax.imshow(ds.values, origin='lower', cmap='RdBu_r', vmin=-40, vmax=40)
            ax.text(x=ds.shape[0] - 2, y=2, s='Average: {:.1f}%'.format(ds.mean().item()), fontsize=14, ha='right')
        else:
            im1 = ax.imshow(ds[var].mean(dim='year').values, origin='lower', cmap='BuPu', vmin=vmin, vmax=vmax)
        
# set titles
axes[0, 0].set_title('Observed', fontsize=16, pad=10);
axes[0, 1].set_title('Predicted', fontsize=16, pad=10);
axes[0, 2].set_title('Bias', fontsize=16, pad=10);

# adjust axes
for ax in axes.flat:
    ax.axes.get_xaxis().set_ticklabels([])
    ax.axes.get_xaxis().set_ticks([])
    ax.axes.get_yaxis().set_ticklabels([])
    ax.axes.get_yaxis().set_ticks([])
    ax.axes.axis('tight')
    ax.set_xlabel('')
    ax.set_ylabel('')

# adjust figure
fig.suptitle('Average P{:.0f} of {}: 1991 - 2010'.format(quantile * 100, NAMES[PREDICTAND]), fontsize=20);
fig.subplots_adjust(hspace=0, wspace=0, top=0.85)

# add colorbar for bias
axes = axes.flatten()
cbar_ax = fig.add_axes([axes[-1].get_position().x1 + 0.01, axes[-1].get_position().y0,
                        0.01, axes[-1].get_position().y1 - axes[-1].get_position().y0])
cbar = fig.colorbar(im2, cax=cbar_ax)
cbar.set_label(label='Relative bias / (%)', fontsize=16)
cbar.ax.tick_params(labelsize=14)

# add colorbar for predictand
cbar_ax_predictand = fig.add_axes([axes[0].get_position().x0, axes[0].get_position().y0 - 0.1,
                                   axes[-1].get_position().x0 - axes[0].get_position().x0,
                                   0.05])
cbar_predictand = fig.colorbar(im1, cax=cbar_ax_predictand, orientation='horizontal')
cbar_predictand.set_label(label='{} / '.format(NAMES[PREDICTAND].capitalize()) + '(mm day$^{-1}$)', fontsize=16)
cbar_predictand.ax.tick_params(labelsize=14)

# add metrics: MAE and RMSE
axes[1].text(x=ds.shape[0] - 2, y=2, s='MAE = {:.1f}'.format(mae_ex[NAMES[PREDICTAND]].item())  + 'mm day$^{-1}$', fontsize=14, ha='right')
axes[1].text(x=ds.shape[0] - 2, y=12, s='RMSE = {:.1f}'.format(rmse_ex[NAMES[PREDICTAND]].item())  + 'mm$^2$ day$^{-2}$', fontsize=14, ha='right')

# save figure
fig.savefig('../Notebooks/Figures/{}_average_bias_p{:.0f}.png'.format(PREDICTAND, quantile * 100), dpi=300, bbox_inches='tight')

### Bias of extremes: winter vs. summer

In [None]:
# group data by season and compute extreme percentile
with warnings.catch_warnings():
    warnings.simplefilter('ignore', category=RuntimeWarning)
    y_true_ex_snl = y_true.groupby('time.season').quantile(quantile, dim='time')
    y_pred_ex_snl = y_pred_pr.groupby('time.season').quantile(quantile, dim='time')

In [None]:
# compute relative bias in seasonal extremes
bias_ex_snl = ((y_pred_ex_snl - y_true_ex_snl) / y_true_ex_snl) * 100

In [None]:
# print average bias in extreme per season
for var in bias_ex_snl.data_vars:
    for season in bias_ex_snl[NAMES[PREDICTAND]].season:
        print('Average bias of P{:.0f} {} for season {}: {:.1f}%'.format(quantile * 100, var, season.values.item(), bias_ex_snl[var].sel(season=season).mean().item()))

In [None]:
# plot seasonal differences
seasons = ('DJF', 'JJA')
fig, axes = plt.subplots(nrows=1, ncols=len(seasons) + 1, figsize=(24,8), sharex=True, sharey=True)
axes = axes.flatten()

# plot annual average bias of extreme
ds = bias_ex[NAMES[PREDICTAND]].mean(dim='year')
axes[0].imshow(ds.values, origin='lower', cmap='RdBu_r', vmin=-40, vmax=40)
axes[0].set_title('Annual', fontsize=16);
axes[0].text(x=ds.shape[0] - 2, y=2, s='Average: {:.1f}%'.format(ds.mean().item()), fontsize=14, ha='right')

# plot seasonal average bias of extreme
for ax, season in zip(axes[1:], seasons):
    ds = bias_ex_snl[NAMES[PREDICTAND]].sel(season=season)
    ax.imshow(ds.values, origin='lower', cmap='RdBu_r', vmin=-40, vmax=40)
    ax.set_title(season, fontsize=16);
    ax.text(x=ds.shape[0] - 2, y=2, s='Average: {:.1f}%'.format(ds.mean().item()), fontsize=14, ha='right')

# adjust axes
for ax in axes.flat:
    ax.axes.get_xaxis().set_ticklabels([])
    ax.axes.get_xaxis().set_ticks([])
    ax.axes.get_yaxis().set_ticklabels([])
    ax.axes.get_yaxis().set_ticks([])
    ax.axes.axis('tight')
    ax.set_xlabel('')
    ax.set_ylabel('')

# adjust figure
fig.suptitle('Average bias of P{:.0f} of {}: 1991 - 2010'.format(quantile * 100, NAMES[PREDICTAND]), fontsize=20);
fig.subplots_adjust(hspace=0, wspace=0, top=0.85)

# add colorbar for bias
axes = axes.flatten()
cbar_ax = fig.add_axes([axes[-1].get_position().x1 + 0.01, axes[-1].get_position().y0,
                        0.01, axes[-1].get_position().y1 - axes[-1].get_position().y0])
cbar = fig.colorbar(im2, cax=cbar_ax)
cbar.set_label(label='Relative bias / (%)', fontsize=16)
cbar.ax.tick_params(labelsize=14)

# save figure
fig.savefig('../Notebooks/Figures/{}_average_bias_seasonal_ex.png'.format(PREDICTAND), dpi=300, bbox_inches='tight')

### Frequency of wet days

In [None]:
# minimum precipitation (mm / day) defining a wet day
WET_DAY_THRESHOLD = 1

In [None]:
# true and predicted frequency of wet days
mask = (~np.isnan(y_true)) & (~np.isnan(y_pred_pr))
wet_days_true = (y_true >= WET_DAY_THRESHOLD).where(mask, other=np.nan).astype(np.float32)
wet_days_pred = (y_pred_pr >= WET_DAY_THRESHOLD).where(mask, other=np.nan).astype(np.float32)

In [None]:
# number of wet days in reference period: annual
n_wet_days_true = wet_days_true.sum(dim='time', skipna=False)
n_wet_days_pred = wet_days_pred.sum(dim='time', skipna=False)

In [None]:
# frequency of wet days in reference period: annual
f_wet_days_true = (n_wet_days_true / len(wet_days_true.time)) * 100
f_wet_days_pred = (n_wet_days_pred / len(wet_days_pred.time)) * 100

In [None]:
# frequency of wet days in reference period: seasonal
f_wet_days_true_snl = wet_days_true.groupby('time.season').mean(dim='time', skipna=False)
f_wet_days_pred_snl = wet_days_pred.groupby('time.season').mean(dim='time', skipna=False)

In [None]:
# relative bias of frequency of wet vs. dry days: annual
bias_wet = ((f_wet_days_pred - f_wet_days_true) / f_wet_days_true) * 100

# relative bias of frequency of wet vs. dry days: seasonal
bias_wet_snl = ((f_wet_days_pred_snl - f_wet_days_true_snl) / f_wet_days_true_snl) * 100

In [None]:
# plot average of observation, prediction, and bias
fig, axes = plt.subplots(2, 3, figsize=(24, 12), sharex=True, sharey=True)
axes = axes.flatten()

# plot annual average bias of extreme
ds = bias_wet[NAMES[PREDICTAND]]
im = axes[0].imshow(ds.values, origin='lower', cmap='RdBu_r', vmin=-40, vmax=40)
axes[0].set_title('Annual', fontsize=16);
axes[0].text(x=ds.shape[0] - 2, y=2, s='Average: {:.1f}%'.format(ds.mean().item()), fontsize=14, ha='right')

# plot seasonal average bias of extreme
for ax, season in zip(axes[1:], bias_wet_snl.season):
    ds = bias_wet_snl[NAMES[PREDICTAND]].sel(season=season)
    ax.imshow(ds.values, origin='lower', cmap='RdBu_r', vmin=-40, vmax=40)
    ax.set_title(season.item(), fontsize=16);
    ax.text(x=ds.shape[0] - 2, y=2, s='Average: {:.1f}%'.format(ds.mean().item()), fontsize=14, ha='right')

# adjust axes
for ax in axes.flat:
    ax.axes.get_xaxis().set_ticklabels([])
    ax.axes.get_xaxis().set_ticks([])
    ax.axes.get_yaxis().set_ticklabels([])
    ax.axes.get_yaxis().set_ticks([])
    ax.axes.axis('tight')
    ax.set_xlabel('')
    ax.set_ylabel('')
    
# turn off last axis
axes[-1].set_visible(False)

# adjust figure
fig.suptitle('Frequency of wet days (>= {:.1f} mm): 1991 - 2010'.format(WET_DAY_THRESHOLD), fontsize=20);
fig.subplots_adjust(hspace=0.1, wspace=0, top=0.925)

# add colorbar
cbar_ax_predictand = fig.add_axes([axes[-1].get_position().x1 + 0.01, axes[-1].get_position().y0,
                                   0.01, axes[0].get_position().y1 - axes[-1].get_position().y0])
cbar_predictand = fig.colorbar(im, cax=cbar_ax_predictand)
cbar_predictand.set_label(label='Relative bias / (%)', fontsize=16)
cbar_predictand.ax.tick_params(labelsize=14)

# save figure
fig.savefig('../Notebooks/Figures/{}_bias_wet_days.png'.format(PREDICTAND), dpi=300, bbox_inches='tight')

### Mean wet day precipitation

In [None]:
# calculate mean wet day precipitation
dii_true = (y_true * wet_days_true).sum(dim='time', skipna=False) / n_wet_days_true
dii_pred = (y_pred_pr * wet_days_pred).sum(dim='time', skipna=False) / n_wet_days_pred

In [None]:
# calculate relative bias of mean wet day precipitation
bias_dii = ((dii_pred - dii_true) / dii_true) * 100

In [None]:
# plot average of observation, prediction, and bias
fig, axes = plt.subplots(1, 3, figsize=(24, 6), sharex=True, sharey=True)
for i, var in enumerate(dii_true):
    for ds, ax in zip([dii_true, dii_pred, bias_dii], axes):
        if ds is bias_dii:
            im2 = ax.imshow(ds[var].values, origin='lower', cmap='RdBu_r', vmin=-40, vmax=40)
            ax.text(x=ds[var].shape[0] - 2, y=2, s='Average: {:.1f}%'.format(ds[var].mean().item()), fontsize=14, ha='right')
        else:
            im1 = ax.imshow(ds[var].values, origin='lower', cmap='BuPu', vmin=0, vmax=15)
        
# set titles
axes[0].set_title('Observed', fontsize=16, pad=10);
axes[1].set_title('Predicted', fontsize=16, pad=10);
axes[2].set_title('Bias', fontsize=16, pad=10);

# adjust axes
for ax in axes.flat:
    ax.axes.get_xaxis().set_ticklabels([])
    ax.axes.get_xaxis().set_ticks([])
    ax.axes.get_yaxis().set_ticklabels([])
    ax.axes.get_yaxis().set_ticks([])
    ax.axes.axis('tight')
    ax.set_xlabel('')
    ax.set_ylabel('')

# adjust figure
fig.suptitle('Mean wet day (>= {:.1f} mm) precipitation: 1991 - 2010'.format(WET_DAY_THRESHOLD), fontsize=20);
fig.subplots_adjust(hspace=0, wspace=0, top=0.85)

# add colorbar for bias
axes = axes.flatten()
cbar_ax_bias = fig.add_axes([axes[-1].get_position().x1 + 0.01, axes[-1].get_position().y0,
                             0.01, axes[-1].get_position().y1 - axes[-1].get_position().y0])
cbar_bias = fig.colorbar(im2, cax=cbar_ax_bias)
cbar_bias.set_label(label='Relative bias / (%)', fontsize=16)
cbar_bias.ax.tick_params(labelsize=14)

# add colorbar for predictand
cbar_ax_predictand = fig.add_axes([axes[0].get_position().x0, axes[0].get_position().y0 - 0.1,
                                   axes[-1].get_position().x0 - axes[0].get_position().x0,
                                   0.05])
cbar_predictand = fig.colorbar(im1, cax=cbar_ax_predictand, orientation='horizontal')
cbar_predictand.set_label(label='Mean wet day precipitation / (mm day$^{-1}$)', fontsize=16)
cbar_predictand.ax.tick_params(labelsize=14)

# save figure
fig.savefig('../Notebooks/Figures/{}_bias_wet_days_p.png'.format(PREDICTAND), dpi=300, bbox_inches='tight')

## Model validation: precipitation probability

### ROC: Receiver operating characteristics

In [None]:
# true and predicted probability of precipitation
p_true = (y_true[NAMES[PREDICTAND]] > 0).values.flatten()
p_pred = y_pred_prob.prob.values.flatten()

In [None]:
# apply mask of valid pixels
mask = (~np.isnan(p_true) & ~np.isnan(p_pred))
p_pred = p_pred[mask]
p_true = p_true[mask].astype(float)

In [None]:
# calculate ROC: false positive rate vs. true positive rate
fpr, tpr, _ = roc_curve(p_true, p_pred)
area = auc(fpr, tpr) # area under ROC curve
rocss = 2 * area - 1 # ROC skill score (cf. https://journals.ametsoc.org/view/journals/clim/16/24/1520-0442_2003_016_4145_otrsop_2.0.co_2.xml)

In [None]:
# plot ROC curve
fig, ax = plt.subplots(1, 1, figsize=(10, 10))
ax.plot(fpr, tpr, lw=2, label='Area={:.2f}, ROCSS={:.2f}'.format(area, rocss), color='k')

# plot classifier with no skill
interval = np.arange(-0.05, 1.1, 0.05)
ax.plot([0, 1], [0, 1], lw=2, linestyle='--', color='k')
ax.text(0.95, 0.975, 'Random Classifier', ha='right', va='top', rotation=45, fontsize=12)

# plot perfect classifier
ax.plot(0, 1, '-o', markersize=5, markerfacecolor='k', markeredgecolor='none')
ax.text(0.02, 1, 'Perfect classifier', va='center', fontsize=12)

# plot direction of increase / decrease
ax.arrow(np.median(interval), np.median(interval), 0.1, -0.1, head_width=0.01, facecolor='k')
ax.arrow(np.median(interval), np.median(interval), -0.1, 0.1, head_width=0.01, facecolor='k')
ax.text(np.median(interval) + 0.05, np.median(interval) - 0.05, s='Worse', rotation=45, ha='left', fontsize=12)
ax.text(np.median(interval) - 0.05, np.median(interval) + 0.05, s='Better', rotation=45, ha='left', fontsize=12)

# adjust axes
ax.set_xticks(np.arange(0, 1.1, 0.1))
ax.set_xticklabels(['{:.2f}'.format(i) for i in np.arange(0, 1.1, 0.1)], fontsize=12)
ax.set_yticks(np.arange(0, 1.1, 0.1))
ax.set_yticklabels(['{:.2f}'.format(i) for i in np.arange(0, 1.1, 0.1)], fontsize=12)
ax.set_xlim(interval[0], interval[-1])
ax.set_ylim(interval[0], interval[-1])
ax.set_xlabel('False Positive Rate', fontsize=14)
ax.set_ylabel('True Positive Rate', fontsize=14)
ax.set_title('ROC of precipitation probability: 1991 - 2010', fontsize=14, pad=10)
ax.legend(frameon=False, loc='lower right', fontsize=14);

# save figure
fig.savefig('../Notebooks/Figures/{}_ROC.png'.format(PREDICTAND), dpi=300, bbox_inches='tight')