# Evaluate bootstrapped model results

## Imports and constants

In [None]:
# builtins
import pathlib
import warnings

# externals
import numpy as np
import xarray as xr
import pandas as pd
from sklearn.metrics import r2_score, auc, roc_curve

# locals
from climax.core.dataset import ERA5Dataset
from climax.main.io import OBS_PATH, ERA5_PATH
from climax.main.config import VALID_PERIOD
from pysegcnn.core.utils import search_files

In [None]:
# mapping from predictands to variable names
NAMES = {'tasmin': 'minimum temperature', 'tasmax': 'maximum temperature', 'pr': 'precipitation'}

In [None]:
# path to bootstrapped model results
RESULTS = pathlib.Path('/mnt/CEPH_PROJECTS/FACT_CLIMAX/ERA5_PRED/bootstrap')

## Search model configurations

In [None]:
# predictand to evaluate
PREDICTAND = 'tasmin'

In [None]:
# whether only precipitation was used as predictor
PR_ONLY = False

In [None]:
# loss function and optimizer
LOSS = ['L1Loss', 'MSELoss', 'BernoulliGammaLoss'] if PREDICTAND == 'pr' else ['L1Loss', 'MSELoss']
OPTIM = 'Adam'

In [None]:
# model to evaluate
if PREDICTAND == 'pr' and PR_ONLY:
    models = ['USegNet_pr_pr_1mm_{}_{}'.format(PREDICTAND, loss, OPTIM) if loss == 'BernoulliGammaLoss' else
              'USegNet_pr_pr_{}_{}'.format(PREDICTAND, loss, OPTIM) for loss in LOSS]
else:
    models = ['USegNet_{}_ztuvq_500_850_p_dem_doy_1mm_{}_{}'.format(PREDICTAND, loss, OPTIM) if loss == 'BernoulliGammaLoss' else
              'USegNet_{}_ztuvq_500_850_p_dem_doy_{}_{}'.format(PREDICTAND, loss, OPTIM) for loss in LOSS]

In [None]:
# get bootstrapped models
models = {loss: sorted(search_files(RESULTS.joinpath(PREDICTAND), model + '(.*).nc$'),
                       key=lambda x: int(x.stem.split('_')[-1])) for loss, model in zip(LOSS, models)}
models

## Load datasets

### Load observations

In [None]:
# load observations
y_true = xr.open_dataset(OBS_PATH.joinpath(PREDICTAND, 'OBS_{}_1980_2018.nc'.format(PREDICTAND)),
                         chunks={'time': 365})
y_true = y_true.sel(time=VALID_PERIOD)  # subset to time period covered by predictions
y_true = y_true.rename({NAMES[PREDICTAND]: PREDICTAND}) if PREDICTAND == 'pr' else y_true

In [None]:
# mask of missing values
missing = np.isnan(y_true[PREDICTAND])

### Load reference data

In [None]:
# ERA-5 reference dataset
if PREDICTAND == 'pr':
    y_refe = xr.open_dataset(search_files(ERA5_PATH.joinpath('ERA5', 'total_precipitation'), '.nc$').pop(),
                             chunks={'time': 365})
    y_refe = y_refe.rename({'tp': 'pr'})
else:
    y_refe = xr.open_dataset(search_files(ERA5_PATH.joinpath('ERA5', '2m_{}_temperature'.format(PREDICTAND.lstrip('tas'))), '.nc$').pop(),
                             chunks={'time': 365})
    y_refe = y_refe - 273.15  # convert to Â°C
    y_refe = y_refe.rename({'t2m': PREDICTAND})

In [None]:
# subset to time period covered by predictions
y_refe = y_refe.sel(time=VALID_PERIOD).drop_vars('lambert_azimuthal_equal_area')
y_refe = y_refe.transpose('time', 'y', 'x')  # change order of dimensions

### Load QM-adjusted reference data

In [None]:
y_refe_qm = xr.open_dataset(ERA5_PATH.joinpath('QM_ERA5_{}_day_19912010.nc'.format(PREDICTAND)), chunks={'time': 365})
y_refe_qm = y_refe_qm.transpose('time', 'y', 'x')  # change order of dimensions

In [None]:
# center hours at 00:00:00 rather than 12:00:00
y_refe_qm['time'] = np.asarray([t.astype('datetime64[D]') for t in y_refe_qm.time.values])

In [None]:
# subset to time period covered by predictions
y_refe_qm = y_refe_qm.sel(time=VALID_PERIOD).drop_vars('lambert_azimuthal_equal_area')

In [None]:
# align datasets and mask missing values
y_true, y_refe, y_refe_qm = xr.align(y_true[PREDICTAND], y_refe[PREDICTAND], y_refe_qm[PREDICTAND], join='override')
y_refe = y_refe.where(~missing, other=np.nan)
y_refe_qm = y_refe_qm.where(~missing, other=np.nan)

### Load model predictions

In [None]:
y_pred_raw = {k: [xr.open_dataset(v, chunks={'time': 365}) for v in models[k]] for k in models.keys()}
if PREDICTAND == 'pr':
    y_pred_raw = {k: [v.rename({NAMES[PREDICTAND]: PREDICTAND}) if k == 'BernoulliGammaLoss' else v.rename({PREDICTAND: PREDICTAND}) for v in y_pred_raw[k]] for k in y_pred_raw.keys()}
    y_pred_raw = {k: [v.transpose('time', 'y', 'x') for v in y_pred_raw[k]] for k in y_pred_raw.keys()}

In [None]:
# align datasets and mask missing values
y_prob = {}
y_pred = {}
for loss, sim in y_pred_raw.items():
    y_pred[loss], y_prob[loss] = [], []
    for y_p in sim:
        # check whether evaluating precipitation or temperatures
        if len(y_p.data_vars) > 1:
            _, _, y_p, y_p_prob = xr.align(y_true, y_refe, y_p[PREDICTAND], y_p.prob, join='override')
            y_p_prob = y_p_prob.where(~missing, other=np.nan)  # mask missing values
            y_prob[loss].append(y_p_prob)
        else:
            _, _, y_p = xr.align(y_true, y_refe, y_p[PREDICTAND], join='override')

        # mask missing values
        y_p = y_p.where(~missing, other=np.nan)
        y_pred[loss].append(y_p)

## Ensemble predictions

In [None]:
# create and save ensemble dataset
ensemble = {k: xr.Dataset({'Member-{}'.format(i): member for i, member in enumerate(y_pred[k])}).to_array('members')
            for k in y_pred.keys() if y_pred[k]}

In [None]:
# full ensemble mean prediction and standard deviation
ensemble_mean_full = {k: v.mean(dim='members') for k, v in ensemble.items()}
ensemble_std_full = {k: v.std(dim='members') for k, v in ensemble.items()}

# Model validation

In [None]:
# extreme quantile of interest
quantile = 0.02 if PREDICTAND == 'tasmin' else 0.98

In [None]:
def r2(y_pred, y_true, precipitation=False):
    # compute daily anomalies wrt. monthly mean values
    anom_pred = ERA5Dataset.anomalies(y_pred, timescale='time.month')
    anom_true = ERA5Dataset.anomalies(y_true, timescale='time.month')
    
    # get predicted and observed daily anomalies
    y_pred_av = anom_pred.values.flatten()
    y_true_av = anom_true.values.flatten()

    # apply mask of valid pixels
    mask = (~np.isnan(y_pred_av) & ~np.isnan(y_true_av))
    y_pred_av = y_pred_av[mask]
    y_true_av = y_true_av[mask]

    # get predicted and observed monthly sums/means
    if precipitation:
        y_pred_mv = y_pred.resample(time='1M').sum(skipna=False).values.flatten()
        y_true_mv = y_true.resample(time='1M').sum(skipna=False).values.flatten()
    else:
        y_pred_mv = y_pred.groupby('time.month').mean(dim=('time')).values.flatten()
        y_true_mv = y_true.groupby('time.month').mean(dim=('time')).values.flatten()

    # apply mask of valid pixels
    mask = (~np.isnan(y_pred_mv) & ~np.isnan(y_true_mv))
    y_pred_mv = y_pred_mv[mask]
    y_true_mv = y_true_mv[mask]

    # calculate coefficient of determination on monthly sums/means
    r2_mm = r2_score(y_true_mv, y_pred_mv)
    print('R2 on monthly means: {:.2f}'.format(r2_mm))

    # calculate coefficient of determination on daily anomalies
    r2_anom = r2_score(y_true_av, y_pred_av)
    print('R2 on daily anomalies: {:.2f}'.format(r2_anom))
    
    return r2_mm, r2_anom

In [None]:
def bias(y_pred, y_true, relative=False):
    return (((y_pred - y_true) / y_true) * 100).mean().values.item() if relative else (y_pred - y_true).mean().values.item()

In [None]:
def mae(y_pred, y_true):
    return np.abs(y_pred - y_true).mean().values.item()

In [None]:
def rmse(y_pred, y_true):
    return np.sqrt(((y_pred - y_true) ** 2).mean().values.item())

## R2, Bias, MAE, and RMSE for reference data

### Metrics for mean values

In [None]:
# yearly average values over validation period
y_refe_yearly_avg = y_refe.groupby('time.year').mean(dim='time')
y_refe_qm_yearly_avg = y_refe_qm.groupby('time.year').mean(dim='time')
y_true_yearly_avg = y_true.groupby('time.year').mean(dim='time')

In [None]:
# yearly average r2, bias, mae, and rmse for ERA-5
r2_refe_mm, r2_refe_anom = r2(y_refe, y_true)
bias_refe = bias(y_refe_yearly_avg, y_true_yearly_avg, relative=True if PREDICTAND == 'pr' else False)
mae_refe = mae(y_refe_yearly_avg, y_true_yearly_avg)
rmse_refe = rmse(y_refe_yearly_avg, y_true_yearly_avg)

In [None]:
# yearly average r2, bias, mae, and rmse for QM-Adjusted ERA-5
r2_refe_qm_mm, r2_refe_qm_anom = r2(y_refe_qm, y_true)
bias_refe_qm = bias(y_refe_qm_yearly_avg, y_true_yearly_avg, relative=True if PREDICTAND == 'pr' else False)
mae_refe_qm = mae(y_refe_qm_yearly_avg, y_true_yearly_avg)
rmse_refe_qm = rmse(y_refe_qm_yearly_avg, y_true_yearly_avg)

### Metrics for extreme values

In [None]:
# calculate extreme quantile for each year
with warnings.catch_warnings():
    warnings.simplefilter('ignore', category=RuntimeWarning)
    y_true_ex = y_true.chunk(dict(time=-1)).groupby('time.year').quantile(quantile, dim='time')
    y_refe_ex = y_refe.chunk(dict(time=-1)).groupby('time.year').quantile(quantile, dim='time')
    y_refe_qm_ex = y_refe_qm.chunk(dict(time=-1)).groupby('time.year').quantile(quantile, dim='time')

In [None]:
# bias in extreme quantile
bias_ex_refe = bias(y_refe_ex, y_true_ex, relative=True if PREDICTAND == 'pr' else False)
bias_ex_refe_qm = bias(y_refe_qm_ex, y_true_ex, relative=True if PREDICTAND == 'pr' else False)

In [None]:
# mean absolute error in extreme quantile
mae_ex_refe = mae(y_refe_ex, y_true_ex)
mae_ex_refe_qm = mae(y_refe_qm_ex, y_true_ex)

In [None]:
# root mean squared error in extreme quantile
rmse_ex_refe = rmse(y_refe_ex, y_true_ex)
rmse_ex_refe_qm = rmse(y_refe_qm_ex, y_true_ex)

In [None]:
# compute validation metrics for reference datasets
filename = RESULTS.joinpath(PREDICTAND, 'reference.csv')
if filename.exists():
    # check if validation metrics for reference already exist
    df_refe = pd.read_csv(filename)
else:
    # compute validation metrics
    df_refe = pd.DataFrame([], columns=['r2_mm', 'r2_anom', 'bias', 'mae', 'rmse', 'bias_ex', 'mae_ex', 'rmse_ex', 'product'])
    for product, metrics in zip(['Era-5', 'Era-5 QM'],
                                [[r2_refe_mm, r2_refe_anom, bias_refe, mae_refe, rmse_refe, bias_ex_refe, mae_ex_refe, rmse_ex_refe],
                                 [r2_refe_qm_mm, r2_refe_qm_anom, bias_refe_qm, mae_refe_qm, rmse_refe_qm, bias_ex_refe_qm, mae_ex_refe_qm,
                                  rmse_ex_refe_qm]]):
        df_refe = df_refe.append(pd.DataFrame([metrics + [product]], columns=df_refe.columns), ignore_index=True)

    # save metrics to disk
    df_refe.to_csv(filename, index=False)

## R2, Bias, MAE, and RMSE for model predictions

### Metrics for mean values

In [None]:
# yearly average bias, mae, and rmse for each ensemble member
y_pred_yearly_avg = {k: v.groupby('time.year').mean(dim='time') for k, v in ensemble.items()}
bias_pred = {k: [bias(v[i], y_true_yearly_avg, relative=True if PREDICTAND == 'pr' else False) for i in range(len(ensemble[k]))] for k, v in y_pred_yearly_avg.items()}
mae_pred = {k: [mae(v[i], y_true_yearly_avg) for i in range(len(ensemble[k]))] for k, v in y_pred_yearly_avg.items()}
rmse_pred = {k: [rmse(v[i], y_true_yearly_avg) for i in range(len(ensemble[k]))] for k, v in y_pred_yearly_avg.items()}

### Metrics for extreme values

In [None]:
# calculate extreme quantile for each year
with warnings.catch_warnings():
    warnings.simplefilter('ignore', category=RuntimeWarning)
    y_pred_ex = {k: v.chunk(dict(time=-1)).groupby('time.year').quantile(quantile, dim='time') for k, v in ensemble.items()}

In [None]:
# yearly average bias, mae, and rmse for each ensemble member
bias_pred_ex = {k: [bias(v[i], y_true_ex, relative=True if PREDICTAND == 'pr' else False) for i in range(len(ensemble[k]))] for k, v in y_pred_ex.items()}
mae_pred_ex = {k: [mae(v[i], y_true_ex) for i in range(len(ensemble[k]))] for k, v in y_pred_ex.items()}
rmse_pred_ex = {k: [rmse(v[i], y_true_ex) for i in range(len(ensemble[k]))] for k, v in y_pred_ex.items()}

In [None]:
# compute validation metrics for model predictions
filename = (RESULTS.joinpath(PREDICTAND, 'prediction_pr-only.csv') if PREDICTAND == 'pr' and PR_ONLY else
            RESULTS.joinpath(PREDICTAND, 'prediction.csv'))
if filename.exists():
    # check if validation metrics for predictions already exist
    df_pred = pd.read_csv(filename)
else:
    # validation metrics for each ensemble member
    df_pred = pd.DataFrame([], columns=['r2_mm', 'r2_anom', 'bias', 'mae', 'rmse', 'bias_ex', 'mae_ex', 'rmse_ex', 'product', 'loss'])
    for k in y_pred_yearly_avg.keys():
        for i in range(len(ensemble[k])):
            # bias, mae, and rmse
            values = pd.DataFrame([[bias_pred[k][i], mae_pred[k][i], rmse_pred[k][i], bias_pred_ex[k][i],
                                    mae_pred_ex[k][i], rmse_pred_ex[k][i], 'Member-{:d}'.format(i), k]],
                                  columns=df_pred.columns[2:])
            
            # r2 scores
            values['r2_mm'], values['r2_anom'] = r2(ensemble[k][i], y_true, precipitation=True if PREDICTAND == 'pr' else False)
            df_pred = df_pred.append(values, ignore_index=True)
        
    # validation metrics for ensemble
    for k, v in ensemble_mean_full.items():
        # metrics for mean values
        means = v.groupby('time.year').mean(dim='time')
        bias_mean = bias(means, y_true_yearly_avg, relative=True if PREDICTAND == 'pr' else False)
        mae_mean = mae(means, y_true_yearly_avg)
        rmse_mean = rmse(means, y_true_yearly_avg)
        
        # metrics for extreme values
        with warnings.catch_warnings():
            warnings.simplefilter('ignore', category=RuntimeWarning)
            extremes = v.chunk(dict(time=-1)).groupby('time.year').quantile(quantile, dim='time')
            bias_ex = bias(extremes, y_true_ex, relative=True if PREDICTAND == 'pr' else False)
            mae_ex = mae(extremes, y_true_ex)
            rmse_ex = rmse(extremes, y_true_ex)
            
        # r2 scores
        r2_mm, r2_anom = r2(v, y_true, precipitation=True if PREDICTAND == 'pr' else False)
        df_pred = df_pred.append(pd.DataFrame([[r2_mm, r2_anom, bias_mean, mae_mean, rmse_mean, bias_ex, mae_ex, rmse_ex, 'Ensemble-{:d}'.format(len(ensemble[k])), k]],
                                              columns=df_pred.columns), ignore_index=True)

    # save metrics to disk
    df_pred.to_csv(filename, index=False)

### AUC and ROCSS for precipitation

In [None]:
def auc_rocss(p_pred, y_true, wet_day_threshold=1):
    # true and predicted probability of precipitation
    p_true = (y_true >= float(wet_day_threshold)).values.flatten()
    p_pred = p_pred.values.flatten()
    
    # apply mask of valid pixels
    mask = (~np.isnan(p_true) & ~np.isnan(p_pred))
    p_pred = p_pred[mask]
    p_true = p_true[mask].astype(float)
    
    # calculate ROC: false positive rate vs. true positive rate
    fpr, tpr, _ = roc_curve(p_true, p_pred)
    area = auc(fpr, tpr) # area under ROC curve
    rocss = 2 * area - 1 # ROC skill score (cf. https://journals.ametsoc.org/view/journals/clim/16/24/1520-0442_2003_016_4145_otrsop_2.0.co_2.xml)
    
    return area, rocss

In [None]:
if PREDICTAND == 'pr':
    # precipitation threshold to consider as wet day
    WET_DAY_THRESHOLD = 1
    
    # ensemble prediction for precipitation probability
    ensemble_prob = xr.Dataset({'Member-{}'.format(i): member for i, member in
                                enumerate(y_prob['BernoulliGammaLoss'])}).to_array('members')
    ensemble_mean_prob = ensemble_prob.mean(dim='members')
    
    # filename for probability metrics
    filename = (RESULTS.joinpath(PREDICTAND, 'probability_pr-only.csv') if PREDICTAND == 'pr' and PR_ONLY else
                RESULTS.joinpath(PREDICTAND, 'probability.csv'))
    if filename.exists():
        # check if validation metrics for probabilities already exist
        df_prob = pd.read_csv(filename)
    else:
        # AUC and ROCSS for each ensemble member
        df_prob = pd.DataFrame([], columns=['auc', 'rocss', 'product', 'loss'])
        for i in range(len(ensemble_prob)):
            auc_score, rocss = auc_rocss(ensemble_prob[i], y_true, wet_day_threshold=WET_DAY_THRESHOLD)
            df_prob = df_prob.append(pd.DataFrame([[auc_score, rocss, ensemble_prob[i].members.item(), 'BernoulliGammaLoss']],
                                                  columns=df_prob.columns), ignore_index=True)

        # AUC and ROCSS for ensemble mean
        auc_score, rocss = auc_rocss(ensemble_mean_prob, y_true, wet_day_threshold=WET_DAY_THRESHOLD)
        df_prob = df_prob.append(pd.DataFrame([[auc_score, rocss, 'Ensemble-{:d}'.format(len(ensemble_prob)), 'BernoulliGammaLoss']],
                                              columns=df_prob.columns), ignore_index=True)

        # save metrics to disk
        df_prob.to_csv(filename, index=False)