# Evaluate bootstrapped model results

### Imports and constants

In [None]:
# builtins
import pathlib

# externals
import numpy as np
import xarray as xr
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

# locals
from climax.main.io import OBS_PATH, ERA5_PATH
from climax.main.config import VALID_PERIOD
from pysegcnn.core.utils import search_files

In [None]:
# mapping from predictands to variable names
NAMES = {'tasmin': 'minimum temperature', 'tasmax': 'maximum temperature', 'pr': 'precipitation'}

In [None]:
# path to bootstrapped model results
RESULTS = pathlib.Path('/mnt/CEPH_PROJECTS/FACT_CLIMAX/ERA5_PRED/bootstrap')

## Search model configuration

In [None]:
# predictand to evaluate
PREDICTAND = 'tasmin'
LOSS = 'L1Loss'
OPTIM = 'Adam'

In [None]:
# model to evaluate
model = 'USegNet_{}_ztuvq_500_850_p_dem_doy_{}_{}'.format(PREDICTAND, LOSS, OPTIM)

In [None]:
# get bootstrapped models
models = sorted(search_files(RESULTS.joinpath(PREDICTAND), model + '(.*).nc$'),
                key=lambda x: int(x.stem.split('_')[-1]))
models

### Load observations

In [None]:
# load observations
y_true = xr.open_dataset(OBS_PATH.joinpath(PREDICTAND, 'OBS_{}_1980_2018.nc'.format(PREDICTAND)),
                         chunks={'time': 365})
y_true = y_true.sel(time=VALID_PERIOD)  # subset to time period covered by predictions
y_true = y_true.rename({NAMES[PREDICTAND]: PREDICTAND}) if PREDICTAND == 'pr' else y_true

In [None]:
# mask of missing values
missing = np.isnan(y_true[PREDICTAND])

### Load reference data

In [None]:
# ERA-5 reference dataset
if PREDICTAND == 'pr':
    y_refe = xr.open_dataset(search_files(ERA5_PATH.joinpath('ERA5', 'total_precipitation'), '.nc$').pop(),
                             chunks={'time': 365})
    y_refe = y_refe.rename({'tp': 'pr'})
else:
    y_refe = xr.open_dataset(search_files(ERA5_PATH.joinpath('ERA5', '2m_{}_temperature'.format(PREDICTAND.lstrip('tas'))), '.nc$').pop(),
                             chunks={'time': 365})
    y_refe = y_refe - 273.15  # convert to °C
    y_refe = y_refe.rename({'t2m': PREDICTAND})

In [None]:
# subset to time period covered by predictions
y_refe = y_refe.sel(time=VALID_PERIOD).drop_vars('lambert_azimuthal_equal_area')
y_refe = y_refe.transpose('time', 'y', 'x')  # change order of dimensions

### Load QM-adjusted reference data

In [None]:
y_refe_qm = xr.open_dataset(ERA5_PATH.joinpath('QM_ERA5_{}_day_19912010.nc'.format(PREDICTAND)), chunks={'time': 365})
y_refe_qm = y_refe_qm.transpose('time', 'y', 'x')  # change order of dimensions

In [None]:
# center hours at 00:00:00 rather than 12:00:00
y_refe_qm['time'] = np.asarray([t.astype('datetime64[D]') for t in y_refe_qm.time.values])

In [None]:
# subset to time period covered by predictions
y_refe_qm = y_refe_qm.sel(time=VALID_PERIOD).drop_vars('lambert_azimuthal_equal_area')

In [None]:
# align datasets and mask missing values
y_true, y_refe, y_refe_qm = xr.align(y_true[PREDICTAND], y_refe[PREDICTAND], y_refe_qm[PREDICTAND], join='override')
y_refe = y_refe.where(~missing, other=np.nan)
y_refe_qm = y_refe_qm.where(~missing, other=np.nan)

### Load model predictions

In [None]:
y_pred = [xr.open_dataset(sim, chunks={'time': 365}) for sim in models]
if PREDICTAND == 'pr':
    y_pred = [y_p.rename({NAMES[PREDICTAND]: PREDICTAND}) for y_p in y_pred]

In [None]:
# align datasets and mask missing values
y_prob = []
for i, y_p in enumerate(y_pred):
    
    # check whether evaluating precipitation or temperatures
    if len(y_p.data_vars) > 1:
        _, _, y_p, y_p_prob = xr.align(y_true, y_refe, y_p[PREDICTAND], y_p.prob, join='override')
        y_p_prob = y_p_prob.where(~missing, other=np.nan)  # mask missing values
        y_prob.append(y_p_prob)
    else:
        _, _, y_p = xr.align(y_true, y_refe, y_p[PREDICTAND], join='override')
    
    # mask missing values
    y_p = y_p.where(~missing, other=np.nan)
    y_pred[i] = y_p

## Mean time series

In [None]:
# whether to compute rolling or hard mean
ROLLING = False

In [None]:
# define scale of mean time series
# scale = '1M'  # monthly
scale = '1Y'  # yearly

In [None]:
# mean time series over entire grid and validation period
if ROLLING:
    y_pred_ts = [y_p.rolling(time=365, center=True).mean().mean(dim=('y', 'x')).dropna('time').compute() for y_p in y_pred]
    y_true_ts = y_true.rolling(time=365, center=True).mean().mean(dim=('y', 'x')).dropna('time').compute()
    y_refe_ts = y_refe.rolling(time=365, center=True).mean().mean(dim=('y', 'x')).dropna('time').compute()
    y_refe_qm_ts = y_refe_qm.rolling(time=365, center=True).mean().mean(dim=('y', 'x')).dropna('time').compute()
else:
    y_pred_ts = [y_p.resample(time=scale).mean(dim=('time', 'y', 'x')).compute() for y_p in y_pred]
    y_true_ts = y_true.resample(time=scale).mean(dim=('time', 'y', 'x')).compute()
    y_refe_ts = y_refe.resample(time=scale).mean(dim=('time', 'y', 'x')).compute()
    y_refe_qm_ts = y_refe_qm.resample(time=scale).mean(dim=('time', 'y', 'x')).compute()

In [None]:
# convert model predictions to numpy array
y_pred_ts = np.asarray([y_p for y_p in y_pred_ts])

In [None]:
# calculate quantiles for ensemble of bootstrapped models
y_pred_q = np.quantile(y_pred_ts, q=[0.25, 0.5, 0.75], axis=0)

In [None]:
# initialize figure
palette = sns.color_palette('viridis', 3)
fig, ax = plt.subplots(1, 1, figsize=(16, 9))

# time to plot on x-axis
time = y_true_ts.time if ROLLING else [t.astype('datetime64[{}]'.format(scale.lstrip('1'))) for t in y_true_ts.time.values] 
xticks = [t.astype('datetime64[Y]') for t in list(y_true_ts.time.resample(time='1Y').groups.keys())]

# plot reference: observations, ERA-5, ERA-5 QM-adjusted
ax.plot(time, y_true_ts, label='Observed', ls='-', color='k');
ax.plot(time, y_refe_ts, label='ERA-5', ls='-', color=palette[0]);
ax.plot(time, y_refe_qm_ts, label='ERA-5 QM-adjusted', ls='-', color=palette[1]);

# plot model predictions: median and IQR
ax.plot(time, y_pred_q[1, :], label='Prediction: Median', color=palette[-1])
ax.fill_between(x=time, y1=y_pred_q[0, :], y2=y_pred_q[-1, :], alpha=0.3, label='Prediction: IQR', color=palette[-1]);

# add legend
ax.legend(frameon=False, loc='lower right', fontsize=12)

# axis limits and ticks
ax.set_xticks(xticks)
ax.set_xticklabels(xticks)
ax.tick_params(axis='both', labelsize=12)

# save figure
fig.savefig('./Figures/{}_{}_{}_bootstrap_time_series_{}.png'.format(PREDICTAND, LOSS, OPTIM, scale if not ROLLING else 'rolling'),
            bbox_inches='tight', dpi=300)

### Bias, MAE, and RMSE

Calculate yearly average bias, MAE, and RMSE over entire reference period for model predictions, ERA-5, and QM-adjusted ERA-5.

In [None]:
# yearly average values over validation period
y_pred_yearly_avg = [y_p.groupby('time.year').mean(dim='time') for y_p in y_pred]
y_refe_yearly_avg = y_refe.groupby('time.year').mean(dim='time')
y_refe_qm_yearly_avg = y_refe_qm.groupby('time.year').mean(dim='time')
y_true_yearly_avg = y_true.groupby('time.year').mean(dim='time')

In [None]:
# yearly average bias, mae, and rmse for model predictions
bias_pred = [y_p - y_true_yearly_avg for y_p in y_pred_yearly_avg]
mae_pred = [np.abs(y_p - y_true_yearly_avg) for y_p in y_pred_yearly_avg]
rmse_pred = [(y_p - y_true_yearly_avg) ** 2 for y_p in y_pred_yearly_avg]

In [None]:
# yearly average bias, mae, and rmse for ERA-5
bias_refe = y_refe_yearly_avg - y_true_yearly_avg
mae_refe = np.abs(y_refe_yearly_avg - y_true_yearly_avg)
rmse_refe = (y_refe_yearly_avg - y_true_yearly_avg) ** 2

In [None]:
# yearly average bias, mae, and rmse for QM-Adjusted ERA-5
bias_refe_qm = y_refe_qm_yearly_avg - y_true_yearly_avg
mae_refe_qm = np.abs(y_refe_qm_yearly_avg - y_true_yearly_avg)
rmse_refe_qm = (y_refe_qm_yearly_avg - y_true_yearly_avg) ** 2

#### Calculate absolute values

In [None]:
# create dataframe for mean bias, mae, and rmse
df = pd.DataFrame([], columns=['bias', 'mae', 'rmse', 'product'])

In [None]:
# absolute values for the reference datasets
for product, metrics in zip(['Era-5', 'Era-5 QM'], [[bias_refe, mae_refe, rmse_refe], [bias_refe_qm, mae_refe_qm, rmse_refe_qm]]):
    values = pd.DataFrame([[np.sqrt(m.mean().values.item()) if name == 'rmse' else m.mean().values.item() for name, m in zip(['bias', 'mae', 'rmse'], metrics)] + [product]],
                          columns=df.columns)
    df = df.append(values, ignore_index=True)

In [None]:
# absolute values for the model predictions
df_pred = pd.DataFrame([], columns=['bias', 'mae', 'rmse', 'product'])
for i in range(len(bias_pred)):
    values = pd.DataFrame([[np.sqrt(m.mean().values.item()) if name == 'rmse' else m.mean().values.item()
                            for name, m in zip(['bias', 'mae', 'rmse'], [bias_pred[i], mae_pred[i], rmse_pred[i]])] + ['Prediction']], columns=df.columns)
    df_pred = df_pred.append(values, ignore_index=True)

In [None]:
# compute mean and standard deviation of ensemble members
mean = df_pred.mean(axis=0).to_frame().transpose()
std = df_pred.std(axis=0).to_frame().transpose()
mean['product'] = 'Prediction'

In [None]:
# add ensemble mean to reference dataframe
df = df.append(mean, ignore_index=True)

In [None]:
df

#### Plot spatial distributions

In [None]:
# compute ensemble median for yearly mean bias of each grid point
pred = np.median(np.stack([y_p.mean(dim='year') for y_p in bias_pred], axis=0), axis=0)

In [None]:
# plot yearly average bias of references and predictions
vmin, vmax = -1, 1
fig, axes = plt.subplots(1, 3, figsize=(24, 8), sharex=True, sharey=True)

# plot bias of ERA-5 reference
era5 = bias_refe.mean(dim='year')
im1 = axes[0].imshow(era5.values, origin='lower', cmap='RdBu_r', vmin=vmin, vmax=vmax)

# plot bias of ERA-5 QM-adjusted reference
era5_qm = bias_refe_qm.mean(dim='year')
im2 = axes[1].imshow(era5_qm.values, origin='lower', cmap='RdBu_r', vmin=vmin, vmax=vmax)

# plot bias of ensemble model prediction
im3 = axes[2].imshow(pred, origin='lower', cmap='RdBu_r', vmin=vmin, vmax=vmax)

# set titles
axes[0].set_title('Era-5', fontsize=14, pad=10);
axes[1].set_title('Era-5: QM-adjusted', fontsize=14, pad=10);
axes[2].set_title('Predictions: Median', fontsize=14, pad=10)

# adjust axes
for ax in axes.flat:
    ax.axes.get_xaxis().set_ticklabels([])
    ax.axes.get_xaxis().set_ticks([])
    ax.axes.get_yaxis().set_ticklabels([])
    ax.axes.get_yaxis().set_ticks([])
    ax.axes.axis('tight')
    ax.set_xlabel('')
    ax.set_ylabel('')
    ax.set_axis_off()

# adjust figure
fig.subplots_adjust(hspace=0, wspace=0, top=0.85)

# add colorbar
axes = axes.flatten()
cbar_ax_bias = fig.add_axes([axes[-1].get_position().x1 + 0.01, axes[-1].get_position().y0,
                             0.01, axes[-1].get_position().y1 - axes[-1].get_position().y0])
cbar_bias = fig.colorbar(im3, cax=cbar_ax_bias)
cbar_bias.set_label(label='Bias (°C)', fontsize=14)
cbar_bias.ax.tick_params(labelsize=14, pad=10)

# save figure
fig.savefig('../Notebooks/Figures/{}_{}_{}_bootstrap_bias.png'.format(PREDICTAND, LOSS, OPTIM), dpi=300, bbox_inches='tight')