diff --git a/Notebooks/eval_precipitation.ipynb b/Notebooks/eval_precipitation.ipynb index 51d33ebbd60e806fe8898d5dfd8a3967df2b00a8..b1ea5e217ce0f05df12783555122bab52e44ddc2 100644 --- a/Notebooks/eval_precipitation.ipynb +++ b/Notebooks/eval_precipitation.ipynb @@ -16,6 +16,26 @@ "We used **1981-1991 as training** period and **1991-2010 as reference** period. The results shown in this notebook are based on the model predictions on the reference period." ] }, + { + "cell_type": "markdown", + "id": "ad1769d4-9c0c-4e3f-9adf-ef02bb43c047", + "metadata": {}, + "source": [ + "**Predictors on pressure levels (500, 850)**:\n", + "- Geopotential (z)\n", + "- Temperature (t)\n", + "- Zonal wind (u)\n", + "- Meridional wind (v)\n", + "- Specific humidity (q)\n", + "\n", + "**Predictors on surface**:\n", + "- Mean sea level pressure (msl)\n", + "\n", + "**Auxiliary predictors**:\n", + "- Elevation from Copernicus EU-DEM v1.1 (dem)\n", + "- Day of the year (doy)" + ] + }, { "cell_type": "markdown", "id": "f9334da7-17d1-45ef-9ed9-5c2bee9fcdcc", @@ -38,6 +58,7 @@ "PLEVELS = ['500', '850']\n", "SPREDICTORS = 'p'\n", "DEM = 'dem'\n", + "DEM_FEATURES = ''\n", "DOY = 'doy'" ] }, @@ -104,6 +125,54 @@ "Image(\"./Figures/architecture.png\", width=900, height=400)" ] }, + { + "cell_type": "markdown", + "id": "c8833efe-c715-4872-9aee-a0b5766f5c67", + "metadata": {}, + "source": [ + "### Loss function" + ] + }, + { + "cell_type": "markdown", + "id": "bb801367-6872-4a0b-bff5-70cb6746e057", + "metadata": {}, + "source": [ + "For precipitation, the network is optimizing the negative log-likelihood of a Bernoulli-Gamma distribution after [Cannon (2008)](http://journals.ametsoc.org/doi/10.1175/2008JHM960.1)." + ] + }, + { + "cell_type": "markdown", + "id": "a8775d5e-5ad4-47e4-8fef-6dc230e15dee", + "metadata": {}, + "source": [ + "Bernoulli-Gamma distribution:" + ] + }, + { + "cell_type": "markdown", + "id": "ab10f8de-d8d2-4427-b9c8-5d68803543c3", + "metadata": {}, + "source": [ + "$$P(y \\mid, p, \\alpha, \\beta) = \\begin{cases} 1 - p, & \\text{for } y = 0\\\\ p \\cdot \\frac{y^{\\alpha -1} \\exp(-y/\\beta)}{\\beta^{\\alpha} \\tau(\\alpha)}, & \\text{for } y > 0\\end{cases}$$" + ] + }, + { + "cell_type": "markdown", + "id": "6b6dbd06-1f0e-4c52-84f2-b7ff31c75726", + "metadata": {}, + "source": [ + "Log-likelihood function:" + ] + }, + { + "cell_type": "markdown", + "id": "e41c7b39-f352-4a98-820f-9a7345b3283c", + "metadata": {}, + "source": [ + "$$\\mathcal{J}(p, \\alpha, \\beta \\mid y) = \\underbrace{(1 - P(y > 0)) \\log(1 - p)}_{\\text{Bernoulli}} + \\underbrace{P(y > 0) \\cdot \\left(\\log(p) + (\\alpha - 1) \\log(y) - \\frac{y}{\\beta} - \\alpha \\log(\\beta) - \\log(\\tau(\\alpha))\\right)}_{\\text{Gamma}}$$" + ] + }, { "cell_type": "markdown", "id": "5a0c55f0-79fb-4501-b3cf-b5414399a3d9", @@ -115,12 +184,29 @@ { "cell_type": "code", "execution_count": null, - "id": "a5db133d-2c36-4e84-879e-20e617e821f1", + "id": "efa76f8e-c089-47ff-a001-d4c2a11c4d6d", "metadata": {}, "outputs": [], "source": [ - "# model predictions and observations NetCDF\n", - "y_pred = TARGET_PATH.joinpath(PREDICTAND, '_'.join([MODEL, PREDICTAND, PPREDICTORS, *PLEVELS, SPREDICTORS, DEM, DOY]) + '.nc')\n", + "# construct file pattern to match\n", + "PATTERN = '_'.join([MODEL, PREDICTAND, PPREDICTORS, *PLEVELS])\n", + "PATTERN = '_'.join([PATTERN, SPREDICTORS]) if SPREDICTORS else PATTERN\n", + "PATTERN = '_'.join([PATTERN, DEM]) if DEM else PATTERN\n", + "PATTERN = '_'.join([PATTERN, DEM_FEATURES]) if DEM_FEATURES else PATTERN\n", + "PATTERN = '_'.join([PATTERN, DOY]) if DOY else PATTERN" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "a5db133d-2c36-4e84-879e-20e617e821f1", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "# model predictions and observations NetCDF \n", + "y_pred = xr.open_dataset(search_files(TARGET_PATH.joinpath(PREDICTAND), '.'.join([PATTERN, 'nc$'])).pop())\n", "y_true = xr.open_dataset(search_files(OBS_PATH.joinpath(PREDICTAND), '.nc$').pop())" ] }, @@ -131,9 +217,8 @@ "metadata": {}, "outputs": [], "source": [ - "# load datasets\n", - "y_pred = xr.open_dataset(y_pred)\n", - "y_true = y_true.sel(time=y_pred.time) # subset to time period covered by predictions" + "# subset to time period covered by predictions\n", + "y_true = y_true.sel(time=y_pred.time) " ] }, { @@ -418,7 +503,7 @@ "# print average bias per season\n", "for var in bias_snl.data_vars:\n", " for season in bias_snl[NAMES[PREDICTAND]].season:\n", - " print('Average bias of {} for season {}: {:.1f}%'.format(var, season.values.item(), bias_snl[var].sel(season=season).mean().item()))" + " print('Average bias of mean {} for season {}: {:.1f}%'.format(var, season.values.item(), bias_snl[var].sel(season=season).mean().item()))" ] }, { @@ -433,7 +518,9 @@ "cell_type": "code", "execution_count": null, "id": "2b2390e4-8bf1-41bd-b7d1-33e92fe8bc65", - "metadata": {}, + "metadata": { + "tags": [] + }, "outputs": [], "source": [ "# plot seasonal differences\n", @@ -450,7 +537,7 @@ "# plot seasonal average bias\n", "for ax, season in zip(axes[1:], seasons):\n", " ds = bias_snl[NAMES[PREDICTAND]].sel(season=season)\n", - " ax.imshow(ds.values, origin='lower', cmap='RdBu_r', vmin=-50, vmax=50)\n", + " ax.imshow(ds.values, origin='lower', cmap='RdBu_r', vmin=-40, vmax=40)\n", " ax.set_title(season, fontsize=16);\n", " ax.text(x=ds.shape[0] - 2, y=2, s='Average: {:.1f}%'.format(ds.mean().item()), fontsize=14, ha='right')\n", "\n", @@ -616,12 +703,109 @@ "fig.savefig('../Notebooks/Figures/{}_average_bias_p{:.0f}.png'.format(PREDICTAND, quantile * 100), dpi=300, bbox_inches='tight')" ] }, + { + "cell_type": "markdown", + "id": "f7758b71-e844-4a47-b741-25cc2d277814", + "metadata": {}, + "source": [ + "### Bias of extremes: winter vs. summer" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "de16b327-dc4b-4645-b3a0-abc8f72c4225", + "metadata": {}, + "outputs": [], + "source": [ + "# group data by season and compute extreme percentile\n", + "with warnings.catch_warnings():\n", + " warnings.simplefilter('ignore', category=RuntimeWarning)\n", + " y_true_ex_snl = y_true.groupby('time.season').quantile(quantile, dim='time')\n", + " y_pred_ex_snl = y_pred_pr.groupby('time.season').quantile(quantile, dim='time')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "185a375c-b7bc-42cf-a57d-08268e824f21", + "metadata": {}, + "outputs": [], + "source": [ + "# compute relative bias in seasonal extremes\n", + "bias_ex_snl = ((y_pred_ex_snl - y_true_ex_snl) / y_true_ex_snl) * 100" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e5f82b16-7ebf-4b73-80dd-0bc7e6a305a1", + "metadata": {}, + "outputs": [], + "source": [ + "# print average bias in extreme per season\n", + "for var in bias_ex_snl.data_vars:\n", + " for season in bias_ex_snl[NAMES[PREDICTAND]].season:\n", + " print('Average bias of P{:.0f} {} for season {}: {:.1f}%'.format(quantile * 100, var, season.values.item(), bias_ex_snl[var].sel(season=season).mean().item()))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "322ee9c4-92db-4041-8a36-45aa8d2021b5", + "metadata": {}, + "outputs": [], + "source": [ + "# plot seasonal differences\n", + "seasons = ('DJF', 'JJA')\n", + "fig, axes = plt.subplots(nrows=1, ncols=len(seasons) + 1, figsize=(24,8), sharex=True, sharey=True)\n", + "axes = axes.flatten()\n", + "\n", + "# plot annual average bias of extreme\n", + "ds = bias_ex[NAMES[PREDICTAND]].mean(dim='year')\n", + "axes[0].imshow(ds.values, origin='lower', cmap='RdBu_r', vmin=-40, vmax=40)\n", + "axes[0].set_title('Annual', fontsize=16);\n", + "axes[0].text(x=ds.shape[0] - 2, y=2, s='Average: {:.1f}%'.format(ds.mean().item()), fontsize=14, ha='right')\n", + "\n", + "# plot seasonal average bias of extreme\n", + "for ax, season in zip(axes[1:], seasons):\n", + " ds = bias_ex_snl[NAMES[PREDICTAND]].sel(season=season)\n", + " ax.imshow(ds.values, origin='lower', cmap='RdBu_r', vmin=-40, vmax=40)\n", + " ax.set_title(season, fontsize=16);\n", + " ax.text(x=ds.shape[0] - 2, y=2, s='Average: {:.1f}%'.format(ds.mean().item()), fontsize=14, ha='right')\n", + "\n", + "# adjust axes\n", + "for ax in axes.flat:\n", + " ax.axes.get_xaxis().set_ticklabels([])\n", + " ax.axes.get_xaxis().set_ticks([])\n", + " ax.axes.get_yaxis().set_ticklabels([])\n", + " ax.axes.get_yaxis().set_ticks([])\n", + " ax.axes.axis('tight')\n", + " ax.set_xlabel('')\n", + " ax.set_ylabel('')\n", + "\n", + "# adjust figure\n", + "fig.suptitle('Average bias of P{:.0f} of {}: 1991 - 2010'.format(quantile * 100, NAMES[PREDICTAND]), fontsize=20);\n", + "fig.subplots_adjust(hspace=0, wspace=0, top=0.85)\n", + "\n", + "# add colorbar for bias\n", + "axes = axes.flatten()\n", + "cbar_ax = fig.add_axes([axes[-1].get_position().x1 + 0.01, axes[-1].get_position().y0,\n", + " 0.01, axes[-1].get_position().y1 - axes[-1].get_position().y0])\n", + "cbar = fig.colorbar(im2, cax=cbar_ax)\n", + "cbar.set_label(label='Relative bias / (%)', fontsize=16)\n", + "cbar.ax.tick_params(labelsize=14)\n", + "\n", + "# save figure\n", + "fig.savefig('../Notebooks/Figures/{}_average_bias_seasonal_ex.png'.format(PREDICTAND), dpi=300, bbox_inches='tight')" + ] + }, { "cell_type": "markdown", "id": "1f1a76c5-0bcb-4f79-832c-caf958e5703a", "metadata": {}, "source": [ - "### Number of wet vs. dry days" + "### Frequency of wet days" ] }, { @@ -631,7 +815,7 @@ "metadata": {}, "outputs": [], "source": [ - "# minimum precipitation amount defining a wet day\n", + "# minimum precipitation (mm / day) defining a wet day\n", "WET_DAY_THRESHOLD = 1" ] }, @@ -643,8 +827,45 @@ "outputs": [], "source": [ "# true and predicted frequency of wet days\n", - "p_true = (y_true >= WET_DAY_THRESHOLD).groupby('time.year').mean(dim='time').astype(np.float32)\n", - "p_pred = (y_pred_pr >= WET_DAY_THRESHOLD).groupby('time.year').mean(dim='time').astype(np.float32)" + "mask = (~np.isnan(y_true)) & (~np.isnan(y_pred_pr))\n", + "wet_days_true = (y_true >= WET_DAY_THRESHOLD).where(mask, other=np.nan).astype(np.float32)\n", + "wet_days_pred = (y_pred_pr >= WET_DAY_THRESHOLD).where(mask, other=np.nan).astype(np.float32)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1ad9ad08-ad1f-4a76-9020-c09a442385c9", + "metadata": {}, + "outputs": [], + "source": [ + "# number of wet days in reference period: annual\n", + "n_wet_days_true = wet_days_true.sum(dim='time', skipna=False)\n", + "n_wet_days_pred = wet_days_pred.sum(dim='time', skipna=False)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "54088572-1873-4949-9937-0ea099e9c2b6", + "metadata": {}, + "outputs": [], + "source": [ + "# frequency of wet days in reference period: annual\n", + "f_wet_days_true = (n_wet_days_true / len(wet_days_true.time)) * 100\n", + "f_wet_days_pred = (n_wet_days_pred / len(wet_days_pred.time)) * 100" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3534d1d9-f50d-40f3-b1f0-d34c4152a37c", + "metadata": {}, + "outputs": [], + "source": [ + "# frequency of wet days in reference period: seasonal\n", + "f_wet_days_true_snl = wet_days_true.groupby('time.season').mean(dim='time', skipna=False)\n", + "f_wet_days_pred_snl = wet_days_pred.groupby('time.season').mean(dim='time', skipna=False)" ] }, { @@ -654,8 +875,11 @@ "metadata": {}, "outputs": [], "source": [ - "# bias of wet vs. dry days\n", - "bias_wet = p_pred - p_true" + "# relative bias of frequency of wet vs. dry days: annual\n", + "bias_wet = ((f_wet_days_pred - f_wet_days_true) / f_wet_days_true) * 100\n", + "\n", + "# relative bias of frequency of wet vs. dry days: seasonal\n", + "bias_wet_snl = ((f_wet_days_pred_snl - f_wet_days_true_snl) / f_wet_days_true_snl) * 100" ] }, { @@ -666,16 +890,97 @@ "outputs": [], "source": [ "# plot average of observation, prediction, and bias\n", - "vmin, vmax = 0, 5\n", + "fig, axes = plt.subplots(2, 3, figsize=(24, 12), sharex=True, sharey=True)\n", + "axes = axes.flatten()\n", + "\n", + "# plot annual average bias of extreme\n", + "ds = bias_wet[NAMES[PREDICTAND]]\n", + "im = axes[0].imshow(ds.values, origin='lower', cmap='RdBu_r', vmin=-40, vmax=40)\n", + "axes[0].set_title('Annual', fontsize=16);\n", + "axes[0].text(x=ds.shape[0] - 2, y=2, s='Average: {:.1f}%'.format(ds.mean().item()), fontsize=14, ha='right')\n", + "\n", + "# plot seasonal average bias of extreme\n", + "for ax, season in zip(axes[1:], bias_wet_snl.season):\n", + " ds = bias_wet_snl[NAMES[PREDICTAND]].sel(season=season)\n", + " ax.imshow(ds.values, origin='lower', cmap='RdBu_r', vmin=-40, vmax=40)\n", + " ax.set_title(season.item(), fontsize=16);\n", + " ax.text(x=ds.shape[0] - 2, y=2, s='Average: {:.1f}%'.format(ds.mean().item()), fontsize=14, ha='right')\n", + "\n", + "# adjust axes\n", + "for ax in axes.flat:\n", + " ax.axes.get_xaxis().set_ticklabels([])\n", + " ax.axes.get_xaxis().set_ticks([])\n", + " ax.axes.get_yaxis().set_ticklabels([])\n", + " ax.axes.get_yaxis().set_ticks([])\n", + " ax.axes.axis('tight')\n", + " ax.set_xlabel('')\n", + " ax.set_ylabel('')\n", + " \n", + "# turn off last axis\n", + "axes[-1].set_visible(False)\n", + "\n", + "# adjust figure\n", + "fig.suptitle('Frequency of wet days (>= {:.1f} mm): 1991 - 2010'.format(WET_DAY_THRESHOLD), fontsize=20);\n", + "fig.subplots_adjust(hspace=0.1, wspace=0, top=0.925)\n", + "\n", + "# add colorbar\n", + "cbar_ax_predictand = fig.add_axes([axes[-1].get_position().x1 + 0.01, axes[-1].get_position().y0,\n", + " 0.01, axes[0].get_position().y1 - axes[-1].get_position().y0])\n", + "cbar_predictand = fig.colorbar(im, cax=cbar_ax_predictand)\n", + "cbar_predictand.set_label(label='Relative bias / (%)', fontsize=16)\n", + "cbar_predictand.ax.tick_params(labelsize=14)\n", + "\n", + "# save figure\n", + "fig.savefig('../Notebooks/Figures/{}_bias_wet_days.png'.format(PREDICTAND), dpi=300, bbox_inches='tight')" + ] + }, + { + "cell_type": "markdown", + "id": "fde67674-8bed-45b2-b4ca-4a507084b5a8", + "metadata": {}, + "source": [ + "### Mean wet day precipitation" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "aa2d2b5a-e192-4c34-975c-d4e61c969a15", + "metadata": {}, + "outputs": [], + "source": [ + "# calculate mean wet day precipitation\n", + "dii_true = (y_true * wet_days_true).sum(dim='time', skipna=False) / n_wet_days_true\n", + "dii_pred = (y_pred_pr * wet_days_pred).sum(dim='time', skipna=False) / n_wet_days_pred" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d6331bb6-afe8-4e62-9fa8-60bce9af42ec", + "metadata": {}, + "outputs": [], + "source": [ + "# calculate relative bias of mean wet day precipitation\n", + "bias_dii = ((dii_pred - dii_true) / dii_true) * 100" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "36e3cbac-233d-4892-9756-941d91981534", + "metadata": {}, + "outputs": [], + "source": [ + "# plot average of observation, prediction, and bias\n", "fig, axes = plt.subplots(1, 3, figsize=(24, 6), sharex=True, sharey=True)\n", - "for i, var in enumerate(p_true):\n", - " for ds, ax in zip([p_true, p_pred, bias_wet], axes):\n", - " if ds is bias_wet:\n", - " ds = ds[var].mean(dim='year')\n", - " im2 = ax.imshow(ds.values, origin='lower', cmap='RdBu_r', vmin=-1, vmax=1)\n", - " ax.text(x=ds.shape[0] - 2, y=2, s='Average: {:.1f}%'.format(ds.mean().item()), fontsize=14, ha='right')\n", + "for i, var in enumerate(dii_true):\n", + " for ds, ax in zip([dii_true, dii_pred, bias_dii], axes):\n", + " if ds is bias_dii:\n", + " im2 = ax.imshow(ds[var].values, origin='lower', cmap='RdBu_r', vmin=-40, vmax=40)\n", + " ax.text(x=ds[var].shape[0] - 2, y=2, s='Average: {:.1f}%'.format(ds[var].mean().item()), fontsize=14, ha='right')\n", " else:\n", - " im1 = ax.imshow(ds[var].mean(dim='year').values, origin='lower', cmap='BuPu', vmin=vmin, vmax=vmax)\n", + " im1 = ax.imshow(ds[var].values, origin='lower', cmap='BuPu', vmin=0, vmax=15)\n", " \n", "# set titles\n", "axes[0].set_title('Observed', fontsize=16, pad=10);\n", @@ -693,7 +998,7 @@ " ax.set_ylabel('')\n", "\n", "# adjust figure\n", - "fig.suptitle('Average {}: 1991 - 2010'.format(NAMES[PREDICTAND]), fontsize=20);\n", + "fig.suptitle('Mean wet day (>= {:.1f} mm) precipitation: 1991 - 2010'.format(WET_DAY_THRESHOLD), fontsize=20);\n", "fig.subplots_adjust(hspace=0, wspace=0, top=0.85)\n", "\n", "# add colorbar for bias\n", @@ -709,39 +1014,11 @@ " axes[-1].get_position().x0 - axes[0].get_position().x0,\n", " 0.05])\n", "cbar_predictand = fig.colorbar(im1, cax=cbar_ax_predictand, orientation='horizontal')\n", - "cbar_predictand.set_label(label='{} / '.format(NAMES[PREDICTAND].capitalize()) + '(mm day$^{-1}$)', fontsize=16)\n", + "cbar_predictand.set_label(label='Mean wet day precipitation / (mm day$^{-1}$)', fontsize=16)\n", "cbar_predictand.ax.tick_params(labelsize=14)\n", "\n", - "# add metrics: MAE and RMSE\n", - "axes[1].text(x=ds.shape[0] - 2, y=2, s='MAE = {:.1f}'.format(mae_avg[NAMES[PREDICTAND]].item()) + 'mm day$^{-1}$', fontsize=14, ha='right')\n", - "axes[1].text(x=ds.shape[0] - 2, y=12, s='RMSE = {:.1f}'.format(rmse_avg[NAMES[PREDICTAND]].item()) + 'mm$^2$ day$^{-2}$', fontsize=14, ha='right')\n", - "\n", "# save figure\n", - "fig.savefig('../Notebooks/Figures/{}_average_bias_wet_days.png'.format(PREDICTAND), dpi=300, bbox_inches='tight')" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "fbd79a2c-d88e-4ca8-8dad-80488eccf7cb", - "metadata": {}, - "outputs": [], - "source": [ - "# apply mask of valid pixels\n", - "mask = (~np.isnan(p_true) & ~np.isnan(p_pred))\n", - "p_pred = p_pred[mask]\n", - "p_true = p_true[mask]" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "fbf52265-126b-49b3-a83c-419cfaaac5f0", - "metadata": {}, - "outputs": [], - "source": [ - "# compute classification report\n", - "report = classification_report(p_true, p_pred, target_names=['Dry days', 'Wet days'], output_dict=True)" + "fig.savefig('../Notebooks/Figures/{}_bias_wet_days_p.png'.format(PREDICTAND), dpi=300, bbox_inches='tight')" ] }, { @@ -812,7 +1089,7 @@ "# plot classifier with no skill\n", "interval = np.arange(-0.05, 1.1, 0.05)\n", "ax.plot([0, 1], [0, 1], lw=2, linestyle='--', color='k')\n", - "ax.text(0.95, 0.975, 'Random Classifier', ha='right', va='top', rotation=45)\n", + "ax.text(0.95, 0.975, 'Random Classifier', ha='right', va='top', rotation=45, fontsize=12)\n", "\n", "# plot perfect classifier\n", "ax.plot(0, 1, '-o', markersize=5, markerfacecolor='k', markeredgecolor='none')\n", diff --git a/Notebooks/eval_temperature.ipynb b/Notebooks/eval_temperature.ipynb index ec47bf27ca6e35eed18e79976fe60a51e314258a..1cf81452ea1033837a45a51e786d686b2d368081 100644 --- a/Notebooks/eval_temperature.ipynb +++ b/Notebooks/eval_temperature.ipynb @@ -16,6 +16,26 @@ "We used **1981-1991 as training** period and **1991-2010 as reference** period. The results shown in this notebook are based on the model predictions on the reference period." ] }, + { + "cell_type": "markdown", + "id": "3b06b8f0-090e-4da6-87bf-be19c0dddd7d", + "metadata": {}, + "source": [ + "**Predictors on pressure levels (500, 850)**:\n", + "- Geopotential (z)\n", + "- Temperature (t)\n", + "- Zonal wind (u)\n", + "- Meridional wind (v)\n", + "- Specific humidity (q)\n", + "\n", + "**Predictors on surface**:\n", + "- Mean sea level pressure (msl)\n", + "\n", + "**Auxiliary predictors**:\n", + "- Elevation from Copernicus EU-DEM v1.1 (dem)\n", + "- Day of the year (doy)" + ] + }, { "cell_type": "markdown", "id": "4186c89c-b55b-4559-a818-8b712baaf44e", @@ -38,6 +58,7 @@ "PLEVELS = ['500', '850']\n", "SPREDICTORS = 'p'\n", "DEM = 'dem'\n", + "DEM_FEATURES = ''\n", "DOY = 'doy'" ] }, @@ -103,6 +124,30 @@ "Image(\"./Figures/architecture.png\", width=900, height=400)" ] }, + { + "cell_type": "markdown", + "id": "1db1f750-7d74-48ea-a385-94595a1da724", + "metadata": {}, + "source": [ + "### Loss function" + ] + }, + { + "cell_type": "markdown", + "id": "eb4ac3ea-213a-4391-8c88-909f6288de63", + "metadata": {}, + "source": [ + "For temperature, the network is optimizing the mean-squared error (negative log-likelihood of normal distribution):" + ] + }, + { + "cell_type": "markdown", + "id": "83bd9d8f-91b4-45f5-b58e-3551a942661a", + "metadata": {}, + "source": [ + "$$\\mathcal{J}(y_{pred} \\mid y_{true}) = \\left(y_{pred} - y_{true}\\right)^2$$" + ] + }, { "cell_type": "markdown", "id": "f47caa41-9380-4c02-8785-4febcf2cb2d0", @@ -111,6 +156,21 @@ "### Load datasets" ] }, + { + "cell_type": "code", + "execution_count": null, + "id": "e8cb1011-32dd-4e32-9692-4a4aa009869f", + "metadata": {}, + "outputs": [], + "source": [ + "# construct file pattern to match\n", + "PATTERN = '_'.join([MODEL, PREDICTAND, PPREDICTORS, *PLEVELS])\n", + "PATTERN = '_'.join([PATTERN, SPREDICTORS]) if SPREDICTORS else PATTERN\n", + "PATTERN = '_'.join([PATTERN, DEM]) if DEM else PATTERN\n", + "PATTERN = '_'.join([PATTERN, DEM_FEATURES]) if DEM_FEATURES else PATTERN\n", + "PATTERN = '_'.join([PATTERN, DOY]) if DOY else PATTERN" + ] + }, { "cell_type": "code", "execution_count": null, @@ -119,7 +179,7 @@ "outputs": [], "source": [ "# model predictions and observations NetCDF\n", - "y_pred = TARGET_PATH.joinpath(PREDICTAND, '_'.join([MODEL, PREDICTAND, PPREDICTORS, *PLEVELS, SPREDICTORS, DEM, DOY]) + '.nc')\n", + "y_pred = xr.open_dataset(search_files(TARGET_PATH.joinpath(PREDICTAND), '(.*)'.join([PATTERN, '.nc$'])).pop())\n", "if PREDICTAND == 'tas':\n", " # read both tasmax and tasmin\n", " tasmax = xr.open_dataset(search_files(OBS_PATH.joinpath('tasmax'), '.nc$').pop())\n", @@ -136,9 +196,8 @@ "metadata": {}, "outputs": [], "source": [ - "# load datasets\n", - "y_pred = xr.open_dataset(y_pred)\n", - "y_true = y_true.sel(time=y_pred.time) # subset to time period covered by predictions" + "# subset to time period covered by predictions\n", + "y_true = y_true.sel(time=y_pred.time)" ] }, { @@ -185,7 +244,9 @@ "cell_type": "code", "execution_count": null, "id": "49dff6ce-a629-460b-a43b-d1a0ef447351", - "metadata": {}, + "metadata": { + "tags": [] + }, "outputs": [], "source": [ "# apply mask of valid pixels\n", @@ -441,21 +502,20 @@ "outputs": [], "source": [ "# plot seasonal differences\n", - "seasons = ('DJF', 'JJA')\n", - "fig, axes = plt.subplots(nrows=1, ncols=len(seasons) + 1, figsize=(24,8), sharex=True, sharey=True)\n", + "fig, axes = plt.subplots(nrows=2, ncols=3, figsize=(24, 12), sharex=True, sharey=True)\n", "axes = axes.flatten()\n", "\n", "# plot annual average bias\n", "ds = bias_yearly_avg[PREDICTAND].mean(dim='year')\n", - "axes[0].imshow(ds.values, origin='lower', cmap='RdBu_r', vmin=-2, vmax=2)\n", + "im = axes[0].imshow(ds.values, origin='lower', cmap='RdBu_r', vmin=-2, vmax=2)\n", "axes[0].set_title('Annual', fontsize=16);\n", "axes[0].text(x=ds.shape[0] - 2, y=2, s='Average: {:.2f}°C'.format(ds.mean().item()), fontsize=14, ha='right')\n", "\n", "# plot seasonal average bias\n", - "for ax, season in zip(axes[1:], seasons):\n", + "for ax, season in zip(axes[1:], bias_snl.season):\n", " ds = bias_snl[PREDICTAND].sel(season=season)\n", " ax.imshow(ds.values, origin='lower', cmap='RdBu_r', vmin=-2, vmax=2)\n", - " ax.set_title(season, fontsize=16);\n", + " ax.set_title(season.item(), fontsize=16);\n", " ax.text(x=ds.shape[0] - 2, y=2, s='Average: {:.2f}°C'.format(ds.mean().item()), fontsize=14, ha='right')\n", "\n", "# adjust axes\n", @@ -467,18 +527,20 @@ " ax.axes.axis('tight')\n", " ax.set_xlabel('')\n", " ax.set_ylabel('')\n", + " \n", + "# turn off last axis\n", + "axes[-1].set_visible(False)\n", "\n", "# adjust figure\n", "fig.suptitle('Average bias of {}: 1991 - 2010'.format(NAMES[PREDICTAND]), fontsize=20);\n", - "fig.subplots_adjust(hspace=0, wspace=0, top=0.85)\n", + "fig.subplots_adjust(hspace=0.1, wspace=0, top=0.925)\n", "\n", - "# add colorbar for bias\n", - "axes = axes.flatten()\n", - "cbar_ax = fig.add_axes([axes[-1].get_position().x1 + 0.01, axes[-1].get_position().y0,\n", - " 0.01, axes[-1].get_position().y1 - axes[-1].get_position().y0])\n", - "cbar = fig.colorbar(im2, cax=cbar_ax)\n", - "cbar.set_label(label='Bias / (°C)', fontsize=16)\n", - "cbar.ax.tick_params(labelsize=14)\n", + "# add colorbar\n", + "cbar_ax_predictand = fig.add_axes([axes[-1].get_position().x1 + 0.01, axes[-1].get_position().y0,\n", + " 0.01, axes[0].get_position().y1 - axes[-1].get_position().y0])\n", + "cbar_predictand = fig.colorbar(im, cax=cbar_ax_predictand)\n", + "cbar_predictand.set_label(label='Bias / (°C)', fontsize=16)\n", + "cbar_predictand.ax.tick_params(labelsize=14)\n", "\n", "# save figure\n", "fig.savefig('../Notebooks/Figures/{}_average_bias_seasonal.png'.format(PREDICTAND), dpi=300, bbox_inches='tight')" @@ -604,7 +666,9 @@ "cell_type": "code", "execution_count": null, "id": "cfc196fa-0999-4603-959c-2c82f038c8fa", - "metadata": {}, + "metadata": { + "tags": [] + }, "outputs": [], "source": [ "# plot extremes of observation, prediction, and bias\n",