diff --git a/Notebooks/eval_bootstrap.ipynb b/Notebooks/eval_bootstrap.ipynb index bf829d047184ffb5d16ed812743135d7c405b371..46067ee95337db14480ec08509a93019c3dd6ff7 100644 --- a/Notebooks/eval_bootstrap.ipynb +++ b/Notebooks/eval_bootstrap.ipynb @@ -33,6 +33,7 @@ "import xarray as xr\n", "import pandas as pd\n", "import matplotlib.pyplot as plt\n", + "from matplotlib import gridspec\n", "import seaborn as sns\n", "\n", "# locals\n", @@ -70,19 +71,29 @@ "tags": [] }, "source": [ - "## Search model configuration" + "## Search model configurations" ] }, { "cell_type": "code", "execution_count": null, - "id": "3e856f80-14fd-405f-a44e-cc77863f8e5b", + "id": "3b83c9f3-7081-4cec-8f23-c4de007839d7", "metadata": {}, "outputs": [], "source": [ "# predictand to evaluate\n", - "PREDICTAND = 'tasmin'\n", - "LOSS = 'L1Loss'\n", + "PREDICTAND = 'tasmin'" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3e856f80-14fd-405f-a44e-cc77863f8e5b", + "metadata": {}, + "outputs": [], + "source": [ + "# loss function and optimizer\n", + "LOSS = ['L1Loss', 'MSELoss', 'BernoulliGammaLoss'] if PREDICTAND == 'pr' else ['L1Loss', 'MSELoss']\n", "OPTIM = 'Adam'" ] }, @@ -94,7 +105,8 @@ "outputs": [], "source": [ "# model to evaluate\n", - "model = 'USegNet_{}_ztuvq_500_850_p_dem_doy_{}_{}'.format(PREDICTAND, LOSS, OPTIM)" + "models = ['USegNet_{}_ztuvq_500_850_p_dem_doy_1mm_{}_{}'.format(PREDICTAND, loss, OPTIM) if loss == 'BernoulliGammaLoss' else\n", + " 'USegNet_{}_ztuvq_500_850_p_dem_doy_{}_{}'.format(PREDICTAND, loss, OPTIM) for loss in LOSS]" ] }, { @@ -105,8 +117,8 @@ "outputs": [], "source": [ "# get bootstrapped models\n", - "models = sorted(search_files(RESULTS.joinpath(PREDICTAND), model + '(.*).nc$'),\n", - " key=lambda x: int(x.stem.split('_')[-1]))\n", + "models = {loss: sorted(search_files(RESULTS.joinpath(PREDICTAND), model + '(.*).nc$'),\n", + " key=lambda x: int(x.stem.split('_')[-1])) for loss, model in zip(LOSS, models)}\n", "models" ] }, @@ -257,38 +269,40 @@ { "cell_type": "code", "execution_count": null, - "id": "ccaf0118-da51-43b0-a2b6-56ba4b252999", + "id": "eb889059-17e4-4d8c-b796-e8b1e2d0bf8c", "metadata": {}, "outputs": [], "source": [ - "y_pred = [xr.open_dataset(sim, chunks={'time': 365}) for sim in models]\n", + "y_pred_raw = {k: [xr.open_dataset(v, chunks={'time': 365}) for v in models[k]] for k in models.keys()}\n", "if PREDICTAND == 'pr':\n", - " y_pred = [y_p.rename({NAMES[PREDICTAND]: PREDICTAND}) for y_p in y_pred]\n", - " y_pred = [y_p.transpose('time', 'y', 'x') for y_p in y_pred]" + " y_pred_raw = {k: [v.rename({NAMES[PREDICTAND]: PREDICTAND}) for v in y_pred_raw[k]] for k in y_pred_raw.keys()}\n", + " y_pred_raw = {k: [v.transpose('time', 'y', 'x') for v in y_pred_raw[k]] for k in y_pred_raw.keys()}" ] }, { "cell_type": "code", "execution_count": null, - "id": "df3f018e-4723-4878-b72a-0586b68e6dfd", + "id": "534e020d-96b2-403c-b8e4-86de98fbbe3b", "metadata": {}, "outputs": [], "source": [ "# align datasets and mask missing values\n", - "y_prob = []\n", - "for i, y_p in enumerate(y_pred):\n", - " \n", - " # check whether evaluating precipitation or temperatures\n", - " if len(y_p.data_vars) > 1:\n", - " _, _, y_p, y_p_prob = xr.align(y_true, y_refe, y_p[PREDICTAND], y_p.prob, join='override')\n", - " y_p_prob = y_p_prob.where(~missing, other=np.nan) # mask missing values\n", - " y_prob.append(y_p_prob)\n", - " else:\n", - " _, _, y_p = xr.align(y_true, y_refe, y_p[PREDICTAND], join='override')\n", - " \n", - " # mask missing values\n", - " y_p = y_p.where(~missing, other=np.nan)\n", - " y_pred[i] = y_p" + "y_prob = {}\n", + "y_pred = {}\n", + "for loss, models in y_pred_raw.items():\n", + " y_pred[loss], y_prob[loss] = [], []\n", + " for y_p in models:\n", + " # check whether evaluating precipitation or temperatures\n", + " if len(y_p.data_vars) > 1:\n", + " _, _, y_p, y_p_prob = xr.align(y_true, y_refe, y_p[PREDICTAND], y_p.prob, join='override')\n", + " y_p_prob = y_p_prob.where(~missing, other=np.nan) # mask missing values\n", + " y_prob[loss].append(y_p_prob)\n", + " else:\n", + " _, _, y_p = xr.align(y_true, y_refe, y_p[PREDICTAND], join='override')\n", + "\n", + " # mask missing values\n", + " y_p = y_p.where(~missing, other=np.nan)\n", + " y_pred[loss].append(y_p)" ] }, { @@ -309,7 +323,8 @@ "outputs": [], "source": [ "# create ensemble dataset\n", - "ensemble = xr.Dataset({'member_{}'.format(i): member for i, member in enumerate(y_pred)}).to_array('members')" + "ensemble = {k: xr.Dataset({'Member-{}'.format(i): member for i, member in enumerate(y_pred[k])}).to_array('members')\n", + " for k in y_pred.keys() if y_pred[k]}" ] }, { @@ -320,8 +335,8 @@ "outputs": [], "source": [ "# full ensemble mean prediction and standard deviation\n", - "ensemble_mean_full = ensemble.mean(dim='members')\n", - "ensemble_std_full = ensemble.std(dim='members')" + "ensemble_mean_full = {k: v.mean(dim='members') for k, v in ensemble.items()}\n", + "ensemble_std_full = {k: v.std(dim='members') for k, v in ensemble.items()}" ] }, { @@ -332,9 +347,9 @@ "outputs": [], "source": [ "# ensemble mean prediction using three random members\n", - "ensemble_3 = np.random.randint(0, len(ensemble.members), size=3)\n", - "ensemble_mean_3 = ensemble[ensemble_3, ...].mean(dim='members')\n", - "ensemble_std_3 = ensemble[ensemble_3, ...].std(dim='members')" + "ensemble_3 = np.random.randint(0, len(ensemble['L1Loss'].members), size=3)\n", + "ensemble_mean_3 = {k: v[ensemble_3, ...].mean(dim='members') for k, v in ensemble.items()}\n", + "ensemble_std_3 = {k: v[ensemble_3, ...].std(dim='members') for k, v in ensemble.items()}" ] }, { @@ -345,9 +360,9 @@ "outputs": [], "source": [ "# ensemble mean prediction using five random members\n", - "ensemble_5 = np.random.randint(0, len(ensemble.members), size=5)\n", - "ensemble_mean_5 = ensemble[ensemble_5, ...].mean(dim='members')\n", - "ensemble_std_5 = ensemble[ensemble_5, ...].std(dim='members')" + "ensemble_5 = np.random.randint(0, len(ensemble['L1Loss'].members), size=5)\n", + "ensemble_mean_5 = {k: v[ensemble_5, ...].mean(dim='members') for k, v in ensemble.items()}\n", + "ensemble_std_5 = {k: v[ensemble_5, ...].std(dim='members') for k, v in ensemble.items()}" ] }, { @@ -417,17 +432,6 @@ "rmse_refe_qm = (y_refe_qm_yearly_avg - y_true_yearly_avg) ** 2" ] }, - { - "cell_type": "code", - "execution_count": null, - "id": "5b49ff5b-b4f5-48f9-8cd7-49bb0f2af7da", - "metadata": {}, - "outputs": [], - "source": [ - "# create dataframe for mean bias, mae, and rmse\n", - "df_ref = pd.DataFrame([], columns=['bias', 'mae', 'rmse', 'product'])" - ] - }, { "cell_type": "code", "execution_count": null, @@ -435,11 +439,21 @@ "metadata": {}, "outputs": [], "source": [ - "# absolute values for the reference datasets\n", - "for product, metrics in zip(['Era-5', 'Era-5 QM'], [[bias_refe, mae_refe, rmse_refe], [bias_refe_qm, mae_refe_qm, rmse_refe_qm]]):\n", - " values = pd.DataFrame([[np.sqrt(m.mean().values.item()) if name == 'rmse' else m.mean().values.item() for name, m in zip(['bias', 'mae', 'rmse'], metrics)] + [product]],\n", - " columns=df_ref.columns)\n", - " df_ref = df_ref.append(values, ignore_index=True)" + "# compute validation metrics for reference datasets\n", + "filename = RESULTS.joinpath(PREDICTAND, 'reference.csv')\n", + "if filename.exists():\n", + " # check if validation metrics for reference already exist\n", + " df_ref = pd.read_csv(filename)\n", + "else:\n", + " # compute validation metrics\n", + " df_ref = pd.DataFrame([], columns=['bias', 'mae', 'rmse', 'product'])\n", + " for product, metrics in zip(['Era-5', 'Era-5 QM'], [[bias_refe, mae_refe, rmse_refe], [bias_refe_qm, mae_refe_qm, rmse_refe_qm]]):\n", + " values = pd.DataFrame([[np.sqrt(m.mean().values.item()) if name == 'rmse' else m.mean().values.item() for\n", + " name, m in zip(['bias', 'mae', 'rmse'], metrics)] + [product]], columns=df_ref.columns)\n", + " df_ref = df_ref.append(values, ignore_index=True)\n", + " \n", + " # save metrics to disk\n", + " df_ref.to_csv(filename, index=False)" ] }, { @@ -454,10 +468,10 @@ }, { "cell_type": "markdown", - "id": "3f8d75ad-1da2-4f80-aef9-996c0463d1a2", + "id": "630ce1c5-b018-437f-a7cf-8c8d99cd8f84", "metadata": {}, "source": [ - "### Absolute values for each ensemble member" + "Calculate yearly average bias, MAE, and RMSE over entire reference period for model predictions." ] }, { @@ -468,10 +482,10 @@ "outputs": [], "source": [ "# yearly average bias, mae, and rmse for each ensemble member\n", - "y_pred_yearly_avg = ensemble.groupby('time.year').mean(dim='time')\n", - "bias_pred = y_pred_yearly_avg - y_true_yearly_avg\n", - "mae_pred = np.abs(y_pred_yearly_avg - y_true_yearly_avg)\n", - "rmse_pred = (y_pred_yearly_avg - y_true_yearly_avg) ** 2" + "y_pred_yearly_avg = {k: v.groupby('time.year').mean(dim='time') for k, v in ensemble.items()}\n", + "bias_pred = {k: v - y_true_yearly_avg for k, v in y_pred_yearly_avg.items()}\n", + "mae_pred = {k: np.abs(v - y_true_yearly_avg) for k, v in y_pred_yearly_avg.items()}\n", + "rmse_pred = {k: (v - y_true_yearly_avg) ** 2 for k, v in y_pred_yearly_avg.items()}" ] }, { @@ -483,99 +497,213 @@ }, "outputs": [], "source": [ - "# absolute values for each ensemble member\n", - "df_pred = pd.DataFrame([], columns=['bias', 'mae', 'rmse', 'product'])\n", - "for i in range(len(bias_pred)):\n", - " values = pd.DataFrame([[np.sqrt(m.mean().values.item()) if name == 'rmse' else m.mean().values.item()\n", - " for name, m in zip(['bias', 'mae', 'rmse'], [bias_pred[i], mae_pred[i], rmse_pred[i]])] + [bias_pred[i].members.item()]],\n", - " columns=df_pred.columns)\n", - " df_pred = df_pred.append(values, ignore_index=True)" + "# compute validation metrics for model predictions\n", + "filename = RESULTS.joinpath(PREDICTAND, 'prediction.csv')\n", + "if filename.exists():\n", + " # check if validation metrics for predictions already exist\n", + " df_pred = pd.read_csv(filename)\n", + "else:\n", + " # validation metrics for each ensemble member\n", + " df_pred = pd.DataFrame([], columns=['bias', 'mae', 'rmse', 'product', 'loss'])\n", + " for k in y_pred_yearly_avg.keys():\n", + " for i in range(len(bias_pred[k])):\n", + " values = pd.DataFrame([[np.sqrt(m.mean().values.item()) if name == 'rmse' else m.mean().values.item()\n", + " for name, m in zip(['bias', 'mae', 'rmse'], [bias_pred[k][i], mae_pred[k][i], rmse_pred[k][i]])] +\n", + " [bias_pred[k][i].members.item()] + [k]],\n", + " columns=df_pred.columns)\n", + " df_pred = df_pred.append(values, ignore_index=True)\n", + " \n", + " # validation metrics for ensembles\n", + " for name, ens in zip(['Ensemble-3', 'Ensemble-5', 'Ensemble-{:d}'.format(len(ensemble['L1Loss']))],\n", + " [ensemble_mean_3, ensemble_mean_5, ensemble_mean_full]):\n", + " for k, v in ens.items():\n", + " yearly_avg = v.groupby('time.year').mean(dim='time')\n", + " bias = (yearly_avg - y_true_yearly_avg).mean().values.item()\n", + " mae = np.abs(yearly_avg - y_true_yearly_avg).mean().values.item()\n", + " rmse = np.sqrt(((yearly_avg - y_true_yearly_avg) ** 2).mean().values.item())\n", + " values = pd.DataFrame([[bias, mae, rmse, name, k]], columns=df_pred.columns)\n", + " df_pred = df_pred.append(values, ignore_index=True)\n", + " \n", + " # save metrics to disk\n", + " df_pred.to_csv(filename, index=False)" ] }, { "cell_type": "markdown", - "id": "a1fdb3bd-84cd-4ebf-80f4-a0275e372315", + "id": "902e299c-a927-41b1-b2ae-987c30dee8cf", "metadata": {}, "source": [ - "### Absolute values for ensemble predictions" + "## Plot results" ] }, { "cell_type": "code", "execution_count": null, - "id": "5a5755de-c000-43aa-9d16-637f021691ae", + "id": "bdca9b54-3e05-49c8-b1b2-b8c782017306", "metadata": {}, "outputs": [], "source": [ - "for name, ens in zip(['Ensemble-3', 'Ensemble-5', 'Ensemble-{:d}'.format(len(ensemble))], [ensemble_mean_3, ensemble_mean_5, ensemble_mean_full]):\n", - " yearly_avg = ens.groupby('time.year').mean(dim='time')\n", - " bias = (yearly_avg - y_true_yearly_avg).mean().values.item()\n", - " mae = np.abs(yearly_avg - y_true_yearly_avg).mean().values.item()\n", - " rmse = np.sqrt(((yearly_avg - y_true_yearly_avg) ** 2).mean().values.item())\n", - " values = pd.DataFrame([[bias, mae, rmse, name]], columns=df_pred.columns)\n", - " df_pred = df_pred.append(values, ignore_index=True)" + "# create a sequential colormap: for reference data, single ensemble members, and ensemble mean predictions\n", + "# palette = sns.color_palette('YlOrRd_r', 10) + sns.color_palette('Greens', 3)\n", + "palette = sns.color_palette('Blues', len(LOSS))" ] }, { "cell_type": "markdown", - "id": "902e299c-a927-41b1-b2ae-987c30dee8cf", - "metadata": {}, + "id": "3cfcd2de-cd37-42d5-b53d-e8abfd21e242", + "metadata": { + "tags": [] + }, "source": [ - "## Plot results" + "### Absolute values: single members vs. ensemble" ] }, { - "cell_type": "markdown", - "id": "3cfcd2de-cd37-42d5-b53d-e8abfd21e242", + "cell_type": "code", + "execution_count": null, + "id": "48751f7f-9c26-471d-a75e-b7bb2fcb71be", "metadata": {}, + "outputs": [], "source": [ - "### Absolute values" + "# dataframe of single members and ensembles only\n", + "members = df_pred[~np.isin(df_pred['product'], ['Ensemble-{}'.format(i) for i in [3, 5, 10]])]\n", + "ensemble = df_pred[~np.isin(df_pred['product'], ['Member-{}'.format(i) for i in range(10)])]" ] }, { "cell_type": "code", "execution_count": null, - "id": "bdca9b54-3e05-49c8-b1b2-b8c782017306", + "id": "d7c8e987-0257-4263-ac4b-718a614c458f", "metadata": {}, "outputs": [], "source": [ - "# create a sequential colormap: for reference data, single ensemble members, and ensemble mean predictions\n", - "palette = sns.color_palette('YlOrRd_r', 10) + sns.color_palette('Greens', 3)" + "# initialize figure\n", + "fig = plt.figure(figsize=(16, 5))\n", + "\n", + "# create grid for different subplots\n", + "grid = gridspec.GridSpec(ncols=5, nrows=1, width_ratios=[3, 1, 1, 3, 1], wspace=0.05, hspace=0)\n", + "\n", + "# add subplots\n", + "ax1 = fig.add_subplot(grid[0])\n", + "ax2 = fig.add_subplot(grid[1], sharey=ax1)\n", + "ax3 = fig.add_subplot(grid[3])\n", + "ax4 = fig.add_subplot(grid[4], sharey=ax3)\n", + "axes = [ax1, ax2, ax3, ax4]\n", + "\n", + "# plot bias: single members vs. ensemble\n", + "sns.barplot(x='product', y='bias', hue='loss', data=members, palette=palette, ax=ax1);\n", + "sns.barplot(x='product', y='bias', hue='loss', data=ensemble, palette=palette, ax=ax2);\n", + "\n", + "# plot mae: single members vs. ensemble\n", + "sns.barplot(x='product', y='mae', hue='loss', data=members, palette=palette, ax=ax3);\n", + "sns.barplot(x='product', y='mae', hue='loss', data=ensemble, palette=palette, ax=ax4);\n", + "\n", + "# axes limits and ticks\n", + "y_lim_bias = (-50, 50) if PREDICTAND == 'pr' else (-1, 1)\n", + "y_lim_mae = (0, 2) if PREDICTAND == 'pr' else (0, 1)\n", + "y_ticks_bias = (np.arange(y_lim_bias[0], y_lim_bias[1] + 10, 10) if PREDICTAND == 'pr' else\n", + " np.arange(y_lim_bias[0], y_lim_bias[1] + 0.2, 0.2))\n", + "y_ticks_mae = (np.arange(y_lim_mae[0], y_lim_mae[1] + 10, 10) if PREDICTAND == 'pr' else\n", + " np.arange(y_lim_mae[0], y_lim_mae[1] + 0.2, 0.2))\n", + "\n", + "# axis for bias\n", + "ax1.set_ylabel('Bias (%)' if PREDICTAND == 'pr' else 'Bias (°C)')\n", + "ax1.set_ylim(y_lim_bias)\n", + "ax1.set_yticks(y_ticks_bias)\n", + "\n", + "# axis for mae\n", + "ax3.set_ylabel('Mean absolute error (mm)' if PREDICTAND == 'pr' else 'Mean absolute error (°C)')\n", + "ax3.set_ylim(y_lim_mae)\n", + "ax3.set_yticks(y_ticks_mae)\n", + "\n", + "# adjust axis for ensemble predictions\n", + "for ax in [ax2, ax4]:\n", + " ax.yaxis.tick_right()\n", + " ax.set_ylabel('')\n", + "\n", + "# axis fontsize and legend\n", + "for ax in axes:\n", + " ax.tick_params('both', labelsize=14)\n", + " ax.set_xticklabels(ax.get_xticklabels(), rotation=90)\n", + " ax.yaxis.label.set_size(14)\n", + " ax.set_xlabel('')\n", + " \n", + " # adjust legend\n", + " h, _ = ax.get_legend_handles_labels()\n", + " ax.get_legend().remove()\n", + "\n", + "# show single legend\n", + "ax4.legend(bbox_to_anchor=(1.3, 1.05), loc=2, frameon=False, fontsize=14);\n", + "\n", + "# save figure\n", + "fig.savefig('./Figures/{}_members_vs_ensemble.pdf'.format(PREDICTAND), bbox_inches='tight')" ] }, { - "cell_type": "code", - "execution_count": null, - "id": "32f1a652-a4a2-4738-ba6f-93fe3ac43658", + "cell_type": "markdown", + "id": "590ffbaf-0e8d-4b63-9264-ad86078d50c9", "metadata": {}, - "outputs": [], "source": [ - "# absolute values for metrics for both reference and model predictions\n", - "df = pd.concat([df_ref, df_pred], ignore_index=True)\n", - "df" + "### Absolute values: ensemble vs. reference" ] }, { "cell_type": "code", "execution_count": null, - "id": "d7c8e987-0257-4263-ac4b-718a614c458f", + "id": "b1a9b1b7-9cd7-4998-afbb-11e64e91b333", "metadata": {}, "outputs": [], "source": [ "# initialize figure\n", - "fig, axes = plt.subplots(1, 2, figsize=(16, 5))\n", - "sns.barplot(x='product', y='bias', data=df_pred, palette=palette, ax=axes[0])\n", - "sns.barplot(x='product', y='mae', data=df_pred, palette=palette, ax=axes[1])\n", - "# sns.barplot(x='product', y='rmse', data=df, palette=palette, ax=axes[2])\n", + "fig, axes = plt.subplots(1, 3, figsize=(16, 5))\n", "\n", - "# axes limits and ticks\n", - "axes[0].set_ylim(-0.5, 0.5)\n", - "axes[1].set_ylim(0, 1)\n", - "# axes[2].set_ylim(0, 1)\n", + "# plot bias: ensemble predictions vs. reference\n", + "sns.barplot(x='product', y='bias', hue='loss', data=ensemble, palette=palette, ax=axes[0]);\n", + "\n", + "# plot mae: ensemble predictions vs. reference\n", + "sns.barplot(x='product', y='mae', hue='loss', data=ensemble, palette=palette, ax=axes[1]);\n", + "\n", + "# plot rmse: ensemble predictions vs. reference\n", + "sns.barplot(x='product', y='rmse', hue='loss', data=ensemble, palette=palette, ax=axes[2]);\n", + "\n", + "# add metrics for reference\n", + "for ax, metric in zip(axes, ['bias', 'mae', 'rmse']):\n", + " for product, ls in zip(df_ref['product'], ['-', '--']):\n", + " ax.hlines(df_ref[metric][df_ref['product'] == product].item(), xmin=-0.5, xmax=2.5,\n", + " color='k', ls=ls, label=product)\n", + "\n", + "# axis for bias\n", + "axes[0].set_ylabel('Bias (%)' if PREDICTAND == 'pr' else 'Bias (°C)')\n", + "axes[0].set_ylim(y_lim_bias)\n", + "axes[0].set_yticks(y_ticks_bias)\n", + "\n", + "# axis for mae\n", + "axes[1].set_ylabel('Mean absolute error (mm)' if PREDICTAND == 'pr' else 'Mean absolute error (°C)')\n", + "axes[1].set_ylim(y_lim_mae)\n", + "axes[1].set_yticks(y_ticks_mae)\n", + "\n", + "# axis for rmse\n", + "axes[2].set_ylabel('RMSE (mm)' if PREDICTAND == 'pr' else 'RMSE (°C)')\n", + "axes[2].set_ylim(y_lim_mae)\n", + "axes[2].set_yticks(y_ticks_mae)\n", + "\n", + "# axis fontsize and legend\n", "for ax in axes:\n", " ax.tick_params('both', labelsize=14)\n", " ax.set_xticklabels(ax.get_xticklabels(), rotation=90)\n", - " ax.yaxis.label.set_size(14)" + " ax.yaxis.label.set_size(14)\n", + " ax.set_xlabel('')\n", + " \n", + " # adjust legend\n", + " h, _ = ax.get_legend_handles_labels()\n", + " ax.get_legend().remove()\n", + "\n", + "# show single legend\n", + "axes[-1].legend(bbox_to_anchor=(1.05, 1.05), loc=2, frameon=False, fontsize=14);\n", + "\n", + "# save figure\n", + "fig.subplots_adjust(wspace=0.25)\n", + "fig.savefig('./Figures/{}_ensemble_vs_reference.pdf'.format(PREDICTAND), bbox_inches='tight')" ] }, {