Newer
Older
{
"cells": [
{
"cell_type": "markdown",
"id": "fde8874d-299f-4f48-a10a-9fb6a00b43b9",
"metadata": {},
"source": [
"# Evaluate bootstrapped model results"
]
},
{
"cell_type": "markdown",
"id": "969d063b-5262-4324-901f-0a48630c4f27",
]
},
{
"cell_type": "code",
"id": "8af00ae4-4aeb-4ff8-a46a-65966b28c440",
"metadata": {},
"outputs": [],
"source": [
"# builtins\n",
"import pathlib\n",
"\n",
"# externals\n",
"import numpy as np\n",
"import xarray as xr\n",
"import pandas as pd\n",
"from sklearn.metrics import r2_score, auc, roc_curve\n",
"\n",
"# locals\n",
"from climax.main.io import OBS_PATH, ERA5_PATH\n",
"from climax.main.config import VALID_PERIOD\n",
"from pysegcnn.core.utils import search_files"
]
},
{
"cell_type": "code",
"id": "5bc74835-dc59-46ed-849b-3ff614e53eee",
"metadata": {},
"outputs": [],
"source": [
"# mapping from predictands to variable names\n",
"NAMES = {'tasmin': 'minimum temperature', 'tasmax': 'maximum temperature', 'pr': 'precipitation'}"
]
},
{
"cell_type": "code",
"id": "c8a63ef3-35ef-4ffa-b1f3-5c2986eb7eb1",
"metadata": {},
"outputs": [],
"source": [
"# path to bootstrapped model results\n",
"RESULTS = pathlib.Path('/mnt/CEPH_PROJECTS/FACT_CLIMAX/ERA5_PRED/bootstrap')"
]
},
{
"cell_type": "markdown",
"id": "7eae545b-4d8a-4689-a6c0-4aba2cb9104e",
]
},
{
"cell_type": "code",
"metadata": {},
"outputs": [],
"source": [
"# predictand to evaluate\n",
"PREDICTAND = 'tasmin'"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "49e03dc8-e709-4877-922a-4914e61d7636",
"metadata": {},
"outputs": [],
"source": [
"# whether only precipitation was used as predictor\n",
"PR_ONLY = False"
"id": "3e856f80-14fd-405f-a44e-cc77863f8e5b",
"metadata": {},
"outputs": [],
"source": [
"# loss function and optimizer\n",
"LOSS = ['L1Loss', 'MSELoss', 'BernoulliGammaLoss'] if PREDICTAND == 'pr' else ['L1Loss', 'MSELoss']\n",
"OPTIM = 'Adam'"
]
},
{
"cell_type": "code",
"id": "011b792d-7349-44ad-997d-11f236472a11",
"metadata": {},
"outputs": [],
"source": [
"# model to evaluate\n",
"if PREDICTAND == 'pr' and PR_ONLY:\n",
" models = ['USegNet_pr_pr_1mm_{}_{}'.format(PREDICTAND, loss, OPTIM) if loss == 'BernoulliGammaLoss' else\n",
" 'USegNet_pr_pr_{}_{}'.format(PREDICTAND, loss, OPTIM) for loss in LOSS]\n",
"else:\n",
" models = ['USegNet_{}_ztuvq_500_850_p_dem_doy_1mm_{}_{}'.format(PREDICTAND, loss, OPTIM) if loss == 'BernoulliGammaLoss' else\n",
" 'USegNet_{}_ztuvq_500_850_p_dem_doy_{}_{}'.format(PREDICTAND, loss, OPTIM) for loss in LOSS]"
]
},
{
"cell_type": "code",
"id": "dc4ca6f0-5490-4522-8661-e36bd1be11b7",
"metadata": {},
"source": [
"# get bootstrapped models\n",
"models = {loss: sorted(search_files(RESULTS.joinpath(PREDICTAND), model + '(.*).nc$'),\n",
" key=lambda x: int(x.stem.split('_')[-1])) for loss, model in zip(LOSS, models)}\n",
"models"
]
},
{
"cell_type": "markdown",
"id": "5a64795a-6e5c-409a-8b3b-c738a96fa255",
"metadata": {
"tags": []
},
"source": [
"## Load datasets"
]
},
{
"cell_type": "markdown",
"id": "e790ed9f-451c-4368-849d-06d9c50f797c",
"metadata": {},
"source": [
"### Load observations"
]
},
{
"cell_type": "code",
"id": "0862e0c8-06df-45d6-bc1b-002ffb6e9915",
"metadata": {},
"outputs": [],
"source": [
"# load observations\n",
"y_true = xr.open_dataset(OBS_PATH.joinpath(PREDICTAND, 'OBS_{}_1980_2018.nc'.format(PREDICTAND)),\n",
" chunks={'time': 365})\n",
"y_true = y_true.sel(time=VALID_PERIOD) # subset to time period covered by predictions\n",
"y_true = y_true.rename({NAMES[PREDICTAND]: PREDICTAND}) if PREDICTAND == 'pr' else y_true"
]
},
{
"cell_type": "code",
"id": "aba38642-85d1-404a-81f3-65d23985fb7a",
"metadata": {},
"outputs": [],
"source": [
"# mask of missing values\n",
"missing = np.isnan(y_true[PREDICTAND])"
]
},
{
"cell_type": "markdown",
"id": "d4512ed2-d503-4bc1-ae76-84560c101a14",
"metadata": {},
"source": [
"### Load reference data"
]
},
{
"cell_type": "code",
"id": "f90f6abf-5fd6-49c0-a1ad-f62242b3d3a0",
"metadata": {},
"outputs": [],
"source": [
"# ERA-5 reference dataset\n",
"if PREDICTAND == 'pr':\n",
" y_refe = xr.open_dataset(search_files(ERA5_PATH.joinpath('ERA5', 'total_precipitation'), '.nc$').pop(),\n",
" chunks={'time': 365})\n",
" y_refe = y_refe.rename({'tp': 'pr'})\n",
"else:\n",
" y_refe = xr.open_dataset(search_files(ERA5_PATH.joinpath('ERA5', '2m_{}_temperature'.format(PREDICTAND.lstrip('tas'))), '.nc$').pop(),\n",
" chunks={'time': 365})\n",
" y_refe = y_refe - 273.15 # convert to °C\n",
" y_refe = y_refe.rename({'t2m': PREDICTAND})"
]
},
{
"cell_type": "code",
"id": "ea6d5f56-4f39-4e9a-976d-00ff28fce95c",
"metadata": {},
"outputs": [],
"source": [
"# subset to time period covered by predictions\n",
"y_refe = y_refe.sel(time=VALID_PERIOD).drop_vars('lambert_azimuthal_equal_area')\n",
"y_refe = y_refe.transpose('time', 'y', 'x') # change order of dimensions"
]
},
{
"cell_type": "markdown",
"id": "d37702de-da5f-4306-acc1-e569471c1f12",
"metadata": {},
"source": [
"### Load QM-adjusted reference data"
]
},
{
"cell_type": "code",
"id": "fffbd267-d08b-44f4-869c-7056c4f19c28",
"metadata": {},
"outputs": [],
"source": [
"y_refe_qm = xr.open_dataset(ERA5_PATH.joinpath('QM_ERA5_{}_day_19912010.nc'.format(PREDICTAND)), chunks={'time': 365})\n",
"y_refe_qm = y_refe_qm.transpose('time', 'y', 'x') # change order of dimensions"
]
},
{
"cell_type": "code",
"id": "16fa580e-27a7-4758-9164-7f607df7179d",
"metadata": {},
"outputs": [],
"source": [
"# center hours at 00:00:00 rather than 12:00:00\n",
"y_refe_qm['time'] = np.asarray([t.astype('datetime64[D]') for t in y_refe_qm.time.values])"
]
},
{
"cell_type": "code",
"id": "6789791f-006b-49b3-aa04-34e4ed8e1571",
"metadata": {},
"outputs": [],
"source": [
"# subset to time period covered by predictions\n",
"y_refe_qm = y_refe_qm.sel(time=VALID_PERIOD).drop_vars('lambert_azimuthal_equal_area')"
]
},
{
"cell_type": "code",
"id": "b51cfb3f-caa8-413e-a12d-47bbafcef1df",
"metadata": {},
"outputs": [],
"source": [
"# align datasets and mask missing values\n",
"y_true, y_refe, y_refe_qm = xr.align(y_true[PREDICTAND], y_refe[PREDICTAND], y_refe_qm[PREDICTAND], join='override')\n",
"y_refe = y_refe.where(~missing, other=np.nan)\n",
"y_refe_qm = y_refe_qm.where(~missing, other=np.nan)"
]
},
{
"cell_type": "markdown",
"id": "b4a6c286-6b88-487d-866c-3cb633686dac",
"metadata": {},
"source": [
"### Load model predictions"
]
},
{
"cell_type": "code",
"metadata": {},
"outputs": [],
"source": [
"y_pred_raw = {k: [xr.open_dataset(v, chunks={'time': 365}) for v in models[k]] for k in models.keys()}\n",
"if PREDICTAND == 'pr':\n",
" y_pred_raw = {k: [v.rename({NAMES[PREDICTAND]: PREDICTAND}) if k == 'BernoulliGammaLoss' else v.rename({PREDICTAND: PREDICTAND}) for v in y_pred_raw[k]] for k in y_pred_raw.keys()}\n",
" y_pred_raw = {k: [v.transpose('time', 'y', 'x') for v in y_pred_raw[k]] for k in y_pred_raw.keys()}"
]
},
{
"cell_type": "code",
"metadata": {},
"outputs": [],
"source": [
"# align datasets and mask missing values\n",
" # check whether evaluating precipitation or temperatures\n",
" if len(y_p.data_vars) > 1:\n",
" _, _, y_p, y_p_prob = xr.align(y_true, y_refe, y_p[PREDICTAND], y_p.prob, join='override')\n",
" y_p_prob = y_p_prob.where(~missing, other=np.nan) # mask missing values\n",
" y_prob[loss].append(y_p_prob)\n",
" else:\n",
" _, _, y_p = xr.align(y_true, y_refe, y_p[PREDICTAND], join='override')\n",
"\n",
" # mask missing values\n",
" y_p = y_p.where(~missing, other=np.nan)\n",
" y_pred[loss].append(y_p)"
{
"cell_type": "markdown",
"id": "6a718ea3-54d3-400a-8c89-76d04347de2d",
"metadata": {
"tags": []
},
"source": [
]
},
{
"cell_type": "code",
"id": "5a6c0bfe-c1d2-4e43-9f8e-35c63c46bb10",
"metadata": {},
"outputs": [],
"source": [
"ensemble = {k: xr.Dataset({'Member-{}'.format(i): member for i, member in enumerate(y_pred[k])}).to_array('members')\n",
" for k in y_pred.keys() if y_pred[k]}"
]
},
{
"cell_type": "code",
"id": "0e526227-cd4c-4a1c-ab72-51b72a4f821f",
"metadata": {},
"outputs": [],
"source": [
"# full ensemble mean prediction and standard deviation\n",
"ensemble_mean_full = {k: v.mean(dim='members') for k, v in ensemble.items()}\n",
"ensemble_std_full = {k: v.std(dim='members') for k, v in ensemble.items()}"
]
},
{
"cell_type": "markdown",
"id": "f8b31e39-d4b9-4347-953f-87af04c0dd7a",
"metadata": {
"tags": []
},
"source": [
]
},
"cell_type": "code",
"execution_count": null,
"id": "e8adcb5e-c7b4-4156-85b3-4751020160e6",
"# extreme quantile of interest\n",
"quantile = 0.02 if PREDICTAND == 'tasmin' else 0.98"
"execution_count": null,
"id": "8aa8d57d-8e41-4c6e-a43f-063650ac8e4b",
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
"def r2(y_pred, y_true, precipitation=False):\n",
" # compute daily anomalies wrt. monthly mean values\n",
" anom_pred = ERA5Dataset.anomalies(y_pred, timescale='time.month')\n",
" anom_true = ERA5Dataset.anomalies(y_true, timescale='time.month')\n",
" \n",
" # get predicted and observed daily anomalies\n",
" y_pred_av = anom_pred.values.flatten()\n",
" y_true_av = anom_true.values.flatten()\n",
"\n",
" # apply mask of valid pixels\n",
" mask = (~np.isnan(y_pred_av) & ~np.isnan(y_true_av))\n",
" y_pred_av = y_pred_av[mask]\n",
" y_true_av = y_true_av[mask]\n",
"\n",
" # get predicted and observed monthly sums/means\n",
" if precipitation:\n",
" y_pred_mv = y_pred.resample(time='1M').sum(skipna=False).values.flatten()\n",
" y_true_mv = y_true.resample(time='1M').sum(skipna=False).values.flatten()\n",
" else:\n",
" y_pred_mv = y_pred.groupby('time.month').mean(dim=('time')).values.flatten()\n",
" y_true_mv = y_true.groupby('time.month').mean(dim=('time')).values.flatten()\n",
"\n",
" # apply mask of valid pixels\n",
" mask = (~np.isnan(y_pred_mv) & ~np.isnan(y_true_mv))\n",
" y_pred_mv = y_pred_mv[mask]\n",
" y_true_mv = y_true_mv[mask]\n",
"\n",
" # calculate coefficient of determination on monthly sums/means\n",
" r2_mm = r2_score(y_true_mv, y_pred_mv)\n",
" print('R2 on monthly means: {:.2f}'.format(r2_mm))\n",
"\n",
" # calculate coefficient of determination on daily anomalies\n",
" r2_anom = r2_score(y_true_av, y_pred_av)\n",
" print('R2 on daily anomalies: {:.2f}'.format(r2_anom))\n",
" \n",
" return r2_mm, r2_anom"
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
{
"cell_type": "code",
"execution_count": null,
"id": "074d7405-a01b-4368-b98b-06d8d46f1ce6",
"metadata": {},
"outputs": [],
"source": [
"def bias(y_pred, y_true, relative=False):\n",
" return (((y_pred - y_true) / y_true) * 100).mean().values.item() if relative else (y_pred - y_true).mean().values.item()"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "2fc13939-e517-47bd-aa7d-0addd6715538",
"metadata": {},
"outputs": [],
"source": [
"def mae(y_pred, y_true):\n",
" return np.abs(y_pred - y_true).mean().values.item()"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "c93f497e-760e-4484-aeb6-ce54f561a7f6",
"metadata": {},
"outputs": [],
"source": [
"def rmse(y_pred, y_true):\n",
" return np.sqrt(((y_pred - y_true) ** 2).mean().values.item())"
]
},
{
"cell_type": "markdown",
"id": "3e6ecc98-f32f-42f7-9971-64b270aa5453",
"metadata": {
"tags": []
},
]
},
{
"cell_type": "markdown",
"id": "671cd3c0-8d6c-41c1-bf8e-93f5943bf9aa",
"metadata": {},
"source": [
]
},
{
"cell_type": "code",
"id": "7939a4d2-4eff-4507-86f8-dba7c0b635df",
"metadata": {},
"outputs": [],
"source": [
"# yearly average values over validation period\n",
"y_refe_yearly_avg = y_refe.groupby('time.year').mean(dim='time')\n",
"y_refe_qm_yearly_avg = y_refe_qm.groupby('time.year').mean(dim='time')\n",
"y_true_yearly_avg = y_true.groupby('time.year').mean(dim='time')"
]
},
{
"cell_type": "code",
"id": "64e29db7-998d-4952-84b0-1c79016ab9a9",
"metadata": {},
"outputs": [],
"source": [
"# yearly average r2, bias, mae, and rmse for ERA-5\n",
"r2_refe_mm, r2_refe_anom = r2(y_refe, y_true)\n",
"bias_refe = bias(y_refe_yearly_avg, y_true_yearly_avg, relative=True if PREDICTAND == 'pr' else False)\n",
"mae_refe = mae(y_refe_yearly_avg, y_true_yearly_avg)\n",
"rmse_refe = rmse(y_refe_yearly_avg, y_true_yearly_avg)"
]
},
{
"cell_type": "code",
"id": "d0d4c974-876f-45e6-85cc-df91501ead20",
"metadata": {},
"outputs": [],
"source": [
"# yearly average r2, bias, mae, and rmse for QM-Adjusted ERA-5\n",
"r2_refe_qm_mm, r2_refe_qm_anom = r2(y_refe_qm, y_true)\n",
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
"bias_refe_qm = bias(y_refe_qm_yearly_avg, y_true_yearly_avg, relative=True if PREDICTAND == 'pr' else False)\n",
"mae_refe_qm = mae(y_refe_qm_yearly_avg, y_true_yearly_avg)\n",
"rmse_refe_qm = rmse(y_refe_qm_yearly_avg, y_true_yearly_avg)"
]
},
{
"cell_type": "markdown",
"id": "c07684d1-76c0-4088-bdd7-7e1a6ccc4716",
"metadata": {},
"source": [
"### Metrics for extreme values"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "343aad59-4b0a-4eec-9ac3-86e5f9d06fc6",
"metadata": {},
"outputs": [],
"source": [
"# calculate extreme quantile for each year\n",
"with warnings.catch_warnings():\n",
" warnings.simplefilter('ignore', category=RuntimeWarning)\n",
" y_true_ex = y_true.chunk(dict(time=-1)).groupby('time.year').quantile(quantile, dim='time')\n",
" y_refe_ex = y_refe.chunk(dict(time=-1)).groupby('time.year').quantile(quantile, dim='time')\n",
" y_refe_qm_ex = y_refe_qm.chunk(dict(time=-1)).groupby('time.year').quantile(quantile, dim='time')"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "fbbc648a-82a7-4137-b8ea-5dccb56a65c7",
"metadata": {},
"outputs": [],
"source": [
"# bias in extreme quantile\n",
"bias_ex_refe = bias(y_refe_ex, y_true_ex, relative=True if PREDICTAND == 'pr' else False)\n",
"bias_ex_refe_qm = bias(y_refe_qm_ex, y_true_ex, relative=True if PREDICTAND == 'pr' else False)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "44a3a0e7-ca39-49ce-b569-51b0022161ed",
"metadata": {},
"outputs": [],
"source": [
"# mean absolute error in extreme quantile\n",
"mae_ex_refe = mae(y_refe_ex, y_true_ex)\n",
"mae_ex_refe_qm = mae(y_refe_qm_ex, y_true_ex)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "a90ce1dc-cf94-4081-9add-1d26195f2302",
"metadata": {},
"outputs": [],
"source": [
"# root mean squared error in extreme quantile\n",
"rmse_ex_refe = rmse(y_refe_ex, y_true_ex)\n",
"rmse_ex_refe_qm = rmse(y_refe_qm_ex, y_true_ex)"
]
},
{
"cell_type": "code",
"id": "d6efe5b9-3a6d-41ea-9f26-295b167cf0af",
"metadata": {},
"outputs": [],
"source": [
"# compute validation metrics for reference datasets\n",
"filename = RESULTS.joinpath(PREDICTAND, 'reference.csv')\n",
"if filename.exists():\n",
" # check if validation metrics for reference already exist\n",
" df_refe = pd.DataFrame([], columns=['r2_mm', 'r2_anom', 'bias', 'mae', 'rmse', 'bias_ex', 'mae_ex', 'rmse_ex', 'product'])\n",
" for product, metrics in zip(['Era-5', 'Era-5 QM'],\n",
" [[r2_refe_mm, r2_refe_anom, bias_refe, mae_refe, rmse_refe, bias_ex_refe, mae_ex_refe, rmse_ex_refe],\n",
" [r2_refe_qm_mm, r2_refe_qm_anom, bias_refe_qm, mae_refe_qm, rmse_refe_qm, bias_ex_refe_qm, mae_ex_refe_qm,\n",
" rmse_ex_refe_qm]]):\n",
" df_refe = df_refe.append(pd.DataFrame([metrics + [product]], columns=df_refe.columns), ignore_index=True)\n",
"\n",
]
},
{
"cell_type": "markdown",
"id": "258cb3c6-c2fc-457d-885e-28eaf48f1d5b",
"metadata": {
"tags": []
},
"source": [
]
},
{
"cell_type": "code",
"id": "6980833a-3848-43ca-bcca-d759b4fd9f69",
"metadata": {},
"outputs": [],
"source": [
"# yearly average bias, mae, and rmse for each ensemble member\n",
"y_pred_yearly_avg = {k: v.groupby('time.year').mean(dim='time') for k, v in ensemble.items()}\n",
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
"bias_pred = {k: [bias(v[i], y_true_yearly_avg, relative=True if PREDICTAND == 'pr' else False) for i in range(len(ensemble[k]))] for k, v in y_pred_yearly_avg.items()}\n",
"mae_pred = {k: [mae(v[i], y_true_yearly_avg) for i in range(len(ensemble[k]))] for k, v in y_pred_yearly_avg.items()}\n",
"rmse_pred = {k: [rmse(v[i], y_true_yearly_avg) for i in range(len(ensemble[k]))] for k, v in y_pred_yearly_avg.items()}"
]
},
{
"cell_type": "markdown",
"id": "122e84f8-211d-4816-9b03-2e1abc24eb9e",
"metadata": {},
"source": [
"### Metrics for extreme values"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "8b48a065-1a0b-4457-a97a-642c26d56c51",
"metadata": {},
"outputs": [],
"source": [
"# calculate extreme quantile for each year\n",
"with warnings.catch_warnings():\n",
" warnings.simplefilter('ignore', category=RuntimeWarning)\n",
" y_pred_ex = {k: v.chunk(dict(time=-1)).groupby('time.year').quantile(quantile, dim='time') for k, v in ensemble.items()}"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "2e6893da-271b-4b0e-bbc1-46b5eb9ecee3",
"metadata": {},
"outputs": [],
"source": [
"# yearly average bias, mae, and rmse for each ensemble member\n",
"bias_pred_ex = {k: [bias(v[i], y_true_ex, relative=True if PREDICTAND == 'pr' else False) for i in range(len(ensemble[k]))] for k, v in y_pred_ex.items()}\n",
"mae_pred_ex = {k: [mae(v[i], y_true_ex) for i in range(len(ensemble[k]))] for k, v in y_pred_ex.items()}\n",
"rmse_pred_ex = {k: [rmse(v[i], y_true_ex) for i in range(len(ensemble[k]))] for k, v in y_pred_ex.items()}"
"id": "64f7a0b9-a772-4a03-9160-7839a48e56cd",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"# compute validation metrics for model predictions\n",
"filename = (RESULTS.joinpath(PREDICTAND, 'prediction_pr-only.csv') if PREDICTAND == 'pr' and PR_ONLY else\n",
" RESULTS.joinpath(PREDICTAND, 'prediction.csv'))\n",
"if filename.exists():\n",
" # check if validation metrics for predictions already exist\n",
" df_pred = pd.read_csv(filename)\n",
"else:\n",
" # validation metrics for each ensemble member\n",
" df_pred = pd.DataFrame([], columns=['r2_mm', 'r2_anom', 'bias', 'mae', 'rmse', 'bias_ex', 'mae_ex', 'rmse_ex', 'product', 'loss'])\n",
" for i in range(len(ensemble[k])):\n",
" # bias, mae, and rmse\n",
" values = pd.DataFrame([[bias_pred[k][i], mae_pred[k][i], rmse_pred[k][i], bias_pred_ex[k][i],\n",
" mae_pred_ex[k][i], rmse_pred_ex[k][i], 'Member-{:d}'.format(i), k]],\n",
" columns=df_pred.columns[2:])\n",
" \n",
" # r2 scores\n",
" values['r2_mm'], values['r2_anom'] = r2(ensemble[k][i], y_true, precipitation=True if PREDICTAND == 'pr' else False)\n",
" df_pred = df_pred.append(values, ignore_index=True)\n",
" \n",
" # validation metrics for ensemble\n",
" for k, v in ensemble_mean_full.items():\n",
" # metrics for mean values\n",
" means = v.groupby('time.year').mean(dim='time')\n",
" bias_mean = bias(means, y_true_yearly_avg, relative=True if PREDICTAND == 'pr' else False)\n",
" mae_mean = mae(means, y_true_yearly_avg)\n",
" rmse_mean = rmse(means, y_true_yearly_avg)\n",
" \n",
" # metrics for extreme values\n",
" with warnings.catch_warnings():\n",
" warnings.simplefilter('ignore', category=RuntimeWarning)\n",
" extremes = v.chunk(dict(time=-1)).groupby('time.year').quantile(quantile, dim='time')\n",
" bias_ex = bias(extremes, y_true_ex, relative=True if PREDICTAND == 'pr' else False)\n",
" mae_ex = mae(extremes, y_true_ex)\n",
" rmse_ex = rmse(extremes, y_true_ex)\n",
" \n",
" # r2 scores\n",
" r2_mm, r2_anom = r2(v, y_true, precipitation=True if PREDICTAND == 'pr' else False)\n",
" df_pred = df_pred.append(pd.DataFrame([[r2_mm, r2_anom, bias_mean, mae_mean, rmse_mean, bias_ex, mae_ex, rmse_ex, 'Ensemble-{:d}'.format(len(ensemble[k])), k]],\n",
" columns=df_pred.columns), ignore_index=True)\n",
" # save metrics to disk\n",
" df_pred.to_csv(filename, index=False)"
]
},
"metadata": {},
"execution_count": null,
"id": "0b7a824b-418a-4499-a3c0-627190e00941",
"def auc_rocss(p_pred, y_true, wet_day_threshold=1):\n",
" # true and predicted probability of precipitation\n",
" p_true = (y_true >= float(wet_day_threshold)).values.flatten()\n",
" p_pred = p_pred.values.flatten()\n",
" # apply mask of valid pixels\n",
" mask = (~np.isnan(p_true) & ~np.isnan(p_pred))\n",
" p_pred = p_pred[mask]\n",
" p_true = p_true[mask].astype(float)\n",
" # calculate ROC: false positive rate vs. true positive rate\n",
" fpr, tpr, _ = roc_curve(p_true, p_pred)\n",
" area = auc(fpr, tpr) # area under ROC curve\n",
" rocss = 2 * area - 1 # ROC skill score (cf. https://journals.ametsoc.org/view/journals/clim/16/24/1520-0442_2003_016_4145_otrsop_2.0.co_2.xml)\n",
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "43138ade-148a-4d8d-be48-f1280d40e5b0",
"metadata": {},
"outputs": [],
"source": [
"if PREDICTAND == 'pr':\n",
" # precipitation threshold to consider as wet day\n",
" WET_DAY_THRESHOLD = 1\n",
" \n",
" # ensemble prediction for precipitation probability\n",
" ensemble_prob = xr.Dataset({'Member-{}'.format(i): member for i, member in\n",
" enumerate(y_prob['BernoulliGammaLoss'])}).to_array('members')\n",
" ensemble_mean_prob = ensemble_prob.mean(dim='members')\n",
" \n",
" # filename for probability metrics\n",
" filename = (RESULTS.joinpath(PREDICTAND, 'probability_pr-only.csv') if PREDICTAND == 'pr' and PR_ONLY else\n",
" RESULTS.joinpath(PREDICTAND, 'probability.csv'))\n",
" if filename.exists():\n",
" # check if validation metrics for probabilities already exist\n",
" df_prob = pd.read_csv(filename)\n",
" else:\n",
" # AUC and ROCSS for each ensemble member\n",
" df_prob = pd.DataFrame([], columns=['auc', 'rocss', 'product', 'loss'])\n",
" for i in range(len(ensemble_prob)):\n",
" auc_score, rocss = auc_rocss(ensemble_prob[i], y_true, wet_day_threshold=WET_DAY_THRESHOLD)\n",
" df_prob = df_prob.append(pd.DataFrame([[auc_score, rocss, ensemble_prob[i].members.item(), 'BernoulliGammaLoss']],\n",
" columns=df_prob.columns), ignore_index=True)\n",
"\n",
" # AUC and ROCSS for ensemble mean\n",
" auc_score, rocss = auc_rocss(ensemble_mean_prob, y_true, wet_day_threshold=WET_DAY_THRESHOLD)\n",
" df_prob = df_prob.append(pd.DataFrame([[auc_score, rocss, 'Ensemble-{:d}'.format(len(ensemble_prob)), 'BernoulliGammaLoss']],\n",
"\n",
" # save metrics to disk\n",
" df_prob.to_csv(filename, index=False)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
}
},
"nbformat": 4,
"nbformat_minor": 5
}