{ "cells": [ { "cell_type": "markdown", "id": "fde8874d-299f-4f48-a10a-9fb6a00b43b9", "metadata": {}, "source": [ "# Evaluate bootstrapped model results" ] }, { "cell_type": "markdown", "id": "969d063b-5262-4324-901f-0a48630c4f27", "metadata": { "tags": [] }, "source": [ "## Imports and constants" ] }, { "cell_type": "code", "execution_count": null, "id": "8af00ae4-4aeb-4ff8-a46a-65966b28c440", "metadata": {}, "outputs": [], "source": [ "# builtins\n", "import pathlib\n", "import warnings\n", "\n", "# externals\n", "import numpy as np\n", "import xarray as xr\n", "import pandas as pd\n", "from sklearn.metrics import r2_score, auc, roc_curve\n", "\n", "# locals\n", "from climax.core.dataset import ERA5Dataset\n", "from climax.main.io import OBS_PATH, ERA5_PATH\n", "from climax.main.config import VALID_PERIOD\n", "from pysegcnn.core.utils import search_files" ] }, { "cell_type": "code", "execution_count": null, "id": "5bc74835-dc59-46ed-849b-3ff614e53eee", "metadata": {}, "outputs": [], "source": [ "# mapping from predictands to variable names\n", "NAMES = {'tasmin': 'minimum temperature', 'tasmax': 'maximum temperature', 'pr': 'precipitation'}" ] }, { "cell_type": "code", "execution_count": null, "id": "c8a63ef3-35ef-4ffa-b1f3-5c2986eb7eb1", "metadata": {}, "outputs": [], "source": [ "# path to bootstrapped model results\n", "RESULTS = pathlib.Path('/mnt/CEPH_PROJECTS/FACT_CLIMAX/ERA5_PRED/bootstrap')" ] }, { "cell_type": "markdown", "id": "7eae545b-4d8a-4689-a6c0-4aba2cb9104e", "metadata": { "tags": [] }, "source": [ "## Search model configurations" ] }, { "cell_type": "code", "execution_count": null, "id": "3b83c9f3-7081-4cec-8f23-c4de007839d7", "metadata": {}, "outputs": [], "source": [ "# predictand to evaluate\n", "PREDICTAND = 'tasmin'" ] }, { "cell_type": "code", "execution_count": null, "id": "49e03dc8-e709-4877-922a-4914e61d7636", "metadata": {}, "outputs": [], "source": [ "# whether only precipitation was used as predictor\n", "PR_ONLY = False" ] }, { "cell_type": "code", "execution_count": null, "id": "3e856f80-14fd-405f-a44e-cc77863f8e5b", "metadata": {}, "outputs": [], "source": [ "# loss function and optimizer\n", "LOSS = ['L1Loss', 'MSELoss', 'BernoulliGammaLoss'] if PREDICTAND == 'pr' else ['L1Loss', 'MSELoss']\n", "OPTIM = 'Adam'" ] }, { "cell_type": "code", "execution_count": null, "id": "011b792d-7349-44ad-997d-11f236472a11", "metadata": {}, "outputs": [], "source": [ "# model to evaluate\n", "if PREDICTAND == 'pr' and PR_ONLY:\n", " models = ['USegNet_pr_pr_1mm_{}_{}'.format(PREDICTAND, loss, OPTIM) if loss == 'BernoulliGammaLoss' else\n", " 'USegNet_pr_pr_{}_{}'.format(PREDICTAND, loss, OPTIM) for loss in LOSS]\n", "else:\n", " models = ['USegNet_{}_ztuvq_500_850_p_dem_doy_1mm_{}_{}'.format(PREDICTAND, loss, OPTIM) if loss == 'BernoulliGammaLoss' else\n", " 'USegNet_{}_ztuvq_500_850_p_dem_doy_{}_{}'.format(PREDICTAND, loss, OPTIM) for loss in LOSS]" ] }, { "cell_type": "code", "execution_count": null, "id": "dc4ca6f0-5490-4522-8661-e36bd1be11b7", "metadata": {}, "outputs": [], "source": [ "# get bootstrapped models\n", "models = {loss: sorted(search_files(RESULTS.joinpath(PREDICTAND), model + '(.*).nc$'),\n", " key=lambda x: int(x.stem.split('_')[-1])) for loss, model in zip(LOSS, models)}\n", "models" ] }, { "cell_type": "markdown", "id": "5a64795a-6e5c-409a-8b3b-c738a96fa255", "metadata": { "tags": [] }, "source": [ "## Load datasets" ] }, { "cell_type": "markdown", "id": "e790ed9f-451c-4368-849d-06d9c50f797c", "metadata": {}, "source": [ "### Load observations" ] }, { "cell_type": "code", "execution_count": null, "id": "0862e0c8-06df-45d6-bc1b-002ffb6e9915", "metadata": {}, "outputs": [], "source": [ "# load observations\n", "y_true = xr.open_dataset(OBS_PATH.joinpath(PREDICTAND, 'OBS_{}_1980_2018.nc'.format(PREDICTAND)),\n", " chunks={'time': 365})\n", "y_true = y_true.sel(time=VALID_PERIOD) # subset to time period covered by predictions\n", "y_true = y_true.rename({NAMES[PREDICTAND]: PREDICTAND}) if PREDICTAND == 'pr' else y_true" ] }, { "cell_type": "code", "execution_count": null, "id": "aba38642-85d1-404a-81f3-65d23985fb7a", "metadata": {}, "outputs": [], "source": [ "# mask of missing values\n", "missing = np.isnan(y_true[PREDICTAND])" ] }, { "cell_type": "markdown", "id": "d4512ed2-d503-4bc1-ae76-84560c101a14", "metadata": {}, "source": [ "### Load reference data" ] }, { "cell_type": "code", "execution_count": null, "id": "f90f6abf-5fd6-49c0-a1ad-f62242b3d3a0", "metadata": {}, "outputs": [], "source": [ "# ERA-5 reference dataset\n", "if PREDICTAND == 'pr':\n", " y_refe = xr.open_dataset(search_files(ERA5_PATH.joinpath('ERA5', 'total_precipitation'), '.nc$').pop(),\n", " chunks={'time': 365})\n", " y_refe = y_refe.rename({'tp': 'pr'})\n", "else:\n", " y_refe = xr.open_dataset(search_files(ERA5_PATH.joinpath('ERA5', '2m_{}_temperature'.format(PREDICTAND.lstrip('tas'))), '.nc$').pop(),\n", " chunks={'time': 365})\n", " y_refe = y_refe - 273.15 # convert to °C\n", " y_refe = y_refe.rename({'t2m': PREDICTAND})" ] }, { "cell_type": "code", "execution_count": null, "id": "ea6d5f56-4f39-4e9a-976d-00ff28fce95c", "metadata": {}, "outputs": [], "source": [ "# subset to time period covered by predictions\n", "y_refe = y_refe.sel(time=VALID_PERIOD).drop_vars('lambert_azimuthal_equal_area')\n", "y_refe = y_refe.transpose('time', 'y', 'x') # change order of dimensions" ] }, { "cell_type": "markdown", "id": "d37702de-da5f-4306-acc1-e569471c1f12", "metadata": {}, "source": [ "### Load QM-adjusted reference data" ] }, { "cell_type": "code", "execution_count": null, "id": "fffbd267-d08b-44f4-869c-7056c4f19c28", "metadata": {}, "outputs": [], "source": [ "y_refe_qm = xr.open_dataset(ERA5_PATH.joinpath('QM_ERA5_{}_day_19912010.nc'.format(PREDICTAND)), chunks={'time': 365})\n", "y_refe_qm = y_refe_qm.transpose('time', 'y', 'x') # change order of dimensions" ] }, { "cell_type": "code", "execution_count": null, "id": "16fa580e-27a7-4758-9164-7f607df7179d", "metadata": {}, "outputs": [], "source": [ "# center hours at 00:00:00 rather than 12:00:00\n", "y_refe_qm['time'] = np.asarray([t.astype('datetime64[D]') for t in y_refe_qm.time.values])" ] }, { "cell_type": "code", "execution_count": null, "id": "6789791f-006b-49b3-aa04-34e4ed8e1571", "metadata": {}, "outputs": [], "source": [ "# subset to time period covered by predictions\n", "y_refe_qm = y_refe_qm.sel(time=VALID_PERIOD).drop_vars('lambert_azimuthal_equal_area')" ] }, { "cell_type": "code", "execution_count": null, "id": "b51cfb3f-caa8-413e-a12d-47bbafcef1df", "metadata": {}, "outputs": [], "source": [ "# align datasets and mask missing values\n", "y_true, y_refe, y_refe_qm = xr.align(y_true[PREDICTAND], y_refe[PREDICTAND], y_refe_qm[PREDICTAND], join='override')\n", "y_refe = y_refe.where(~missing, other=np.nan)\n", "y_refe_qm = y_refe_qm.where(~missing, other=np.nan)" ] }, { "cell_type": "markdown", "id": "b4a6c286-6b88-487d-866c-3cb633686dac", "metadata": {}, "source": [ "### Load model predictions" ] }, { "cell_type": "code", "execution_count": null, "id": "eb889059-17e4-4d8c-b796-e8b1e2d0bf8c", "metadata": {}, "outputs": [], "source": [ "y_pred_raw = {k: [xr.open_dataset(v, chunks={'time': 365}) for v in models[k]] for k in models.keys()}\n", "if PREDICTAND == 'pr':\n", " y_pred_raw = {k: [v.rename({NAMES[PREDICTAND]: PREDICTAND}) if k == 'BernoulliGammaLoss' else v.rename({PREDICTAND: PREDICTAND}) for v in y_pred_raw[k]] for k in y_pred_raw.keys()}\n", " y_pred_raw = {k: [v.transpose('time', 'y', 'x') for v in y_pred_raw[k]] for k in y_pred_raw.keys()}" ] }, { "cell_type": "code", "execution_count": null, "id": "534e020d-96b2-403c-b8e4-86de98fbbe3b", "metadata": {}, "outputs": [], "source": [ "# align datasets and mask missing values\n", "y_prob = {}\n", "y_pred = {}\n", "for loss, sim in y_pred_raw.items():\n", " y_pred[loss], y_prob[loss] = [], []\n", " for y_p in sim:\n", " # check whether evaluating precipitation or temperatures\n", " if len(y_p.data_vars) > 1:\n", " _, _, y_p, y_p_prob = xr.align(y_true, y_refe, y_p[PREDICTAND], y_p.prob, join='override')\n", " y_p_prob = y_p_prob.where(~missing, other=np.nan) # mask missing values\n", " y_prob[loss].append(y_p_prob)\n", " else:\n", " _, _, y_p = xr.align(y_true, y_refe, y_p[PREDICTAND], join='override')\n", "\n", " # mask missing values\n", " y_p = y_p.where(~missing, other=np.nan)\n", " y_pred[loss].append(y_p)" ] }, { "cell_type": "markdown", "id": "6a718ea3-54d3-400a-8c89-76d04347de2d", "metadata": { "tags": [] }, "source": [ "## Ensemble predictions" ] }, { "cell_type": "code", "execution_count": null, "id": "5a6c0bfe-c1d2-4e43-9f8e-35c63c46bb10", "metadata": {}, "outputs": [], "source": [ "# create and save ensemble dataset\n", "ensemble = {k: xr.Dataset({'Member-{}'.format(i): member for i, member in enumerate(y_pred[k])}).to_array('members')\n", " for k in y_pred.keys() if y_pred[k]}" ] }, { "cell_type": "code", "execution_count": null, "id": "0e526227-cd4c-4a1c-ab72-51b72a4f821f", "metadata": {}, "outputs": [], "source": [ "# full ensemble mean prediction and standard deviation\n", "ensemble_mean_full = {k: v.mean(dim='members') for k, v in ensemble.items()}\n", "ensemble_std_full = {k: v.std(dim='members') for k, v in ensemble.items()}" ] }, { "cell_type": "markdown", "id": "f8b31e39-d4b9-4347-953f-87af04c0dd7a", "metadata": { "tags": [] }, "source": [ "# Model validation" ] }, { "cell_type": "code", "execution_count": null, "id": "e8adcb5e-c7b4-4156-85b3-4751020160e6", "metadata": {}, "outputs": [], "source": [ "# extreme quantile of interest\n", "quantile = 0.02 if PREDICTAND == 'tasmin' else 0.98" ] }, { "cell_type": "code", "execution_count": null, "id": "8aa8d57d-8e41-4c6e-a43f-063650ac8e4b", "metadata": {}, "outputs": [], "source": [ "def r2(y_pred, y_true, precipitation=False):\n", " # compute daily anomalies wrt. monthly mean values\n", " anom_pred = ERA5Dataset.anomalies(y_pred, timescale='time.month')\n", " anom_true = ERA5Dataset.anomalies(y_true, timescale='time.month')\n", " \n", " # get predicted and observed daily anomalies\n", " y_pred_av = anom_pred.values.flatten()\n", " y_true_av = anom_true.values.flatten()\n", "\n", " # apply mask of valid pixels\n", " mask = (~np.isnan(y_pred_av) & ~np.isnan(y_true_av))\n", " y_pred_av = y_pred_av[mask]\n", " y_true_av = y_true_av[mask]\n", "\n", " # get predicted and observed monthly sums/means\n", " if precipitation:\n", " y_pred_mv = y_pred.resample(time='1M').sum(skipna=False).values.flatten()\n", " y_true_mv = y_true.resample(time='1M').sum(skipna=False).values.flatten()\n", " else:\n", " y_pred_mv = y_pred.groupby('time.month').mean(dim=('time')).values.flatten()\n", " y_true_mv = y_true.groupby('time.month').mean(dim=('time')).values.flatten()\n", "\n", " # apply mask of valid pixels\n", " mask = (~np.isnan(y_pred_mv) & ~np.isnan(y_true_mv))\n", " y_pred_mv = y_pred_mv[mask]\n", " y_true_mv = y_true_mv[mask]\n", "\n", " # calculate coefficient of determination on monthly sums/means\n", " r2_mm = r2_score(y_true_mv, y_pred_mv)\n", " print('R2 on monthly means: {:.2f}'.format(r2_mm))\n", "\n", " # calculate coefficient of determination on daily anomalies\n", " r2_anom = r2_score(y_true_av, y_pred_av)\n", " print('R2 on daily anomalies: {:.2f}'.format(r2_anom))\n", " \n", " return r2_mm, r2_anom" ] }, { "cell_type": "code", "execution_count": null, "id": "074d7405-a01b-4368-b98b-06d8d46f1ce6", "metadata": {}, "outputs": [], "source": [ "def bias(y_pred, y_true, relative=False):\n", " return (((y_pred - y_true) / y_true) * 100).mean().values.item() if relative else (y_pred - y_true).mean().values.item()" ] }, { "cell_type": "code", "execution_count": null, "id": "2fc13939-e517-47bd-aa7d-0addd6715538", "metadata": {}, "outputs": [], "source": [ "def mae(y_pred, y_true):\n", " return np.abs(y_pred - y_true).mean().values.item()" ] }, { "cell_type": "code", "execution_count": null, "id": "c93f497e-760e-4484-aeb6-ce54f561a7f6", "metadata": {}, "outputs": [], "source": [ "def rmse(y_pred, y_true):\n", " return np.sqrt(((y_pred - y_true) ** 2).mean().values.item())" ] }, { "cell_type": "markdown", "id": "3e6ecc98-f32f-42f7-9971-64b270aa5453", "metadata": { "tags": [] }, "source": [ "## R2, Bias, MAE, and RMSE for reference data" ] }, { "cell_type": "markdown", "id": "671cd3c0-8d6c-41c1-bf8e-93f5943bf9aa", "metadata": {}, "source": [ "### Metrics for mean values" ] }, { "cell_type": "code", "execution_count": null, "id": "7939a4d2-4eff-4507-86f8-dba7c0b635df", "metadata": {}, "outputs": [], "source": [ "# yearly average values over validation period\n", "y_refe_yearly_avg = y_refe.groupby('time.year').mean(dim='time')\n", "y_refe_qm_yearly_avg = y_refe_qm.groupby('time.year').mean(dim='time')\n", "y_true_yearly_avg = y_true.groupby('time.year').mean(dim='time')" ] }, { "cell_type": "code", "execution_count": null, "id": "64e29db7-998d-4952-84b0-1c79016ab9a9", "metadata": {}, "outputs": [], "source": [ "# yearly average r2, bias, mae, and rmse for ERA-5\n", "r2_refe_mm, r2_refe_anom = r2(y_refe, y_true)\n", "bias_refe = bias(y_refe_yearly_avg, y_true_yearly_avg, relative=True if PREDICTAND == 'pr' else False)\n", "mae_refe = mae(y_refe_yearly_avg, y_true_yearly_avg)\n", "rmse_refe = rmse(y_refe_yearly_avg, y_true_yearly_avg)" ] }, { "cell_type": "code", "execution_count": null, "id": "d0d4c974-876f-45e6-85cc-df91501ead20", "metadata": {}, "outputs": [], "source": [ "# yearly average r2, bias, mae, and rmse for QM-Adjusted ERA-5\n", "r2_refe_qm_mm, r2_refe_qm_anom = r2(y_refe_qm, y_true)\n", "bias_refe_qm = bias(y_refe_qm_yearly_avg, y_true_yearly_avg, relative=True if PREDICTAND == 'pr' else False)\n", "mae_refe_qm = mae(y_refe_qm_yearly_avg, y_true_yearly_avg)\n", "rmse_refe_qm = rmse(y_refe_qm_yearly_avg, y_true_yearly_avg)" ] }, { "cell_type": "markdown", "id": "c07684d1-76c0-4088-bdd7-7e1a6ccc4716", "metadata": {}, "source": [ "### Metrics for extreme values" ] }, { "cell_type": "code", "execution_count": null, "id": "343aad59-4b0a-4eec-9ac3-86e5f9d06fc6", "metadata": {}, "outputs": [], "source": [ "# calculate extreme quantile for each year\n", "with warnings.catch_warnings():\n", " warnings.simplefilter('ignore', category=RuntimeWarning)\n", " y_true_ex = y_true.chunk(dict(time=-1)).groupby('time.year').quantile(quantile, dim='time')\n", " y_refe_ex = y_refe.chunk(dict(time=-1)).groupby('time.year').quantile(quantile, dim='time')\n", " y_refe_qm_ex = y_refe_qm.chunk(dict(time=-1)).groupby('time.year').quantile(quantile, dim='time')" ] }, { "cell_type": "code", "execution_count": null, "id": "fbbc648a-82a7-4137-b8ea-5dccb56a65c7", "metadata": {}, "outputs": [], "source": [ "# bias in extreme quantile\n", "bias_ex_refe = bias(y_refe_ex, y_true_ex, relative=True if PREDICTAND == 'pr' else False)\n", "bias_ex_refe_qm = bias(y_refe_qm_ex, y_true_ex, relative=True if PREDICTAND == 'pr' else False)" ] }, { "cell_type": "code", "execution_count": null, "id": "44a3a0e7-ca39-49ce-b569-51b0022161ed", "metadata": {}, "outputs": [], "source": [ "# mean absolute error in extreme quantile\n", "mae_ex_refe = mae(y_refe_ex, y_true_ex)\n", "mae_ex_refe_qm = mae(y_refe_qm_ex, y_true_ex)" ] }, { "cell_type": "code", "execution_count": null, "id": "a90ce1dc-cf94-4081-9add-1d26195f2302", "metadata": {}, "outputs": [], "source": [ "# root mean squared error in extreme quantile\n", "rmse_ex_refe = rmse(y_refe_ex, y_true_ex)\n", "rmse_ex_refe_qm = rmse(y_refe_qm_ex, y_true_ex)" ] }, { "cell_type": "code", "execution_count": null, "id": "d6efe5b9-3a6d-41ea-9f26-295b167cf0af", "metadata": {}, "outputs": [], "source": [ "# compute validation metrics for reference datasets\n", "filename = RESULTS.joinpath(PREDICTAND, 'reference.csv')\n", "if filename.exists():\n", " # check if validation metrics for reference already exist\n", " df_refe = pd.read_csv(filename)\n", "else:\n", " # compute validation metrics\n", " df_refe = pd.DataFrame([], columns=['r2_mm', 'r2_anom', 'bias', 'mae', 'rmse', 'bias_ex', 'mae_ex', 'rmse_ex', 'product'])\n", " for product, metrics in zip(['Era-5', 'Era-5 QM'],\n", " [[r2_refe_mm, r2_refe_anom, bias_refe, mae_refe, rmse_refe, bias_ex_refe, mae_ex_refe, rmse_ex_refe],\n", " [r2_refe_qm_mm, r2_refe_qm_anom, bias_refe_qm, mae_refe_qm, rmse_refe_qm, bias_ex_refe_qm, mae_ex_refe_qm,\n", " rmse_ex_refe_qm]]):\n", " df_refe = df_refe.append(pd.DataFrame([metrics + [product]], columns=df_refe.columns), ignore_index=True)\n", "\n", " # save metrics to disk\n", " df_refe.to_csv(filename, index=False)" ] }, { "cell_type": "markdown", "id": "258cb3c6-c2fc-457d-885e-28eaf48f1d5b", "metadata": { "tags": [] }, "source": [ "## R2, Bias, MAE, and RMSE for model predictions" ] }, { "cell_type": "markdown", "id": "630ce1c5-b018-437f-a7cf-8c8d99cd8f84", "metadata": {}, "source": [ "### Metrics for mean values" ] }, { "cell_type": "code", "execution_count": null, "id": "6980833a-3848-43ca-bcca-d759b4fd9f69", "metadata": {}, "outputs": [], "source": [ "# yearly average bias, mae, and rmse for each ensemble member\n", "y_pred_yearly_avg = {k: v.groupby('time.year').mean(dim='time') for k, v in ensemble.items()}\n", "bias_pred = {k: [bias(v[i], y_true_yearly_avg, relative=True if PREDICTAND == 'pr' else False) for i in range(len(ensemble[k]))] for k, v in y_pred_yearly_avg.items()}\n", "mae_pred = {k: [mae(v[i], y_true_yearly_avg) for i in range(len(ensemble[k]))] for k, v in y_pred_yearly_avg.items()}\n", "rmse_pred = {k: [rmse(v[i], y_true_yearly_avg) for i in range(len(ensemble[k]))] for k, v in y_pred_yearly_avg.items()}" ] }, { "cell_type": "markdown", "id": "122e84f8-211d-4816-9b03-2e1abc24eb9e", "metadata": {}, "source": [ "### Metrics for extreme values" ] }, { "cell_type": "code", "execution_count": null, "id": "8b48a065-1a0b-4457-a97a-642c26d56c51", "metadata": {}, "outputs": [], "source": [ "# calculate extreme quantile for each year\n", "with warnings.catch_warnings():\n", " warnings.simplefilter('ignore', category=RuntimeWarning)\n", " y_pred_ex = {k: v.chunk(dict(time=-1)).groupby('time.year').quantile(quantile, dim='time') for k, v in ensemble.items()}" ] }, { "cell_type": "code", "execution_count": null, "id": "2e6893da-271b-4b0e-bbc1-46b5eb9ecee3", "metadata": {}, "outputs": [], "source": [ "# yearly average bias, mae, and rmse for each ensemble member\n", "bias_pred_ex = {k: [bias(v[i], y_true_ex, relative=True if PREDICTAND == 'pr' else False) for i in range(len(ensemble[k]))] for k, v in y_pred_ex.items()}\n", "mae_pred_ex = {k: [mae(v[i], y_true_ex) for i in range(len(ensemble[k]))] for k, v in y_pred_ex.items()}\n", "rmse_pred_ex = {k: [rmse(v[i], y_true_ex) for i in range(len(ensemble[k]))] for k, v in y_pred_ex.items()}" ] }, { "cell_type": "code", "execution_count": null, "id": "64f7a0b9-a772-4a03-9160-7839a48e56cd", "metadata": { "tags": [] }, "outputs": [], "source": [ "# compute validation metrics for model predictions\n", "filename = (RESULTS.joinpath(PREDICTAND, 'prediction_pr-only.csv') if PREDICTAND == 'pr' and PR_ONLY else\n", " RESULTS.joinpath(PREDICTAND, 'prediction.csv'))\n", "if filename.exists():\n", " # check if validation metrics for predictions already exist\n", " df_pred = pd.read_csv(filename)\n", "else:\n", " # validation metrics for each ensemble member\n", " df_pred = pd.DataFrame([], columns=['r2_mm', 'r2_anom', 'bias', 'mae', 'rmse', 'bias_ex', 'mae_ex', 'rmse_ex', 'product', 'loss'])\n", " for k in y_pred_yearly_avg.keys():\n", " for i in range(len(ensemble[k])):\n", " # bias, mae, and rmse\n", " values = pd.DataFrame([[bias_pred[k][i], mae_pred[k][i], rmse_pred[k][i], bias_pred_ex[k][i],\n", " mae_pred_ex[k][i], rmse_pred_ex[k][i], 'Member-{:d}'.format(i), k]],\n", " columns=df_pred.columns[2:])\n", " \n", " # r2 scores\n", " values['r2_mm'], values['r2_anom'] = r2(ensemble[k][i], y_true, precipitation=True if PREDICTAND == 'pr' else False)\n", " df_pred = df_pred.append(values, ignore_index=True)\n", " \n", " # validation metrics for ensemble\n", " for k, v in ensemble_mean_full.items():\n", " # metrics for mean values\n", " means = v.groupby('time.year').mean(dim='time')\n", " bias_mean = bias(means, y_true_yearly_avg, relative=True if PREDICTAND == 'pr' else False)\n", " mae_mean = mae(means, y_true_yearly_avg)\n", " rmse_mean = rmse(means, y_true_yearly_avg)\n", " \n", " # metrics for extreme values\n", " with warnings.catch_warnings():\n", " warnings.simplefilter('ignore', category=RuntimeWarning)\n", " extremes = v.chunk(dict(time=-1)).groupby('time.year').quantile(quantile, dim='time')\n", " bias_ex = bias(extremes, y_true_ex, relative=True if PREDICTAND == 'pr' else False)\n", " mae_ex = mae(extremes, y_true_ex)\n", " rmse_ex = rmse(extremes, y_true_ex)\n", " \n", " # r2 scores\n", " r2_mm, r2_anom = r2(v, y_true, precipitation=True if PREDICTAND == 'pr' else False)\n", " df_pred = df_pred.append(pd.DataFrame([[r2_mm, r2_anom, bias_mean, mae_mean, rmse_mean, bias_ex, mae_ex, rmse_ex, 'Ensemble-{:d}'.format(len(ensemble[k])), k]],\n", " columns=df_pred.columns), ignore_index=True)\n", "\n", " # save metrics to disk\n", " df_pred.to_csv(filename, index=False)" ] }, { "cell_type": "markdown", "id": "da948e96-4a8c-4a56-9177-846851fe8ef8", "metadata": {}, "source": [ "### AUC and ROCSS for precipitation" ] }, { "cell_type": "code", "execution_count": null, "id": "0b7a824b-418a-4499-a3c0-627190e00941", "metadata": {}, "outputs": [], "source": [ "def auc_rocss(p_pred, y_true, wet_day_threshold=1):\n", " # true and predicted probability of precipitation\n", " p_true = (y_true >= float(wet_day_threshold)).values.flatten()\n", " p_pred = p_pred.values.flatten()\n", " \n", " # apply mask of valid pixels\n", " mask = (~np.isnan(p_true) & ~np.isnan(p_pred))\n", " p_pred = p_pred[mask]\n", " p_true = p_true[mask].astype(float)\n", " \n", " # calculate ROC: false positive rate vs. true positive rate\n", " fpr, tpr, _ = roc_curve(p_true, p_pred)\n", " area = auc(fpr, tpr) # area under ROC curve\n", " rocss = 2 * area - 1 # ROC skill score (cf. https://journals.ametsoc.org/view/journals/clim/16/24/1520-0442_2003_016_4145_otrsop_2.0.co_2.xml)\n", " \n", " return area, rocss" ] }, { "cell_type": "code", "execution_count": null, "id": "43138ade-148a-4d8d-be48-f1280d40e5b0", "metadata": {}, "outputs": [], "source": [ "if PREDICTAND == 'pr':\n", " # precipitation threshold to consider as wet day\n", " WET_DAY_THRESHOLD = 1\n", " \n", " # ensemble prediction for precipitation probability\n", " ensemble_prob = xr.Dataset({'Member-{}'.format(i): member for i, member in\n", " enumerate(y_prob['BernoulliGammaLoss'])}).to_array('members')\n", " ensemble_mean_prob = ensemble_prob.mean(dim='members')\n", " \n", " # filename for probability metrics\n", " filename = (RESULTS.joinpath(PREDICTAND, 'probability_pr-only.csv') if PREDICTAND == 'pr' and PR_ONLY else\n", " RESULTS.joinpath(PREDICTAND, 'probability.csv'))\n", " if filename.exists():\n", " # check if validation metrics for probabilities already exist\n", " df_prob = pd.read_csv(filename)\n", " else:\n", " # AUC and ROCSS for each ensemble member\n", " df_prob = pd.DataFrame([], columns=['auc', 'rocss', 'product', 'loss'])\n", " for i in range(len(ensemble_prob)):\n", " auc_score, rocss = auc_rocss(ensemble_prob[i], y_true, wet_day_threshold=WET_DAY_THRESHOLD)\n", " df_prob = df_prob.append(pd.DataFrame([[auc_score, rocss, ensemble_prob[i].members.item(), 'BernoulliGammaLoss']],\n", " columns=df_prob.columns), ignore_index=True)\n", "\n", " # AUC and ROCSS for ensemble mean\n", " auc_score, rocss = auc_rocss(ensemble_mean_prob, y_true, wet_day_threshold=WET_DAY_THRESHOLD)\n", " df_prob = df_prob.append(pd.DataFrame([[auc_score, rocss, 'Ensemble-{:d}'.format(len(ensemble_prob)), 'BernoulliGammaLoss']],\n", " columns=df_prob.columns), ignore_index=True)\n", "\n", " # save metrics to disk\n", " df_prob.to_csv(filename, index=False)" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.12" } }, "nbformat": 4, "nbformat_minor": 5 }