Newer
Older
{
"cells": [
{
"cell_type": "markdown",
"id": "fde8874d-299f-4f48-a10a-9fb6a00b43b9",
"metadata": {},
"source": [
"# Evaluate bootstrapped model results"
]
},
{
"cell_type": "markdown",
"id": "969d063b-5262-4324-901f-0a48630c4f27",
]
},
{
"cell_type": "code",
"id": "8af00ae4-4aeb-4ff8-a46a-65966b28c440",
"metadata": {},
"outputs": [],
"source": [
"# builtins\n",
"import pathlib\n",
"\n",
"# externals\n",
"import numpy as np\n",
"import xarray as xr\n",
"import pandas as pd\n",
"from sklearn.metrics import r2_score, auc, roc_curve\n",
"\n",
"# locals\n",
"from climax.main.io import OBS_PATH, ERA5_PATH\n",
"from climax.main.config import VALID_PERIOD\n",
"from pysegcnn.core.utils import search_files"
]
},
{
"cell_type": "code",
"id": "5bc74835-dc59-46ed-849b-3ff614e53eee",
"metadata": {},
"outputs": [],
"source": [
"# mapping from predictands to variable names\n",
"NAMES = {'tasmin': 'minimum temperature', 'tasmax': 'maximum temperature', 'pr': 'precipitation'}"
]
},
{
"cell_type": "code",
"id": "c8a63ef3-35ef-4ffa-b1f3-5c2986eb7eb1",
"metadata": {},
"outputs": [],
"source": [
"# path to bootstrapped model results\n",
"RESULTS = pathlib.Path('/mnt/CEPH_PROJECTS/FACT_CLIMAX/ERA5_PRED/bootstrap')"
]
},
{
"cell_type": "markdown",
"id": "7eae545b-4d8a-4689-a6c0-4aba2cb9104e",
]
},
{
"cell_type": "code",
"metadata": {},
"outputs": [],
"source": [
"# predictand to evaluate\n",
"id": "3e856f80-14fd-405f-a44e-cc77863f8e5b",
"metadata": {},
"outputs": [],
"source": [
"# loss function and optimizer\n",
"LOSS = ['L1Loss', 'MSELoss', 'BernoulliGammaLoss'] if PREDICTAND == 'pr' else ['L1Loss', 'MSELoss']\n",
"OPTIM = 'Adam'"
]
},
{
"cell_type": "code",
"id": "011b792d-7349-44ad-997d-11f236472a11",
"metadata": {},
"outputs": [],
"source": [
"# model to evaluate\n",
"models = ['USegNet_{}_ztuvq_500_850_p_dem_doy_1mm_{}_{}'.format(PREDICTAND, loss, OPTIM) if loss == 'BernoulliGammaLoss' else\n",
" 'USegNet_{}_ztuvq_500_850_p_dem_doy_{}_{}'.format(PREDICTAND, loss, OPTIM) for loss in LOSS]"
]
},
{
"cell_type": "code",
"id": "dc4ca6f0-5490-4522-8661-e36bd1be11b7",
"metadata": {},
"source": [
"# get bootstrapped models\n",
"models = {loss: sorted(search_files(RESULTS.joinpath(PREDICTAND), model + '(.*).nc$'),\n",
" key=lambda x: int(x.stem.split('_')[-1])) for loss, model in zip(LOSS, models)}\n",
"models"
]
},
{
"cell_type": "markdown",
"id": "5a64795a-6e5c-409a-8b3b-c738a96fa255",
"metadata": {
"tags": []
},
"source": [
"## Load datasets"
]
},
{
"cell_type": "markdown",
"id": "e790ed9f-451c-4368-849d-06d9c50f797c",
"metadata": {},
"source": [
"### Load observations"
]
},
{
"cell_type": "code",
"id": "0862e0c8-06df-45d6-bc1b-002ffb6e9915",
"metadata": {},
"outputs": [],
"source": [
"# load observations\n",
"y_true = xr.open_dataset(OBS_PATH.joinpath(PREDICTAND, 'OBS_{}_1980_2018.nc'.format(PREDICTAND)),\n",
" chunks={'time': 365})\n",
"y_true = y_true.sel(time=VALID_PERIOD) # subset to time period covered by predictions\n",
"y_true = y_true.rename({NAMES[PREDICTAND]: PREDICTAND}) if PREDICTAND == 'pr' else y_true"
]
},
{
"cell_type": "code",
"id": "aba38642-85d1-404a-81f3-65d23985fb7a",
"metadata": {},
"outputs": [],
"source": [
"# mask of missing values\n",
"missing = np.isnan(y_true[PREDICTAND])"
]
},
{
"cell_type": "markdown",
"id": "d4512ed2-d503-4bc1-ae76-84560c101a14",
"metadata": {},
"source": [
"### Load reference data"
]
},
{
"cell_type": "code",
"id": "f90f6abf-5fd6-49c0-a1ad-f62242b3d3a0",
"metadata": {},
"outputs": [],
"source": [
"# ERA-5 reference dataset\n",
"if PREDICTAND == 'pr':\n",
" y_refe = xr.open_dataset(search_files(ERA5_PATH.joinpath('ERA5', 'total_precipitation'), '.nc$').pop(),\n",
" chunks={'time': 365})\n",
" y_refe = y_refe.rename({'tp': 'pr'})\n",
"else:\n",
" y_refe = xr.open_dataset(search_files(ERA5_PATH.joinpath('ERA5', '2m_{}_temperature'.format(PREDICTAND.lstrip('tas'))), '.nc$').pop(),\n",
" chunks={'time': 365})\n",
" y_refe = y_refe - 273.15 # convert to °C\n",
" y_refe = y_refe.rename({'t2m': PREDICTAND})"
]
},
{
"cell_type": "code",
"id": "ea6d5f56-4f39-4e9a-976d-00ff28fce95c",
"metadata": {},
"outputs": [],
"source": [
"# subset to time period covered by predictions\n",
"y_refe = y_refe.sel(time=VALID_PERIOD).drop_vars('lambert_azimuthal_equal_area')\n",
"y_refe = y_refe.transpose('time', 'y', 'x') # change order of dimensions"
]
},
{
"cell_type": "markdown",
"id": "d37702de-da5f-4306-acc1-e569471c1f12",
"metadata": {},
"source": [
"### Load QM-adjusted reference data"
]
},
{
"cell_type": "code",
"id": "fffbd267-d08b-44f4-869c-7056c4f19c28",
"metadata": {},
"outputs": [],
"source": [
"y_refe_qm = xr.open_dataset(ERA5_PATH.joinpath('QM_ERA5_{}_day_19912010.nc'.format(PREDICTAND)), chunks={'time': 365})\n",
"y_refe_qm = y_refe_qm.transpose('time', 'y', 'x') # change order of dimensions"
]
},
{
"cell_type": "code",
"id": "16fa580e-27a7-4758-9164-7f607df7179d",
"metadata": {},
"outputs": [],
"source": [
"# center hours at 00:00:00 rather than 12:00:00\n",
"y_refe_qm['time'] = np.asarray([t.astype('datetime64[D]') for t in y_refe_qm.time.values])"
]
},
{
"cell_type": "code",
"id": "6789791f-006b-49b3-aa04-34e4ed8e1571",
"metadata": {},
"outputs": [],
"source": [
"# subset to time period covered by predictions\n",
"y_refe_qm = y_refe_qm.sel(time=VALID_PERIOD).drop_vars('lambert_azimuthal_equal_area')"
]
},
{
"cell_type": "code",
"id": "b51cfb3f-caa8-413e-a12d-47bbafcef1df",
"metadata": {},
"outputs": [],
"source": [
"# align datasets and mask missing values\n",
"y_true, y_refe, y_refe_qm = xr.align(y_true[PREDICTAND], y_refe[PREDICTAND], y_refe_qm[PREDICTAND], join='override')\n",
"y_refe = y_refe.where(~missing, other=np.nan)\n",
"y_refe_qm = y_refe_qm.where(~missing, other=np.nan)"
]
},
{
"cell_type": "markdown",
"id": "b4a6c286-6b88-487d-866c-3cb633686dac",
"metadata": {},
"source": [
"### Load model predictions"
]
},
{
"cell_type": "code",
"metadata": {},
"outputs": [],
"source": [
"y_pred_raw = {k: [xr.open_dataset(v, chunks={'time': 365}) for v in models[k]] for k in models.keys()}\n",
"if PREDICTAND == 'pr':\n",
" y_pred_raw = {k: [v.rename({NAMES[PREDICTAND]: PREDICTAND}) if k == 'BernoulliGammaLoss' else v.rename({PREDICTAND: PREDICTAND}) for v in y_pred_raw[k]] for k in y_pred_raw.keys()}\n",
" y_pred_raw = {k: [v.transpose('time', 'y', 'x') for v in y_pred_raw[k]] for k in y_pred_raw.keys()}"
]
},
{
"cell_type": "code",
"metadata": {},
"outputs": [],
"source": [
"# align datasets and mask missing values\n",
" # check whether evaluating precipitation or temperatures\n",
" if len(y_p.data_vars) > 1:\n",
" _, _, y_p, y_p_prob = xr.align(y_true, y_refe, y_p[PREDICTAND], y_p.prob, join='override')\n",
" y_p_prob = y_p_prob.where(~missing, other=np.nan) # mask missing values\n",
" y_prob[loss].append(y_p_prob)\n",
" else:\n",
" _, _, y_p = xr.align(y_true, y_refe, y_p[PREDICTAND], join='override')\n",
"\n",
" # mask missing values\n",
" y_p = y_p.where(~missing, other=np.nan)\n",
" y_pred[loss].append(y_p)"
{
"cell_type": "markdown",
"id": "6a718ea3-54d3-400a-8c89-76d04347de2d",
"metadata": {
"tags": []
},
"source": [
]
},
{
"cell_type": "code",
"id": "5a6c0bfe-c1d2-4e43-9f8e-35c63c46bb10",
"metadata": {},
"outputs": [],
"source": [
"ensemble = {k: xr.Dataset({'Member-{}'.format(i): member for i, member in enumerate(y_pred[k])}).to_array('members')\n",
" for k in y_pred.keys() if y_pred[k]}"
]
},
{
"cell_type": "code",
"id": "0e526227-cd4c-4a1c-ab72-51b72a4f821f",
"metadata": {},
"outputs": [],
"source": [
"# full ensemble mean prediction and standard deviation\n",
"ensemble_mean_full = {k: v.mean(dim='members') for k, v in ensemble.items()}\n",
"ensemble_std_full = {k: v.std(dim='members') for k, v in ensemble.items()}"
]
},
{
"cell_type": "markdown",
"id": "f8b31e39-d4b9-4347-953f-87af04c0dd7a",
"metadata": {
"tags": []
},
"source": [
]
},
{
"cell_type": "markdown",
"id": "c4114092-81ef-4547-9485-f54a12ac2a16",
"metadata": {},
"source": [
"### Coefficient of determination"
]
},
"execution_count": null,
"id": "8aa8d57d-8e41-4c6e-a43f-063650ac8e4b",
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
"def r2(y_pred, y_true, precipitation=False):\n",
" # compute daily anomalies wrt. monthly mean values\n",
" anom_pred = ERA5Dataset.anomalies(y_pred, timescale='time.month')\n",
" anom_true = ERA5Dataset.anomalies(y_true, timescale='time.month')\n",
" \n",
" # get predicted and observed daily anomalies\n",
" y_pred_av = anom_pred.values.flatten()\n",
" y_true_av = anom_true.values.flatten()\n",
"\n",
" # apply mask of valid pixels\n",
" mask = (~np.isnan(y_pred_av) & ~np.isnan(y_true_av))\n",
" y_pred_av = y_pred_av[mask]\n",
" y_true_av = y_true_av[mask]\n",
"\n",
" # get predicted and observed monthly sums/means\n",
" if precipitation:\n",
" y_pred_mv = y_pred.resample(time='1M').sum(skipna=False).values.flatten()\n",
" y_true_mv = y_true.resample(time='1M').sum(skipna=False).values.flatten()\n",
" else:\n",
" y_pred_mv = y_pred.groupby('time.month').mean(dim=('time')).values.flatten()\n",
" y_true_mv = y_true.groupby('time.month').mean(dim=('time')).values.flatten()\n",
"\n",
" # apply mask of valid pixels\n",
" mask = (~np.isnan(y_pred_mv) & ~np.isnan(y_true_mv))\n",
" y_pred_mv = y_pred_mv[mask]\n",
" y_true_mv = y_true_mv[mask]\n",
"\n",
" # calculate coefficient of determination on monthly sums/means\n",
" r2_mm = r2_score(y_true_mv, y_pred_mv)\n",
" print('R2 on monthly means: {:.2f}'.format(r2_mm))\n",
"\n",
" # calculate coefficient of determination on daily anomalies\n",
" r2_anom = r2_score(y_true_av, y_pred_av)\n",
" print('R2 on daily anomalies: {:.2f}'.format(r2_anom))\n",
" \n",
" return r2_mm, r2_anom"
{
"cell_type": "markdown",
"id": "3e6ecc98-f32f-42f7-9971-64b270aa5453",
"metadata": {
"tags": []
},
]
},
{
"cell_type": "markdown",
"id": "671cd3c0-8d6c-41c1-bf8e-93f5943bf9aa",
"metadata": {},
"source": [
"Calculate yearly average bias, MAE, and RMSE over entire reference period for model predictions, ERA-5, and QM-adjusted ERA-5."
]
},
{
"cell_type": "code",
"id": "7939a4d2-4eff-4507-86f8-dba7c0b635df",
"metadata": {},
"outputs": [],
"source": [
"# yearly average values over validation period\n",
"y_refe_yearly_avg = y_refe.groupby('time.year').mean(dim='time')\n",
"y_refe_qm_yearly_avg = y_refe_qm.groupby('time.year').mean(dim='time')\n",
"y_true_yearly_avg = y_true.groupby('time.year').mean(dim='time')"
]
},
{
"cell_type": "code",
"id": "64e29db7-998d-4952-84b0-1c79016ab9a9",
"metadata": {},
"outputs": [],
"source": [
"# yearly average r2, bias, mae, and rmse for ERA-5\n",
"r2_refe_mm, r2_refe_anom = r2(y_refe, y_true)\n",
"bias_refe = ((y_refe_yearly_avg - y_true_yearly_avg) / y_true_yearly_avg) * 100 if PREDICTAND == 'pr' else y_refe_yearly_avg - y_true_yearly_avg\n",
"mae_refe = np.abs(y_refe_yearly_avg - y_true_yearly_avg)\n",
"rmse_refe = (y_refe_yearly_avg - y_true_yearly_avg) ** 2"
]
},
{
"cell_type": "code",
"id": "d0d4c974-876f-45e6-85cc-df91501ead20",
"metadata": {},
"outputs": [],
"source": [
"# yearly average r2, bias, mae, and rmse for QM-Adjusted ERA-5\n",
"r2_refe_qm_mm, r2_refe_qm_anom = r2(y_refe_qm, y_true)\n",
"bias_refe_qm = ((y_refe_qm_yearly_avg - y_true_yearly_avg) / y_true_yearly_avg) * 100 if PREDICTAND == 'pr' else y_refe_qm_yearly_avg - y_true_yearly_avg\n",
"mae_refe_qm = np.abs(y_refe_qm_yearly_avg - y_true_yearly_avg)\n",
"rmse_refe_qm = (y_refe_qm_yearly_avg - y_true_yearly_avg) ** 2"
]
},
{
"cell_type": "code",
"id": "d6efe5b9-3a6d-41ea-9f26-295b167cf0af",
"metadata": {},
"outputs": [],
"source": [
"# compute validation metrics for reference datasets\n",
"filename = RESULTS.joinpath(PREDICTAND, 'reference.csv')\n",
"if filename.exists():\n",
" # check if validation metrics for reference already exist\n",
" df_refe = pd.DataFrame([], columns=['r2_mm', 'r2_anom', 'bias', 'mae', 'rmse', 'product'])\n",
" for product, metrics in zip(['Era-5', 'Era-5 QM'], [[bias_refe, mae_refe, rmse_refe], [bias_refe_qm, mae_refe_qm, rmse_refe_qm]]):\n",
" values = pd.DataFrame([[np.sqrt(m.mean().values.item()) if name == 'rmse' else m.mean().values.item() for\n",
" name, m in zip(['bias', 'mae', 'rmse'], metrics)] + [product]], columns=df_refe.columns[2:])\n",
" values['r2_mm'] = r2_refe_mm if product == 'Era-5' else r2_refe_qm_mm\n",
" values['r2_anom'] = r2_refe_anom if product == 'Era-5' else r2_refe_qm_anom\n",
" df_refe = df_refe.append(values, ignore_index=True)\n",
]
},
{
"cell_type": "markdown",
"id": "258cb3c6-c2fc-457d-885e-28eaf48f1d5b",
"metadata": {
"tags": []
},
"source": [
"Calculate yearly average bias, MAE, and RMSE over entire reference period for model predictions."
]
},
{
"cell_type": "code",
"id": "6980833a-3848-43ca-bcca-d759b4fd9f69",
"metadata": {},
"outputs": [],
"source": [
"# yearly average bias, mae, and rmse for each ensemble member\n",
"y_pred_yearly_avg = {k: v.groupby('time.year').mean(dim='time') for k, v in ensemble.items()}\n",
"bias_pred = {k: ((v - y_true_yearly_avg) / y_true_yearly_avg) * 100 if PREDICTAND == 'pr'\n",
" else v - y_true_yearly_avg for k, v in y_pred_yearly_avg.items()}\n",
"mae_pred = {k: np.abs(v - y_true_yearly_avg) for k, v in y_pred_yearly_avg.items()}\n",
"rmse_pred = {k: (v - y_true_yearly_avg) ** 2 for k, v in y_pred_yearly_avg.items()}"
"id": "64f7a0b9-a772-4a03-9160-7839a48e56cd",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"# compute validation metrics for model predictions\n",
"filename = RESULTS.joinpath(PREDICTAND, 'prediction.csv')\n",
"if filename.exists():\n",
" # check if validation metrics for predictions already exist\n",
" df_pred = pd.read_csv(filename)\n",
"else:\n",
" # validation metrics for each ensemble member\n",
" df_pred = pd.DataFrame([], columns=['r2_mm', 'r2_anom', 'bias', 'mae', 'rmse', 'product', 'loss'])\n",
" for i in range(len(ensemble[k])):\n",
" # bias, mae, and rmse\n",
" values = pd.DataFrame([[np.sqrt(m.mean().values.item()) if name == 'rmse' else m.mean().values.item()\n",
" for name, m in zip(['bias', 'mae', 'rmse'], [bias_pred[k][i], mae_pred[k][i], rmse_pred[k][i]])] +\n",
" [bias_pred[k][i].members.item()] + [k]],\n",
" columns=df_pred.columns[2:])\n",
" \n",
" # r2 scores\n",
" values['r2_mm'], values['r2_anom'] = r2(ensemble[k][i], y_true, precipitation=True if PREDICTAND == 'pr' else False)\n",
" df_pred = df_pred.append(values, ignore_index=True)\n",
" \n",
" # validation metrics for ensemble\n",
" for k, v in ensemble_mean_full.items():\n",
" yearly_avg = v.groupby('time.year').mean(dim='time')\n",
" bias = ((((yearly_avg - y_true_yearly_avg) / y_true_yearly_avg) * 100).mean().values.item() if PREDICTAND == 'pr' else \n",
" (yearly_avg - y_true_yearly_avg).mean().values.item())\n",
" mae = np.abs(yearly_avg - y_true_yearly_avg).mean().values.item()\n",
" rmse = np.sqrt(((yearly_avg - y_true_yearly_avg) ** 2).mean().values.item())\n",
" r2_mm, r2_anom = r2(v, y_true, precipitation=True if PREDICTAND == 'pr' else False)\n",
" values = pd.DataFrame([[r2_mm, r2_anom, bias, mae, rmse, 'Ensemble-{:d}'.format(len(ensemble[k])), k]],\n",
" columns=df_pred.columns)\n",
" df_pred = df_pred.append(values, ignore_index=True)\n",
"\n",
" # save metrics to disk\n",
" df_pred.to_csv(filename, index=False)"
]
},
"metadata": {},
"execution_count": null,
"id": "0b7a824b-418a-4499-a3c0-627190e00941",
"def auc_rocss(p_pred, y_true, wet_day_threshold=1):\n",
" # true and predicted probability of precipitation\n",
" p_true = (y_true >= float(wet_day_threshold)).values.flatten()\n",
" p_pred = p_pred.values.flatten()\n",
" # apply mask of valid pixels\n",
" mask = (~np.isnan(p_true) & ~np.isnan(p_pred))\n",
" p_pred = p_pred[mask]\n",
" p_true = p_true[mask].astype(float)\n",
" # calculate ROC: false positive rate vs. true positive rate\n",
" fpr, tpr, _ = roc_curve(p_true, p_pred)\n",
" area = auc(fpr, tpr) # area under ROC curve\n",
" rocss = 2 * area - 1 # ROC skill score (cf. https://journals.ametsoc.org/view/journals/clim/16/24/1520-0442_2003_016_4145_otrsop_2.0.co_2.xml)\n",
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "43138ade-148a-4d8d-be48-f1280d40e5b0",
"metadata": {},
"outputs": [],
"source": [
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
"if PREDICTAND == 'pr':\n",
" # precipitation threshold to consider as wet day\n",
" WET_DAY_THRESHOLD = 1\n",
" \n",
" # ensemble prediction for precipitation probability\n",
" ensemble_prob = xr.Dataset({'Member-{}'.format(i): member for i, member in\n",
" enumerate(y_prob['BernoulliGammaLoss'])}).to_array('members')\n",
" ensemble_mean_prob = ensemble_prob.mean(dim='members')\n",
" \n",
" # filename for probability metrics\n",
" filename = RESULTS.joinpath(PREDICTAND, 'probability.csv')\n",
" \n",
" # AUC and ROCSS for each ensemble member\n",
" df_prob = pd.DataFrame([], columns=['auc', 'rocss', 'product', 'loss'])\n",
" for i in range(len(ensemble_prob)):\n",
" auc_score, rocss = auc_rocss(ensemble_prob[i], y_true, wet_day_threshold=WET_DAY_THRESHOLD)\n",
" df_prob = df_prob.append(pd.DataFrame([[auc_score, rocss, ensemble_prob[i].members.item(), 'BernoulliGammaLoss']],\n",
" columns=df_prob.columns), ignore_index=True)\n",
" \n",
" # AUC and ROCSS for ensemble mean\n",
" auc_score, rocss = auc_rocss(ensemble_mean_prob, y_true, wet_day_threshold=WET_DAY_THRESHOLD)\n",
" df_prob = df_prob.append(pd.DataFrame([[auc_score, rocss, 'Ensemble-{:d}'.format(len(ensemble_prob)), 'BernoulliGammaLoss']],\n",
" columns=df_prob.columns), ignore_index=True)\n",
" \n",
" # save metrics to disk\n",
" df_prob.to_csv(filename, index=False)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
}
},
"nbformat": 4,
"nbformat_minor": 5
}