From 52ab32d630b29f1bb098efa20a64bcc3c2d80336 Mon Sep 17 00:00:00 2001
From: "Daniel.Frisinghelli" <daniel.frisinghelli@eurac.edu>
Date: Fri, 26 Nov 2021 14:45:42 +0000
Subject: [PATCH] Improved plots.

---
 Notebooks/eval_bootstrap.ipynb | 328 +++++++++++++++++++++++----------
 1 file changed, 228 insertions(+), 100 deletions(-)

diff --git a/Notebooks/eval_bootstrap.ipynb b/Notebooks/eval_bootstrap.ipynb
index bf829d0..46067ee 100644
--- a/Notebooks/eval_bootstrap.ipynb
+++ b/Notebooks/eval_bootstrap.ipynb
@@ -33,6 +33,7 @@
     "import xarray as xr\n",
     "import pandas as pd\n",
     "import matplotlib.pyplot as plt\n",
+    "from matplotlib import gridspec\n",
     "import seaborn as sns\n",
     "\n",
     "# locals\n",
@@ -70,19 +71,29 @@
     "tags": []
    },
    "source": [
-    "## Search model configuration"
+    "## Search model configurations"
    ]
   },
   {
    "cell_type": "code",
    "execution_count": null,
-   "id": "3e856f80-14fd-405f-a44e-cc77863f8e5b",
+   "id": "3b83c9f3-7081-4cec-8f23-c4de007839d7",
    "metadata": {},
    "outputs": [],
    "source": [
     "# predictand to evaluate\n",
-    "PREDICTAND = 'tasmin'\n",
-    "LOSS = 'L1Loss'\n",
+    "PREDICTAND = 'tasmin'"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "3e856f80-14fd-405f-a44e-cc77863f8e5b",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "# loss function and optimizer\n",
+    "LOSS = ['L1Loss', 'MSELoss', 'BernoulliGammaLoss'] if PREDICTAND == 'pr' else ['L1Loss', 'MSELoss']\n",
     "OPTIM = 'Adam'"
    ]
   },
@@ -94,7 +105,8 @@
    "outputs": [],
    "source": [
     "# model to evaluate\n",
-    "model = 'USegNet_{}_ztuvq_500_850_p_dem_doy_{}_{}'.format(PREDICTAND, LOSS, OPTIM)"
+    "models = ['USegNet_{}_ztuvq_500_850_p_dem_doy_1mm_{}_{}'.format(PREDICTAND, loss, OPTIM) if loss == 'BernoulliGammaLoss' else\n",
+    "          'USegNet_{}_ztuvq_500_850_p_dem_doy_{}_{}'.format(PREDICTAND, loss, OPTIM) for loss in LOSS]"
    ]
   },
   {
@@ -105,8 +117,8 @@
    "outputs": [],
    "source": [
     "# get bootstrapped models\n",
-    "models = sorted(search_files(RESULTS.joinpath(PREDICTAND), model + '(.*).nc$'),\n",
-    "                key=lambda x: int(x.stem.split('_')[-1]))\n",
+    "models = {loss: sorted(search_files(RESULTS.joinpath(PREDICTAND), model + '(.*).nc$'),\n",
+    "                       key=lambda x: int(x.stem.split('_')[-1])) for loss, model in zip(LOSS, models)}\n",
     "models"
    ]
   },
@@ -257,38 +269,40 @@
   {
    "cell_type": "code",
    "execution_count": null,
-   "id": "ccaf0118-da51-43b0-a2b6-56ba4b252999",
+   "id": "eb889059-17e4-4d8c-b796-e8b1e2d0bf8c",
    "metadata": {},
    "outputs": [],
    "source": [
-    "y_pred = [xr.open_dataset(sim, chunks={'time': 365}) for sim in models]\n",
+    "y_pred_raw = {k: [xr.open_dataset(v, chunks={'time': 365}) for v in models[k]] for k in models.keys()}\n",
     "if PREDICTAND == 'pr':\n",
-    "    y_pred = [y_p.rename({NAMES[PREDICTAND]: PREDICTAND}) for y_p in y_pred]\n",
-    "    y_pred = [y_p.transpose('time', 'y', 'x') for y_p in y_pred]"
+    "    y_pred_raw = {k: [v.rename({NAMES[PREDICTAND]: PREDICTAND}) for v in y_pred_raw[k]] for k in y_pred_raw.keys()}\n",
+    "    y_pred_raw = {k: [v.transpose('time', 'y', 'x') for v in y_pred_raw[k]] for k in y_pred_raw.keys()}"
    ]
   },
   {
    "cell_type": "code",
    "execution_count": null,
-   "id": "df3f018e-4723-4878-b72a-0586b68e6dfd",
+   "id": "534e020d-96b2-403c-b8e4-86de98fbbe3b",
    "metadata": {},
    "outputs": [],
    "source": [
     "# align datasets and mask missing values\n",
-    "y_prob = []\n",
-    "for i, y_p in enumerate(y_pred):\n",
-    "    \n",
-    "    # check whether evaluating precipitation or temperatures\n",
-    "    if len(y_p.data_vars) > 1:\n",
-    "        _, _, y_p, y_p_prob = xr.align(y_true, y_refe, y_p[PREDICTAND], y_p.prob, join='override')\n",
-    "        y_p_prob = y_p_prob.where(~missing, other=np.nan)  # mask missing values\n",
-    "        y_prob.append(y_p_prob)\n",
-    "    else:\n",
-    "        _, _, y_p = xr.align(y_true, y_refe, y_p[PREDICTAND], join='override')\n",
-    "    \n",
-    "    # mask missing values\n",
-    "    y_p = y_p.where(~missing, other=np.nan)\n",
-    "    y_pred[i] = y_p"
+    "y_prob = {}\n",
+    "y_pred = {}\n",
+    "for loss, models in y_pred_raw.items():\n",
+    "    y_pred[loss], y_prob[loss] = [], []\n",
+    "    for y_p in models:\n",
+    "        # check whether evaluating precipitation or temperatures\n",
+    "        if len(y_p.data_vars) > 1:\n",
+    "            _, _, y_p, y_p_prob = xr.align(y_true, y_refe, y_p[PREDICTAND], y_p.prob, join='override')\n",
+    "            y_p_prob = y_p_prob.where(~missing, other=np.nan)  # mask missing values\n",
+    "            y_prob[loss].append(y_p_prob)\n",
+    "        else:\n",
+    "            _, _, y_p = xr.align(y_true, y_refe, y_p[PREDICTAND], join='override')\n",
+    "\n",
+    "        # mask missing values\n",
+    "        y_p = y_p.where(~missing, other=np.nan)\n",
+    "        y_pred[loss].append(y_p)"
    ]
   },
   {
@@ -309,7 +323,8 @@
    "outputs": [],
    "source": [
     "# create ensemble dataset\n",
-    "ensemble = xr.Dataset({'member_{}'.format(i): member for i, member in enumerate(y_pred)}).to_array('members')"
+    "ensemble = {k: xr.Dataset({'Member-{}'.format(i): member for i, member in enumerate(y_pred[k])}).to_array('members')\n",
+    "            for k in y_pred.keys() if y_pred[k]}"
    ]
   },
   {
@@ -320,8 +335,8 @@
    "outputs": [],
    "source": [
     "# full ensemble mean prediction and standard deviation\n",
-    "ensemble_mean_full = ensemble.mean(dim='members')\n",
-    "ensemble_std_full = ensemble.std(dim='members')"
+    "ensemble_mean_full = {k: v.mean(dim='members') for k, v in ensemble.items()}\n",
+    "ensemble_std_full = {k: v.std(dim='members') for k, v in ensemble.items()}"
    ]
   },
   {
@@ -332,9 +347,9 @@
    "outputs": [],
    "source": [
     "# ensemble mean prediction using three random members\n",
-    "ensemble_3 = np.random.randint(0, len(ensemble.members), size=3)\n",
-    "ensemble_mean_3 = ensemble[ensemble_3, ...].mean(dim='members')\n",
-    "ensemble_std_3 = ensemble[ensemble_3, ...].std(dim='members')"
+    "ensemble_3 = np.random.randint(0, len(ensemble['L1Loss'].members), size=3)\n",
+    "ensemble_mean_3 = {k: v[ensemble_3, ...].mean(dim='members') for k, v in ensemble.items()}\n",
+    "ensemble_std_3 = {k: v[ensemble_3, ...].std(dim='members') for k, v in ensemble.items()}"
    ]
   },
   {
@@ -345,9 +360,9 @@
    "outputs": [],
    "source": [
     "# ensemble mean prediction using five random members\n",
-    "ensemble_5 = np.random.randint(0, len(ensemble.members), size=5)\n",
-    "ensemble_mean_5 = ensemble[ensemble_5, ...].mean(dim='members')\n",
-    "ensemble_std_5 = ensemble[ensemble_5, ...].std(dim='members')"
+    "ensemble_5 = np.random.randint(0, len(ensemble['L1Loss'].members), size=5)\n",
+    "ensemble_mean_5 = {k: v[ensemble_5, ...].mean(dim='members') for k, v in ensemble.items()}\n",
+    "ensemble_std_5 = {k: v[ensemble_5, ...].std(dim='members') for k, v in ensemble.items()}"
    ]
   },
   {
@@ -417,17 +432,6 @@
     "rmse_refe_qm = (y_refe_qm_yearly_avg - y_true_yearly_avg) ** 2"
    ]
   },
-  {
-   "cell_type": "code",
-   "execution_count": null,
-   "id": "5b49ff5b-b4f5-48f9-8cd7-49bb0f2af7da",
-   "metadata": {},
-   "outputs": [],
-   "source": [
-    "# create dataframe for mean bias, mae, and rmse\n",
-    "df_ref = pd.DataFrame([], columns=['bias', 'mae', 'rmse', 'product'])"
-   ]
-  },
   {
    "cell_type": "code",
    "execution_count": null,
@@ -435,11 +439,21 @@
    "metadata": {},
    "outputs": [],
    "source": [
-    "# absolute values for the reference datasets\n",
-    "for product, metrics in zip(['Era-5', 'Era-5 QM'], [[bias_refe, mae_refe, rmse_refe], [bias_refe_qm, mae_refe_qm, rmse_refe_qm]]):\n",
-    "    values = pd.DataFrame([[np.sqrt(m.mean().values.item()) if name == 'rmse' else m.mean().values.item() for name, m in zip(['bias', 'mae', 'rmse'], metrics)] + [product]],\n",
-    "                          columns=df_ref.columns)\n",
-    "    df_ref = df_ref.append(values, ignore_index=True)"
+    "# compute validation metrics for reference datasets\n",
+    "filename = RESULTS.joinpath(PREDICTAND, 'reference.csv')\n",
+    "if filename.exists():\n",
+    "    # check if validation metrics for reference already exist\n",
+    "    df_ref = pd.read_csv(filename)\n",
+    "else:\n",
+    "    # compute validation metrics\n",
+    "    df_ref = pd.DataFrame([], columns=['bias', 'mae', 'rmse', 'product'])\n",
+    "    for product, metrics in zip(['Era-5', 'Era-5 QM'], [[bias_refe, mae_refe, rmse_refe], [bias_refe_qm, mae_refe_qm, rmse_refe_qm]]):\n",
+    "        values = pd.DataFrame([[np.sqrt(m.mean().values.item()) if name == 'rmse' else m.mean().values.item() for\n",
+    "                                name, m in zip(['bias', 'mae', 'rmse'], metrics)] + [product]], columns=df_ref.columns)\n",
+    "        df_ref = df_ref.append(values, ignore_index=True)\n",
+    "    \n",
+    "    # save metrics to disk\n",
+    "    df_ref.to_csv(filename, index=False)"
    ]
   },
   {
@@ -454,10 +468,10 @@
   },
   {
    "cell_type": "markdown",
-   "id": "3f8d75ad-1da2-4f80-aef9-996c0463d1a2",
+   "id": "630ce1c5-b018-437f-a7cf-8c8d99cd8f84",
    "metadata": {},
    "source": [
-    "### Absolute values for each ensemble member"
+    "Calculate yearly average bias, MAE, and RMSE over entire reference period for model predictions."
    ]
   },
   {
@@ -468,10 +482,10 @@
    "outputs": [],
    "source": [
     "# yearly average bias, mae, and rmse for each ensemble member\n",
-    "y_pred_yearly_avg = ensemble.groupby('time.year').mean(dim='time')\n",
-    "bias_pred = y_pred_yearly_avg - y_true_yearly_avg\n",
-    "mae_pred = np.abs(y_pred_yearly_avg - y_true_yearly_avg)\n",
-    "rmse_pred = (y_pred_yearly_avg - y_true_yearly_avg) ** 2"
+    "y_pred_yearly_avg = {k: v.groupby('time.year').mean(dim='time') for k, v in ensemble.items()}\n",
+    "bias_pred = {k: v - y_true_yearly_avg for k, v in y_pred_yearly_avg.items()}\n",
+    "mae_pred = {k: np.abs(v - y_true_yearly_avg) for k, v in y_pred_yearly_avg.items()}\n",
+    "rmse_pred = {k: (v - y_true_yearly_avg) ** 2 for k, v in y_pred_yearly_avg.items()}"
    ]
   },
   {
@@ -483,99 +497,213 @@
    },
    "outputs": [],
    "source": [
-    "# absolute values for each ensemble member\n",
-    "df_pred = pd.DataFrame([], columns=['bias', 'mae', 'rmse', 'product'])\n",
-    "for i in range(len(bias_pred)):\n",
-    "    values = pd.DataFrame([[np.sqrt(m.mean().values.item()) if name == 'rmse' else m.mean().values.item()\n",
-    "                            for name, m in zip(['bias', 'mae', 'rmse'], [bias_pred[i], mae_pred[i], rmse_pred[i]])] + [bias_pred[i].members.item()]],\n",
-    "                          columns=df_pred.columns)\n",
-    "    df_pred = df_pred.append(values, ignore_index=True)"
+    "# compute validation metrics for model predictions\n",
+    "filename = RESULTS.joinpath(PREDICTAND, 'prediction.csv')\n",
+    "if filename.exists():\n",
+    "    # check if validation metrics for predictions already exist\n",
+    "    df_pred = pd.read_csv(filename)\n",
+    "else:\n",
+    "    # validation metrics for each ensemble member\n",
+    "    df_pred = pd.DataFrame([], columns=['bias', 'mae', 'rmse', 'product', 'loss'])\n",
+    "    for k in y_pred_yearly_avg.keys():\n",
+    "        for i in range(len(bias_pred[k])):\n",
+    "            values = pd.DataFrame([[np.sqrt(m.mean().values.item()) if name == 'rmse' else m.mean().values.item()\n",
+    "                                    for name, m in zip(['bias', 'mae', 'rmse'], [bias_pred[k][i], mae_pred[k][i], rmse_pred[k][i]])] +\n",
+    "                                   [bias_pred[k][i].members.item()] + [k]],\n",
+    "                                  columns=df_pred.columns)\n",
+    "            df_pred = df_pred.append(values, ignore_index=True)\n",
+    "        \n",
+    "    # validation metrics for ensembles\n",
+    "    for name, ens in zip(['Ensemble-3', 'Ensemble-5', 'Ensemble-{:d}'.format(len(ensemble['L1Loss']))],\n",
+    "                         [ensemble_mean_3, ensemble_mean_5, ensemble_mean_full]):\n",
+    "        for k, v in ens.items():\n",
+    "            yearly_avg = v.groupby('time.year').mean(dim='time')\n",
+    "            bias = (yearly_avg - y_true_yearly_avg).mean().values.item()\n",
+    "            mae = np.abs(yearly_avg - y_true_yearly_avg).mean().values.item()\n",
+    "            rmse = np.sqrt(((yearly_avg - y_true_yearly_avg) ** 2).mean().values.item())\n",
+    "            values = pd.DataFrame([[bias, mae, rmse, name, k]], columns=df_pred.columns)\n",
+    "            df_pred = df_pred.append(values, ignore_index=True)\n",
+    "    \n",
+    "    # save metrics to disk\n",
+    "    df_pred.to_csv(filename, index=False)"
    ]
   },
   {
    "cell_type": "markdown",
-   "id": "a1fdb3bd-84cd-4ebf-80f4-a0275e372315",
+   "id": "902e299c-a927-41b1-b2ae-987c30dee8cf",
    "metadata": {},
    "source": [
-    "### Absolute values for ensemble predictions"
+    "## Plot results"
    ]
   },
   {
    "cell_type": "code",
    "execution_count": null,
-   "id": "5a5755de-c000-43aa-9d16-637f021691ae",
+   "id": "bdca9b54-3e05-49c8-b1b2-b8c782017306",
    "metadata": {},
    "outputs": [],
    "source": [
-    "for name, ens in zip(['Ensemble-3', 'Ensemble-5', 'Ensemble-{:d}'.format(len(ensemble))], [ensemble_mean_3, ensemble_mean_5, ensemble_mean_full]):\n",
-    "    yearly_avg = ens.groupby('time.year').mean(dim='time')\n",
-    "    bias = (yearly_avg - y_true_yearly_avg).mean().values.item()\n",
-    "    mae = np.abs(yearly_avg - y_true_yearly_avg).mean().values.item()\n",
-    "    rmse = np.sqrt(((yearly_avg - y_true_yearly_avg) ** 2).mean().values.item())\n",
-    "    values = pd.DataFrame([[bias, mae, rmse, name]], columns=df_pred.columns)\n",
-    "    df_pred = df_pred.append(values, ignore_index=True)"
+    "# create a sequential colormap: for reference data, single ensemble members, and ensemble mean predictions\n",
+    "# palette = sns.color_palette('YlOrRd_r', 10) + sns.color_palette('Greens', 3)\n",
+    "palette = sns.color_palette('Blues', len(LOSS))"
    ]
   },
   {
    "cell_type": "markdown",
-   "id": "902e299c-a927-41b1-b2ae-987c30dee8cf",
-   "metadata": {},
+   "id": "3cfcd2de-cd37-42d5-b53d-e8abfd21e242",
+   "metadata": {
+    "tags": []
+   },
    "source": [
-    "## Plot results"
+    "### Absolute values: single members vs. ensemble"
    ]
   },
   {
-   "cell_type": "markdown",
-   "id": "3cfcd2de-cd37-42d5-b53d-e8abfd21e242",
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "48751f7f-9c26-471d-a75e-b7bb2fcb71be",
    "metadata": {},
+   "outputs": [],
    "source": [
-    "### Absolute values"
+    "# dataframe of single members and ensembles only\n",
+    "members = df_pred[~np.isin(df_pred['product'], ['Ensemble-{}'.format(i) for i in [3, 5, 10]])]\n",
+    "ensemble = df_pred[~np.isin(df_pred['product'], ['Member-{}'.format(i) for i in range(10)])]"
    ]
   },
   {
    "cell_type": "code",
    "execution_count": null,
-   "id": "bdca9b54-3e05-49c8-b1b2-b8c782017306",
+   "id": "d7c8e987-0257-4263-ac4b-718a614c458f",
    "metadata": {},
    "outputs": [],
    "source": [
-    "# create a sequential colormap: for reference data, single ensemble members, and ensemble mean predictions\n",
-    "palette = sns.color_palette('YlOrRd_r', 10) + sns.color_palette('Greens', 3)"
+    "# initialize figure\n",
+    "fig = plt.figure(figsize=(16, 5))\n",
+    "\n",
+    "# create grid for different subplots\n",
+    "grid = gridspec.GridSpec(ncols=5, nrows=1, width_ratios=[3, 1, 1, 3, 1], wspace=0.05, hspace=0)\n",
+    "\n",
+    "# add subplots\n",
+    "ax1 = fig.add_subplot(grid[0])\n",
+    "ax2 = fig.add_subplot(grid[1], sharey=ax1)\n",
+    "ax3 = fig.add_subplot(grid[3])\n",
+    "ax4 = fig.add_subplot(grid[4], sharey=ax3)\n",
+    "axes = [ax1, ax2, ax3, ax4]\n",
+    "\n",
+    "# plot bias: single members vs. ensemble\n",
+    "sns.barplot(x='product', y='bias', hue='loss', data=members, palette=palette, ax=ax1);\n",
+    "sns.barplot(x='product', y='bias', hue='loss', data=ensemble, palette=palette, ax=ax2);\n",
+    "\n",
+    "# plot mae: single members vs. ensemble\n",
+    "sns.barplot(x='product', y='mae', hue='loss', data=members, palette=palette, ax=ax3);\n",
+    "sns.barplot(x='product', y='mae', hue='loss', data=ensemble, palette=palette, ax=ax4);\n",
+    "\n",
+    "# axes limits and ticks\n",
+    "y_lim_bias = (-50, 50) if PREDICTAND == 'pr' else (-1, 1)\n",
+    "y_lim_mae = (0, 2) if PREDICTAND == 'pr' else (0, 1)\n",
+    "y_ticks_bias = (np.arange(y_lim_bias[0], y_lim_bias[1] + 10, 10) if PREDICTAND == 'pr' else\n",
+    "                np.arange(y_lim_bias[0], y_lim_bias[1] + 0.2, 0.2))\n",
+    "y_ticks_mae = (np.arange(y_lim_mae[0], y_lim_mae[1] + 10, 10) if PREDICTAND == 'pr' else\n",
+    "               np.arange(y_lim_mae[0], y_lim_mae[1] + 0.2, 0.2))\n",
+    "\n",
+    "# axis for bias\n",
+    "ax1.set_ylabel('Bias (%)' if PREDICTAND == 'pr' else 'Bias (°C)')\n",
+    "ax1.set_ylim(y_lim_bias)\n",
+    "ax1.set_yticks(y_ticks_bias)\n",
+    "\n",
+    "# axis for mae\n",
+    "ax3.set_ylabel('Mean absolute error (mm)' if PREDICTAND == 'pr' else 'Mean absolute error (°C)')\n",
+    "ax3.set_ylim(y_lim_mae)\n",
+    "ax3.set_yticks(y_ticks_mae)\n",
+    "\n",
+    "# adjust axis for ensemble predictions\n",
+    "for ax in [ax2, ax4]:\n",
+    "    ax.yaxis.tick_right()\n",
+    "    ax.set_ylabel('')\n",
+    "\n",
+    "# axis fontsize and legend\n",
+    "for ax in axes:\n",
+    "    ax.tick_params('both', labelsize=14)\n",
+    "    ax.set_xticklabels(ax.get_xticklabels(), rotation=90)\n",
+    "    ax.yaxis.label.set_size(14)\n",
+    "    ax.set_xlabel('')\n",
+    "    \n",
+    "    # adjust legend\n",
+    "    h, _ = ax.get_legend_handles_labels()\n",
+    "    ax.get_legend().remove()\n",
+    "\n",
+    "# show single legend\n",
+    "ax4.legend(bbox_to_anchor=(1.3, 1.05), loc=2, frameon=False, fontsize=14);\n",
+    "\n",
+    "# save figure\n",
+    "fig.savefig('./Figures/{}_members_vs_ensemble.pdf'.format(PREDICTAND), bbox_inches='tight')"
    ]
   },
   {
-   "cell_type": "code",
-   "execution_count": null,
-   "id": "32f1a652-a4a2-4738-ba6f-93fe3ac43658",
+   "cell_type": "markdown",
+   "id": "590ffbaf-0e8d-4b63-9264-ad86078d50c9",
    "metadata": {},
-   "outputs": [],
    "source": [
-    "# absolute values for metrics for both reference and model predictions\n",
-    "df = pd.concat([df_ref, df_pred], ignore_index=True)\n",
-    "df"
+    "### Absolute values: ensemble vs. reference"
    ]
   },
   {
    "cell_type": "code",
    "execution_count": null,
-   "id": "d7c8e987-0257-4263-ac4b-718a614c458f",
+   "id": "b1a9b1b7-9cd7-4998-afbb-11e64e91b333",
    "metadata": {},
    "outputs": [],
    "source": [
     "# initialize figure\n",
-    "fig, axes = plt.subplots(1, 2, figsize=(16, 5))\n",
-    "sns.barplot(x='product', y='bias', data=df_pred, palette=palette, ax=axes[0])\n",
-    "sns.barplot(x='product', y='mae', data=df_pred, palette=palette, ax=axes[1])\n",
-    "# sns.barplot(x='product', y='rmse', data=df, palette=palette, ax=axes[2])\n",
+    "fig, axes = plt.subplots(1, 3, figsize=(16, 5))\n",
     "\n",
-    "# axes limits and ticks\n",
-    "axes[0].set_ylim(-0.5, 0.5)\n",
-    "axes[1].set_ylim(0, 1)\n",
-    "# axes[2].set_ylim(0, 1)\n",
+    "# plot bias: ensemble predictions vs. reference\n",
+    "sns.barplot(x='product', y='bias', hue='loss', data=ensemble, palette=palette, ax=axes[0]);\n",
+    "\n",
+    "# plot mae: ensemble predictions vs. reference\n",
+    "sns.barplot(x='product', y='mae', hue='loss', data=ensemble, palette=palette, ax=axes[1]);\n",
+    "\n",
+    "# plot rmse: ensemble predictions vs. reference\n",
+    "sns.barplot(x='product', y='rmse', hue='loss', data=ensemble, palette=palette, ax=axes[2]);\n",
+    "\n",
+    "# add metrics for reference\n",
+    "for ax, metric in zip(axes, ['bias', 'mae', 'rmse']):\n",
+    "    for product, ls in zip(df_ref['product'], ['-', '--']):\n",
+    "        ax.hlines(df_ref[metric][df_ref['product'] == product].item(), xmin=-0.5, xmax=2.5,\n",
+    "                  color='k', ls=ls, label=product)\n",
+    "\n",
+    "# axis for bias\n",
+    "axes[0].set_ylabel('Bias (%)' if PREDICTAND == 'pr' else 'Bias (°C)')\n",
+    "axes[0].set_ylim(y_lim_bias)\n",
+    "axes[0].set_yticks(y_ticks_bias)\n",
+    "\n",
+    "# axis for mae\n",
+    "axes[1].set_ylabel('Mean absolute error (mm)' if PREDICTAND == 'pr' else 'Mean absolute error (°C)')\n",
+    "axes[1].set_ylim(y_lim_mae)\n",
+    "axes[1].set_yticks(y_ticks_mae)\n",
+    "\n",
+    "# axis for rmse\n",
+    "axes[2].set_ylabel('RMSE (mm)' if PREDICTAND == 'pr' else 'RMSE (°C)')\n",
+    "axes[2].set_ylim(y_lim_mae)\n",
+    "axes[2].set_yticks(y_ticks_mae)\n",
+    "\n",
+    "# axis fontsize and legend\n",
     "for ax in axes:\n",
     "    ax.tick_params('both', labelsize=14)\n",
     "    ax.set_xticklabels(ax.get_xticklabels(), rotation=90)\n",
-    "    ax.yaxis.label.set_size(14)"
+    "    ax.yaxis.label.set_size(14)\n",
+    "    ax.set_xlabel('')\n",
+    "    \n",
+    "    # adjust legend\n",
+    "    h, _ = ax.get_legend_handles_labels()\n",
+    "    ax.get_legend().remove()\n",
+    "\n",
+    "# show single legend\n",
+    "axes[-1].legend(bbox_to_anchor=(1.05, 1.05), loc=2, frameon=False, fontsize=14);\n",
+    "\n",
+    "# save figure\n",
+    "fig.subplots_adjust(wspace=0.25)\n",
+    "fig.savefig('./Figures/{}_ensemble_vs_reference.pdf'.format(PREDICTAND), bbox_inches='tight')"
    ]
   },
   {
-- 
GitLab