# Evaluate ERA-5 downscaling

Define the predictand and the model to evaluate

In [None]:
# define the predictand and the model to evaluate
PREDICTAND = 'tasmin'
MODEL = 'USegNet'

### Imports

In [None]:
# builtins
import datetime

# externals
import xarray as xr
import numpy as np
import matplotlib.pyplot as plt

# locals
from climax.main.io import ERA5_PATH, OBS_PATH, TARGET_PATH

### Load datasets

In [None]:
# model predictions and observations NetCDF
y_pred = TARGET_PATH.joinpath(PREDICTAND, '_'.join([MODEL, PREDICTAND]) + '.nc')
y_true = OBS_PATH.joinpath(PREDICTAND, '_'.join(['OBS', PREDICTAND, '1980', '2018']) + '.nc')

In [None]:
# load datasets
y_pred = xr.open_dataset(y_pred)
y_true = xr.open_dataset(y_true).sel(time=y_pred.time) # subset to time period covered by predictions

In [None]:
# replace variable name by predictand
y_true = y_true.rename({var: PREDICTAND for var in y_true.data_vars})

In [None]:
# align datasets and mask missing values in model predictions
y_true, y_pred = xr.align(y_true, y_pred, join='override')
y_pred = y_pred.where(~np.isnan(y_true), other=np.nan)

## Model validation

### Overall bias

Calculate average bias over entire reference period:

In [None]:
# average bias over reference period
y_pred_avg = y_pred.mean(dim='time')
y_true_avg = y_true.mean(dim='time')
bias = y_pred_avg - y_true_avg
print('Overall average bias: {:.2f}'.format(bias[PREDICTAND].mean().item()))

In [None]:
# plot average of observation, prediction, and bias
fig, axes = plt.subplots(1, 3, figsize=(24, 6))
for ds, ax, title in zip([y_true_avg, y_pred_avg, bias], axes, ['Observed', 'Predicted', 'Difference']):
 ds[PREDICTAND].plot(ax=ax)
 ax.set_title(title)

### Seasonal bias

Calculate seasonal bias:

In [None]:
# group data by season: (DJF, MAM, JJA, SON)
y_true_snl = y_true.groupby('time.season').mean(dim='time')
y_pred_snl = y_pred.groupby('time.season').mean(dim='time')
bias_snl = y_pred_snl - y_true_snl

Plot seasonal differences, taken from the [xarray documentation](xarray.pydata.org/en/stable/examples/monthly-means.html).

In [None]:
# plot seasonal differences
fig, axes = plt.subplots(nrows=4, ncols=3, figsize=(14,12))
for i, season in enumerate(('DJF', 'MAM', 'JJA', 'SON')):
 y_true_snl[PREDICTAND].sel(season=season).plot.pcolormesh(
 ax=axes[i, 0], add_colorbar=True, extend='both')

 y_pred_snl[PREDICTAND].sel(season=season).plot.pcolormesh(
 ax=axes[i, 1], add_colorbar=True, extend='both')

 bias_snl[PREDICTAND].sel(season=season).plot.pcolormesh(
 ax=axes[i, 2], vmin=-1, vmax=1, add_colorbar=True,
 extend='both')

 axes[i, 0].set_ylabel(season)
 axes[i, 1].set_ylabel('')
 axes[i, 2].set_ylabel('')

for ax in axes.flat:
 ax.axes.get_xaxis().set_ticklabels([])
 ax.axes.get_yaxis().set_ticklabels([])
 ax.axes.axis('tight')
 ax.set_xlabel('')

axes[0, 0].set_title('Observed')
axes[0, 1].set_title('Predicted')
axes[0, 2].set_title('Difference')

plt.tight_layout()