diff --git a/Notebooks/eval_temperature.ipynb b/Notebooks/eval_temperature.ipynb index bb34b7c493915f15b2ac1d28d7e50562fa35bdd6..23ea0a6ed61c03ab6e96c8613ef258475706beac 100644 --- a/Notebooks/eval_temperature.ipynb +++ b/Notebooks/eval_temperature.ipynb @@ -32,7 +32,7 @@ "outputs": [], "source": [ "# define the model parameters\n", - "PREDICTAND = 'tasmin'\n", + "PREDICTAND = 'tasmax'\n", "MODEL = 'USegNet'\n", "PPREDICTORS = 'ztuvq'\n", "PLEVELS = ['500', '850']\n", @@ -70,6 +70,7 @@ "# builtins\n", "import datetime\n", "import warnings\n", + "import calendar\n", "\n", "# externals\n", "import xarray as xr\n", @@ -77,6 +78,7 @@ "import matplotlib.pyplot as plt\n", "import scipy.stats as stats\n", "from IPython.display import Image\n", + "from sklearn.metrics import r2_score\n", "\n", "# locals\n", "from climax.main.io import ERA5_PATH, OBS_PATH, TARGET_PATH\n", @@ -159,6 +161,89 @@ "## Model validation: temperature" ] }, + { + "cell_type": "markdown", + "id": "ab15d557-c7ea-40c0-9977-a3d410fea784", + "metadata": {}, + "source": [ + "### Coefficient of determination" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "619d5dc9-4d36-43a3-b23c-a4ea51229c78", + "metadata": {}, + "outputs": [], + "source": [ + "# get predicted and observed values over entire time series and grid points\n", + "y_pred_values = y_pred[PREDICTAND].values.flatten()\n", + "y_true_values = y_true[PREDICTAND].values.flatten()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "49dff6ce-a629-460b-a43b-d1a0ef447351", + "metadata": {}, + "outputs": [], + "source": [ + "# apply mask of valid pixels\n", + "mask = (~np.isnan(y_pred_values) & ~np.isnan(y_true_values))\n", + "y_pred_values = y_pred_values[mask]\n", + "y_true_values = y_true_values[mask]" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e9e46770-4176-4257-8ea2-7050d3325e98", + "metadata": {}, + "outputs": [], + "source": [ + "# calculate coefficient of determination\n", + "r2 = r2_score(y_true_values, y_pred_values)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "703ab604-5193-4032-92eb-80f2cff9fc2c", + "metadata": {}, + "outputs": [], + "source": [ + "# scatter plot of observations vs. predictions\n", + "fig, ax = plt.subplots(1, 1, figsize=(10, 10))\n", + "\n", + "# plot only a subset of data: otherwise plot is overloaded ...\n", + "subset = np.random.choice(np.arange(0, len(y_pred_values)), size=int(1e7), replace=False)\n", + "ax.plot(y_true_values[subset], y_pred_values[subset], 'o', alpha=.5, markeredgecolor='grey', markerfacecolor='none', markersize=3);\n", + "\n", + "# plot entire dataset\n", + "# ax.plot(y_true_values, y_pred_values, 'o', alpha=.5, markeredgecolor='grey', markerfacecolor='none', markersize=3);\n", + "\n", + "# plot 1:1 mapping line\n", + "interval = np.arange(-40, 45, 5)\n", + "ax.plot(interval, interval, color='k', lw=2, ls='--')\n", + "\n", + "# add coefficient of determination: calculated on entire dataset!\n", + "ax.text(interval[-1] - 1, interval[0] + 1, s='Coefficient of determination R$^2$ = {:.2f}'.format(r2), ha='right', fontsize=14)\n", + "\n", + "# format axes\n", + "ax.set_ylim(-40, 40)\n", + "ax.set_xlim(-40, 40)\n", + "ax.set_xticks(interval)\n", + "ax.set_xticklabels(interval, fontsize=14)\n", + "ax.set_yticks(interval)\n", + "ax.set_yticklabels(interval, fontsize=14)\n", + "ax.set_xlabel('Observed', fontsize=14)\n", + "ax.set_ylabel('Predicted', fontsize=14)\n", + "ax.set_title('{} (°C): 1991 - 2010'.format(NAMES[PREDICTAND].capitalize()), fontsize=16, pad=10);\n", + "\n", + "# save figure\n", + "fig.savefig('../Notebooks/Figures/{}_r2.png'.format(PREDICTAND), dpi=300, bbox_inches='tight')" + ] + }, { "cell_type": "markdown", "id": "5e8be24e-8ca2-4582-98c0-b56c6db289d2", @@ -186,8 +271,8 @@ "y_pred_yearly_avg = y_pred.groupby('time.year').mean(dim='time')\n", "y_true_yearly_avg = y_true.groupby('time.year').mean(dim='time')\n", "bias_yearly_avg = y_pred_yearly_avg - y_true_yearly_avg\n", - "for var in bias:\n", - " print('Yearly average bias {}: {:.2f}'.format(var, bias_yearly_avg[var].mean().item()))" + "for var in bias_yearly_avg:\n", + " print('Yearly average bias of {}: {:.2f}'.format(var, bias_yearly_avg[var].mean().item()))" ] }, { @@ -199,8 +284,8 @@ "source": [ "# mean absolute error over reference period\n", "mae_avg = np.abs(y_pred_yearly_avg - y_true_yearly_avg).mean()\n", - "for var in mae:\n", - " print('Yearly average mean absolute error {}: {:.2f}'.format(var, mae[var].item()))" + "for var in mae_avg:\n", + " print('Yearly average MAE of {}: {:.2f}'.format(var, mae_avg[var].item()))" ] }, { @@ -212,8 +297,8 @@ "source": [ "# root mean squared error over reference period\n", "rmse_avg = ((y_pred_yearly_avg - y_true_yearly_avg) ** 2).mean()\n", - "for var in rmse:\n", - " print('Root mean squared error {}: {:.2f}'.format(var, rmse[var].item()))" + "for var in rmse_avg:\n", + " print('Yearly average RMSE of {}: {:.2f}'.format(var, rmse_avg[var].item()))" ] }, { @@ -242,6 +327,7 @@ "outputs": [], "source": [ "# plot average of observation, prediction, and bias\n", + "vmin, vmax = (-15, 15) if PREDICTAND == 'tasmin' else (0, 25)\n", "fig, axes = plt.subplots(len(y_pred_yearly_avg.data_vars), 3, figsize=(24, len(y_pred_yearly_avg.data_vars) * 6),\n", " sharex=True, sharey=True)\n", "axes = axes.reshape(len(y_pred_yearly_avg.data_vars), -1)\n", @@ -252,12 +338,12 @@ " im2 = ax.imshow(ds.values, origin='lower', cmap='RdBu_r', vmin=-2, vmax=2)\n", " ax.text(x=ds.shape[0] - 2, y=2, s='Average: {:.2f}°C'.format(ds.mean().item()), fontsize=14, ha='right')\n", " else:\n", - " im1 = ax.imshow(ds[var].mean(dim='year').values, origin='lower', cmap='RdYlBu_r', vmin=-15, vmax=15)\n", + " im1 = ax.imshow(ds[var].mean(dim='year').values, origin='lower', cmap='RdYlBu_r', vmin=vmin, vmax=vmax)\n", " \n", "# set titles\n", - "axes[0, 0].set_title('Observed', fontsize=16);\n", - "axes[0, 1].set_title('Predicted', fontsize=16);\n", - "axes[0, 2].set_title('Bias', fontsize=16);\n", + "axes[0, 0].set_title('Observed', fontsize=16, pad=10);\n", + "axes[0, 1].set_title('Predicted', fontsize=16, pad=10);\n", + "axes[0, 2].set_title('Bias', fontsize=16, pad=10);\n", "\n", "# adjust axes\n", "for ax in axes.flat:\n", @@ -275,11 +361,23 @@ "\n", "# add colorbar for bias\n", "axes = axes.flatten()\n", - "cbar_ax = fig.add_axes([axes[-1].get_position().x1 + 0.01, axes[-1].get_position().y0,\n", - " 0.01, axes[-1].get_position().y1 - axes[-1].get_position().y0])\n", - "cbar = fig.colorbar(im2, cax=cbar_ax)\n", - "cbar.set_label(label='Bias / (°C)', fontsize=16)\n", - "cbar.ax.tick_params(labelsize=14)\n", + "cbar_ax_bias = fig.add_axes([axes[-1].get_position().x1 + 0.01, axes[-1].get_position().y0,\n", + " 0.01, axes[-1].get_position().y1 - axes[-1].get_position().y0])\n", + "cbar_bias = fig.colorbar(im2, cax=cbar_ax_bias)\n", + "cbar_bias.set_label(label='Bias / (°C)', fontsize=16)\n", + "cbar_bias.ax.tick_params(labelsize=14)\n", + "\n", + "# add colorbar for predictand\n", + "cbar_ax_predictand = fig.add_axes([axes[0].get_position().x0, axes[0].get_position().y0 - 0.1,\n", + " axes[-1].get_position().x0 - axes[0].get_position().x0,\n", + " 0.05])\n", + "cbar_predictand = fig.colorbar(im1, cax=cbar_ax_predictand, orientation='horizontal')\n", + "cbar_predictand.set_label(label='{} / (°C)'.format(NAMES[PREDICTAND].capitalize()), fontsize=16)\n", + "cbar_predictand.ax.tick_params(labelsize=14)\n", + "\n", + "# add metrics: MAE and RMSE\n", + "axes[1].text(x=ds.shape[0] - 2, y=2, s='MAE = {:.2f}°C'.format(mae_avg[PREDICTAND].item()), fontsize=14, ha='right')\n", + "axes[1].text(x=ds.shape[0] - 2, y=12, s='RMSE = {:.2f}°C$^2$'.format(rmse_avg[PREDICTAND].item()), fontsize=14, ha='right')\n", "\n", "# save figure\n", "fig.savefig('../Notebooks/Figures/{}_average_bias.png'.format(PREDICTAND), dpi=300, bbox_inches='tight')" @@ -323,7 +421,7 @@ "source": [ "# print average bias per season\n", "for var in bias_snl.data_vars:\n", - " for season in bias_snl.tasmin.season:\n", + " for season in bias_snl[PREDICTAND].season:\n", " print('Average bias of {} for season {}: {:.2f}'.format(var, season.values.item(), bias_snl[var].sel(season=season).mean().item()))" ] }, @@ -388,20 +486,54 @@ }, { "cell_type": "markdown", - "id": "c70b369d-2d16-42e3-9300-4a18757ad1b2", + "id": "41600269-2f8c-4717-8f74-b3dfaef60359", "metadata": {}, "source": [ - "### Bias of extreme values" + "Calculate the mean annual cycle:" ] }, { "cell_type": "code", "execution_count": null, - "id": "4acfc3f2-20ed-498c-ab35-f392ae0e64f9", + "id": "c9f27c01-4dfc-4d16-8d29-00e69b7794cd", "metadata": {}, "outputs": [], "source": [ - "# TODO: smooth quantiles" + "# group timeseries by month and calculate mean over time and space\n", + "y_pred_ac = y_pred.groupby('time.month').mean(dim=('time', 'y', 'x'))\n", + "y_true_ac = y_true.groupby('time.month').mean(dim=('time', 'y', 'x'))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "bcc63e73-2636-49c1-a66b-241eb5407e2e", + "metadata": {}, + "outputs": [], + "source": [ + "# plot mean annual cycle\n", + "fig, ax = plt.subplots(1, 1, figsize=(10, 10))\n", + "ax.plot(y_pred_ac[PREDICTAND].values, ls='--', color='k', label='Predicted')\n", + "ax.plot(y_true_ac[PREDICTAND].values, ls='-', color='k', label='Observed')\n", + "ax.legend(frameon=False, fontsize=14);\n", + "ax.set_yticks(np.arange(np.floor(y_true_ac[PREDICTAND].min().item()), np.ceil(y_true_ac[PREDICTAND].max().item()) + 1, 1))\n", + "ax.set_yticklabels(np.arange(np.floor(y_true_ac[PREDICTAND].min().item()), np.ceil(y_true_ac[PREDICTAND].max().item()) + 1, 1), fontsize=12)\n", + "ax.set_xticks(np.arange(0, 12))\n", + "ax.set_xticklabels([calendar.month_name[i + 1] for i in np.arange(0, 12)], rotation=90, fontsize=12)\n", + "ax.set_title('Mean annual cycle of {}: 1991 - 2010'.format(NAMES[PREDICTAND]), pad=20, fontsize=16);\n", + "ax.set_ylabel('{} / (°C)'.format(NAMES[PREDICTAND].capitalize()), fontsize=14)\n", + "ax.set_xlabel('Month', fontsize=14);\n", + "\n", + "# save figure\n", + "fig.savefig('../Notebooks/Figures/{}_mean_annual_cycle.png'.format(PREDICTAND), dpi=300, bbox_inches='tight')" + ] + }, + { + "cell_type": "markdown", + "id": "c70b369d-2d16-42e3-9300-4a18757ad1b2", + "metadata": {}, + "source": [ + "### Bias of extreme values" ] }, { @@ -411,67 +543,126 @@ "metadata": {}, "outputs": [], "source": [ - "# percentiles of interest\n", - "percentiles = [0.01, 0.02, 0.98, 0.99]" + "# extreme quantile of interest\n", + "quantile = 0.02 if PREDICTAND == 'tasmin' else 0.98" ] }, { "cell_type": "code", "execution_count": null, - "id": "51137d23-a380-4d48-a005-fd1edaf554eb", + "id": "c3da76d3-7261-4084-b5ec-f65682fd6596", "metadata": {}, "outputs": [], "source": [ - "# calculate percentiles over reference period\n", + "# calculate extreme quantile for each year\n", "with warnings.catch_warnings():\n", - " warnings.simplefilter('ignore')\n", - " y_pred_dist = y_pred.quantile(q=percentiles, dim='time')\n", - " y_true_dist = y_true.quantile(q=percentiles, dim='time')" + " warnings.simplefilter('ignore', category=RuntimeWarning)\n", + " y_pred_ex = y_pred.groupby('time.year').quantile(quantile, dim='time')\n", + " y_true_ex = y_true.groupby('time.year').quantile(quantile, dim='time')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "20db96c8-e04f-4acb-886b-abb740863fbb", + "metadata": {}, + "outputs": [], + "source": [ + "# calculate bias in extreme quantile for each year\n", + "bias_ex = y_pred_ex - y_true_ex\n", + "for var in bias_ex:\n", + " print('Yearly average bias for P{:.0f} of {}: {:.2f}'.format(quantile * 100, var, bias_ex[var].mean().item()))" ] }, { "cell_type": "code", "execution_count": null, - "id": "b461c774-c5fc-4a50-8609-34a7e7674a34", + "id": "c7aeff55-deef-4d3c-a251-eac877c9afd9", "metadata": {}, "outputs": [], "source": [ - "# calculate bias in each percentile over entire reference period\n", - "bias_dist = y_pred_dist - y_true_dist" + "# mean absolute error in extreme quantile\n", + "mae_ex = np.abs(y_pred_ex - y_true_ex).mean()\n", + "for var in mae_ex:\n", + " print('Yearly average MAE for P{:.0f} of {}: {:.2f}'.format(quantile * 100, var, mae_avg[var].item()))" ] }, { "cell_type": "code", "execution_count": null, - "id": "bd2ec314-87ed-407f-af8b-5e2c6785e9cc", + "id": "2dc6b7cf-7ae1-4b0e-8483-376eab59f5dd", "metadata": {}, "outputs": [], "source": [ - "# calculate correlation coefficient for extreme values\n", - "for var in y_pred_dist:\n", - " for q in percentiles:\n", - " y_p = y_pred_dist[var].sel(quantile=q).values[~np.isnan(y_pred_dist[var].sel(quantile=q))]\n", - " y_t = y_true_dist[var].sel(quantile=q).values[~np.isnan(y_true_dist[var].sel(quantile=q))]\n", - " r, _ = stats.pearsonr(y_p, y_t)\n", - " print('Pearson correlation for {}, q={:.2f}: R={:.2f}'.format(var, q, r))" + "# root mean squared error over reference period\n", + "rmse_ex = ((y_pred_ex - y_true_ex) ** 2).mean()\n", + "for var in rmse_ex:\n", + " print('Yearly average RMSE for P{:.0f} of {}: {:.2f}'.format(quantile * 100, var, rmse_ex[var].item()))" ] }, { "cell_type": "code", "execution_count": null, - "id": "04a1b4fc-8fc0-4e1e-adbe-bf1eaafeac5d", + "id": "cfc196fa-0999-4603-959c-2c82f038c8fa", "metadata": {}, "outputs": [], "source": [ - "# plot bias in each percentile\n", - "fig, axes = plt.subplots(len(y_pred_dist.data_vars), len(percentiles), sharex=True, sharey=True, figsize=(32, 6))\n", - "axes = axes.reshape(len(y_pred_dist.data_vars), -1)\n", - "for ax, var in zip(axes, y_pred_dist):\n", - " # iterate over percentiles\n", - " for axis, q in zip(ax, percentiles):\n", - " ds = bias_dist.sel(quantile=q).to_array()\n", - " ds.plot(ax=axis, vmin=-2, vmax=2, cmap='RdBu_r') \n", - " axis.text(x=bias_dist.x[-1], y=bias_dist.y[0], s='Avg: {:.2f}'.format(ds.mean().item()), ha='right', va='bottom')" + "# plot extremes of observation, prediction, and bias\n", + "vmin, vmax = (-20, 0) if PREDICTAND == 'tasmin' else (10, 40)\n", + "fig, axes = plt.subplots(len(y_pred_ex.data_vars), 3, figsize=(24, len(y_pred_ex.data_vars) * 6),\n", + " sharex=True, sharey=True)\n", + "axes = axes.reshape(len(y_pred_ex.data_vars), -1)\n", + "for i, var in enumerate(y_pred_ex):\n", + " for ds, ax in zip([y_true_ex, y_pred_ex, bias_ex], axes[i, ...]):\n", + " if ds is bias_ex:\n", + " ds = ds[var].mean(dim='year')\n", + " im2 = ax.imshow(ds.values, origin='lower', cmap='RdBu_r', vmin=-2, vmax=2)\n", + " ax.text(x=ds.shape[0] - 2, y=2, s='Average: {:.2f}°C'.format(ds.mean().item()), fontsize=14, ha='right')\n", + " else:\n", + " im1 = ax.imshow(ds[var].mean(dim='year').values, origin='lower', cmap='Blues_r' if PREDICTAND == 'tasmin' else 'Reds',\n", + " vmin=vmin, vmax=vmax)\n", + " \n", + "# set titles\n", + "axes[0, 0].set_title('Observed', fontsize=16, pad=10);\n", + "axes[0, 1].set_title('Predicted', fontsize=16, pad=10);\n", + "axes[0, 2].set_title('Bias', fontsize=16, pad=10);\n", + "\n", + "# adjust axes\n", + "for ax in axes.flat:\n", + " ax.axes.get_xaxis().set_ticklabels([])\n", + " ax.axes.get_xaxis().set_ticks([])\n", + " ax.axes.get_yaxis().set_ticklabels([])\n", + " ax.axes.get_yaxis().set_ticks([])\n", + " ax.axes.axis('tight')\n", + " ax.set_xlabel('')\n", + " ax.set_ylabel('')\n", + "\n", + "# adjust figure\n", + "fig.suptitle('Average P{:.0f} of {}: 1991 - 2010'.format(quantile * 100, NAMES[PREDICTAND]), fontsize=20);\n", + "fig.subplots_adjust(hspace=0, wspace=0, top=0.85)\n", + "\n", + "# add colorbar for bias\n", + "axes = axes.flatten()\n", + "cbar_ax = fig.add_axes([axes[-1].get_position().x1 + 0.01, axes[-1].get_position().y0,\n", + " 0.01, axes[-1].get_position().y1 - axes[-1].get_position().y0])\n", + "cbar = fig.colorbar(im2, cax=cbar_ax)\n", + "cbar.set_label(label='Bias / (°C)', fontsize=16)\n", + "cbar.ax.tick_params(labelsize=14)\n", + "\n", + "# add colorbar for predictand\n", + "cbar_ax_predictand = fig.add_axes([axes[0].get_position().x0, axes[0].get_position().y0 - 0.1,\n", + " axes[-1].get_position().x0 - axes[0].get_position().x0,\n", + " 0.05])\n", + "cbar_predictand = fig.colorbar(im1, cax=cbar_ax_predictand, orientation='horizontal')\n", + "cbar_predictand.set_label(label='{} / (°C)'.format(NAMES[PREDICTAND].capitalize()), fontsize=16)\n", + "cbar_predictand.ax.tick_params(labelsize=14)\n", + "\n", + "# add metrics: MAE and RMSE\n", + "axes[1].text(x=ds.shape[0] - 2, y=2, s='MAE = {:.2f}°C'.format(mae_ex[PREDICTAND].item()), fontsize=14, ha='right')\n", + "axes[1].text(x=ds.shape[0] - 2, y=12, s='RMSE = {:.2f}°C$^2$'.format(rmse_ex[PREDICTAND].item()), fontsize=14, ha='right')\n", + "\n", + "# save figure\n", + "fig.savefig('../Notebooks/Figures/{}_average_bias_p{:.0f}.png'.format(PREDICTAND, quantile * 100), dpi=300, bbox_inches='tight')" ] } ],