Skip to content
GitLab
Explore
Sign in
Primary navigation
Search or go to…
Project
C
Climax
Manage
Activity
Members
Labels
Plan
Issues
Issue boards
Milestones
Wiki
Code
Merge requests
Repository
Branches
Commits
Tags
Repository graph
Compare revisions
Snippets
Deploy
Releases
Package registry
Container Registry
Model registry
Operate
Terraform modules
Monitor
Incidents
Analyze
Value stream analytics
Contributor analytics
Repository analytics
Model experiments
Help
Help
Support
GitLab documentation
Compare GitLab plans
Community forum
Contribute to GitLab
Provide feedback
Keyboard shortcuts
?
Snippets
Groups
Projects
Show more breadcrumbs
earth_observation_public
Climax
Commits
87e4e1b8
Commit
87e4e1b8
authored
3 years ago
by
Frisinghelli Daniel
Browse files
Options
Downloads
Patches
Plain Diff
Preliminary validation metrics for temperature.
parent
971e181e
No related branches found
No related tags found
No related merge requests found
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
Notebooks/eval_temperature.ipynb
+240
-49
240 additions, 49 deletions
Notebooks/eval_temperature.ipynb
with
240 additions
and
49 deletions
Notebooks/eval_temperature.ipynb
+
240
−
49
View file @
87e4e1b8
...
...
@@ -32,7 +32,7 @@
"outputs": [],
"source": [
"# define the model parameters\n",
"PREDICTAND = 'tasm
in
'\n",
"PREDICTAND = 'tasm
ax
'\n",
"MODEL = 'USegNet'\n",
"PPREDICTORS = 'ztuvq'\n",
"PLEVELS = ['500', '850']\n",
...
...
@@ -70,6 +70,7 @@
"# builtins\n",
"import datetime\n",
"import warnings\n",
"import calendar\n",
"\n",
"# externals\n",
"import xarray as xr\n",
...
...
@@ -77,6 +78,7 @@
"import matplotlib.pyplot as plt\n",
"import scipy.stats as stats\n",
"from IPython.display import Image\n",
"from sklearn.metrics import r2_score\n",
"\n",
"# locals\n",
"from climax.main.io import ERA5_PATH, OBS_PATH, TARGET_PATH\n",
...
...
@@ -159,6 +161,89 @@
"## Model validation: temperature"
]
},
{
"cell_type": "markdown",
"id": "ab15d557-c7ea-40c0-9977-a3d410fea784",
"metadata": {},
"source": [
"### Coefficient of determination"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "619d5dc9-4d36-43a3-b23c-a4ea51229c78",
"metadata": {},
"outputs": [],
"source": [
"# get predicted and observed values over entire time series and grid points\n",
"y_pred_values = y_pred[PREDICTAND].values.flatten()\n",
"y_true_values = y_true[PREDICTAND].values.flatten()"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "49dff6ce-a629-460b-a43b-d1a0ef447351",
"metadata": {},
"outputs": [],
"source": [
"# apply mask of valid pixels\n",
"mask = (~np.isnan(y_pred_values) & ~np.isnan(y_true_values))\n",
"y_pred_values = y_pred_values[mask]\n",
"y_true_values = y_true_values[mask]"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "e9e46770-4176-4257-8ea2-7050d3325e98",
"metadata": {},
"outputs": [],
"source": [
"# calculate coefficient of determination\n",
"r2 = r2_score(y_true_values, y_pred_values)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "703ab604-5193-4032-92eb-80f2cff9fc2c",
"metadata": {},
"outputs": [],
"source": [
"# scatter plot of observations vs. predictions\n",
"fig, ax = plt.subplots(1, 1, figsize=(10, 10))\n",
"\n",
"# plot only a subset of data: otherwise plot is overloaded ...\n",
"subset = np.random.choice(np.arange(0, len(y_pred_values)), size=int(1e7), replace=False)\n",
"ax.plot(y_true_values[subset], y_pred_values[subset], 'o', alpha=.5, markeredgecolor='grey', markerfacecolor='none', markersize=3);\n",
"\n",
"# plot entire dataset\n",
"# ax.plot(y_true_values, y_pred_values, 'o', alpha=.5, markeredgecolor='grey', markerfacecolor='none', markersize=3);\n",
"\n",
"# plot 1:1 mapping line\n",
"interval = np.arange(-40, 45, 5)\n",
"ax.plot(interval, interval, color='k', lw=2, ls='--')\n",
"\n",
"# add coefficient of determination: calculated on entire dataset!\n",
"ax.text(interval[-1] - 1, interval[0] + 1, s='Coefficient of determination R$^2$ = {:.2f}'.format(r2), ha='right', fontsize=14)\n",
"\n",
"# format axes\n",
"ax.set_ylim(-40, 40)\n",
"ax.set_xlim(-40, 40)\n",
"ax.set_xticks(interval)\n",
"ax.set_xticklabels(interval, fontsize=14)\n",
"ax.set_yticks(interval)\n",
"ax.set_yticklabels(interval, fontsize=14)\n",
"ax.set_xlabel('Observed', fontsize=14)\n",
"ax.set_ylabel('Predicted', fontsize=14)\n",
"ax.set_title('{} (°C): 1991 - 2010'.format(NAMES[PREDICTAND].capitalize()), fontsize=16, pad=10);\n",
"\n",
"# save figure\n",
"fig.savefig('../Notebooks/Figures/{}_r2.png'.format(PREDICTAND), dpi=300, bbox_inches='tight')"
]
},
{
"cell_type": "markdown",
"id": "5e8be24e-8ca2-4582-98c0-b56c6db289d2",
...
...
@@ -186,8 +271,8 @@
"y_pred_yearly_avg = y_pred.groupby('time.year').mean(dim='time')\n",
"y_true_yearly_avg = y_true.groupby('time.year').mean(dim='time')\n",
"bias_yearly_avg = y_pred_yearly_avg - y_true_yearly_avg\n",
"for var in bias:\n",
" print('Yearly average bias {}: {:.2f}'.format(var, bias_yearly_avg[var].mean().item()))"
"for var in bias
_yearly_avg
:\n",
" print('Yearly average bias
of
{}: {:.2f}'.format(var, bias_yearly_avg[var].mean().item()))"
]
},
{
...
...
@@ -199,8 +284,8 @@
"source": [
"# mean absolute error over reference period\n",
"mae_avg = np.abs(y_pred_yearly_avg - y_true_yearly_avg).mean()\n",
"for var in mae:\n",
" print('Yearly average
mean absolute error
{}: {:.2f}'.format(var, mae[var].item()))"
"for var in mae
_avg
:\n",
" print('Yearly average
MAE of
{}: {:.2f}'.format(var, mae
_avg
[var].item()))"
]
},
{
...
...
@@ -212,8 +297,8 @@
"source": [
"# root mean squared error over reference period\n",
"rmse_avg = ((y_pred_yearly_avg - y_true_yearly_avg) ** 2).mean()\n",
"for var in rmse:\n",
" print('
Root mean squared error
{}: {:.2f}'.format(var, rmse[var].item()))"
"for var in rmse
_avg
:\n",
" print('
Yearly average RMSE of
{}: {:.2f}'.format(var, rmse
_avg
[var].item()))"
]
},
{
...
...
@@ -242,6 +327,7 @@
"outputs": [],
"source": [
"# plot average of observation, prediction, and bias\n",
"vmin, vmax = (-15, 15) if PREDICTAND == 'tasmin' else (0, 25)\n",
"fig, axes = plt.subplots(len(y_pred_yearly_avg.data_vars), 3, figsize=(24, len(y_pred_yearly_avg.data_vars) * 6),\n",
" sharex=True, sharey=True)\n",
"axes = axes.reshape(len(y_pred_yearly_avg.data_vars), -1)\n",
...
...
@@ -252,12 +338,12 @@
" im2 = ax.imshow(ds.values, origin='lower', cmap='RdBu_r', vmin=-2, vmax=2)\n",
" ax.text(x=ds.shape[0] - 2, y=2, s='Average: {:.2f}°C'.format(ds.mean().item()), fontsize=14, ha='right')\n",
" else:\n",
" im1 = ax.imshow(ds[var].mean(dim='year').values, origin='lower', cmap='RdYlBu_r', vmin=
-15
, vmax=
15
)\n",
" im1 = ax.imshow(ds[var].mean(dim='year').values, origin='lower', cmap='RdYlBu_r', vmin=
vmin
, vmax=
vmax
)\n",
" \n",
"# set titles\n",
"axes[0, 0].set_title('Observed', fontsize=16);\n",
"axes[0, 1].set_title('Predicted', fontsize=16);\n",
"axes[0, 2].set_title('Bias', fontsize=16);\n",
"axes[0, 0].set_title('Observed', fontsize=16
, pad=10
);\n",
"axes[0, 1].set_title('Predicted', fontsize=16
, pad=10
);\n",
"axes[0, 2].set_title('Bias', fontsize=16
, pad=10
);\n",
"\n",
"# adjust axes\n",
"for ax in axes.flat:\n",
...
...
@@ -275,11 +361,23 @@
"\n",
"# add colorbar for bias\n",
"axes = axes.flatten()\n",
"cbar_ax = fig.add_axes([axes[-1].get_position().x1 + 0.01, axes[-1].get_position().y0,\n",
" 0.01, axes[-1].get_position().y1 - axes[-1].get_position().y0])\n",
"cbar = fig.colorbar(im2, cax=cbar_ax)\n",
"cbar.set_label(label='Bias / (°C)', fontsize=16)\n",
"cbar.ax.tick_params(labelsize=14)\n",
"cbar_ax_bias = fig.add_axes([axes[-1].get_position().x1 + 0.01, axes[-1].get_position().y0,\n",
" 0.01, axes[-1].get_position().y1 - axes[-1].get_position().y0])\n",
"cbar_bias = fig.colorbar(im2, cax=cbar_ax_bias)\n",
"cbar_bias.set_label(label='Bias / (°C)', fontsize=16)\n",
"cbar_bias.ax.tick_params(labelsize=14)\n",
"\n",
"# add colorbar for predictand\n",
"cbar_ax_predictand = fig.add_axes([axes[0].get_position().x0, axes[0].get_position().y0 - 0.1,\n",
" axes[-1].get_position().x0 - axes[0].get_position().x0,\n",
" 0.05])\n",
"cbar_predictand = fig.colorbar(im1, cax=cbar_ax_predictand, orientation='horizontal')\n",
"cbar_predictand.set_label(label='{} / (°C)'.format(NAMES[PREDICTAND].capitalize()), fontsize=16)\n",
"cbar_predictand.ax.tick_params(labelsize=14)\n",
"\n",
"# add metrics: MAE and RMSE\n",
"axes[1].text(x=ds.shape[0] - 2, y=2, s='MAE = {:.2f}°C'.format(mae_avg[PREDICTAND].item()), fontsize=14, ha='right')\n",
"axes[1].text(x=ds.shape[0] - 2, y=12, s='RMSE = {:.2f}°C$^2$'.format(rmse_avg[PREDICTAND].item()), fontsize=14, ha='right')\n",
"\n",
"# save figure\n",
"fig.savefig('../Notebooks/Figures/{}_average_bias.png'.format(PREDICTAND), dpi=300, bbox_inches='tight')"
...
...
@@ -323,7 +421,7 @@
"source": [
"# print average bias per season\n",
"for var in bias_snl.data_vars:\n",
" for season in bias_snl
.tasmin
.season:\n",
" for season in bias_snl
[PREDICTAND]
.season:\n",
" print('Average bias of {} for season {}: {:.2f}'.format(var, season.values.item(), bias_snl[var].sel(season=season).mean().item()))"
]
},
...
...
@@ -388,20 +486,54 @@
},
{
"cell_type": "markdown",
"id": "
c70b3
69
d
-2
d16-42e3-9300-4a18757ad1b2
",
"id": "
416002
69-2
f8c-4717-8f74-b3dfaef60359
",
"metadata": {},
"source": [
"
### Bias of extreme values
"
"
Calculate the mean annual cycle:
"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "
4acfc3f2-20ed-498c-ab35-f392ae0e64f9
",
"id": "
c9f27c01-4dfc-4d16-8d29-00e69b7794cd
",
"metadata": {},
"outputs": [],
"source": [
"# TODO: smooth quantiles"
"# group timeseries by month and calculate mean over time and space\n",
"y_pred_ac = y_pred.groupby('time.month').mean(dim=('time', 'y', 'x'))\n",
"y_true_ac = y_true.groupby('time.month').mean(dim=('time', 'y', 'x'))"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "bcc63e73-2636-49c1-a66b-241eb5407e2e",
"metadata": {},
"outputs": [],
"source": [
"# plot mean annual cycle\n",
"fig, ax = plt.subplots(1, 1, figsize=(10, 10))\n",
"ax.plot(y_pred_ac[PREDICTAND].values, ls='--', color='k', label='Predicted')\n",
"ax.plot(y_true_ac[PREDICTAND].values, ls='-', color='k', label='Observed')\n",
"ax.legend(frameon=False, fontsize=14);\n",
"ax.set_yticks(np.arange(np.floor(y_true_ac[PREDICTAND].min().item()), np.ceil(y_true_ac[PREDICTAND].max().item()) + 1, 1))\n",
"ax.set_yticklabels(np.arange(np.floor(y_true_ac[PREDICTAND].min().item()), np.ceil(y_true_ac[PREDICTAND].max().item()) + 1, 1), fontsize=12)\n",
"ax.set_xticks(np.arange(0, 12))\n",
"ax.set_xticklabels([calendar.month_name[i + 1] for i in np.arange(0, 12)], rotation=90, fontsize=12)\n",
"ax.set_title('Mean annual cycle of {}: 1991 - 2010'.format(NAMES[PREDICTAND]), pad=20, fontsize=16);\n",
"ax.set_ylabel('{} / (°C)'.format(NAMES[PREDICTAND].capitalize()), fontsize=14)\n",
"ax.set_xlabel('Month', fontsize=14);\n",
"\n",
"# save figure\n",
"fig.savefig('../Notebooks/Figures/{}_mean_annual_cycle.png'.format(PREDICTAND), dpi=300, bbox_inches='tight')"
]
},
{
"cell_type": "markdown",
"id": "c70b369d-2d16-42e3-9300-4a18757ad1b2",
"metadata": {},
"source": [
"### Bias of extreme values"
]
},
{
...
...
@@ -411,67 +543,126 @@
"metadata": {},
"outputs": [],
"source": [
"#
perce
ntile
s
of interest\n",
"
perce
ntile
s
=
[
0.0
1, 0.02, 0.98, 0.99]
"
"#
extreme qua
ntile of interest\n",
"
qua
ntile = 0.0
2 if PREDICTAND == 'tasmin' else 0.98
"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "
51137d23-a380-4d48-a005-fd1edaf554eb
",
"id": "
c3da76d3-7261-4084-b5ec-f65682fd6596
",
"metadata": {},
"outputs": [],
"source": [
"# calculate
percentiles over reference period
\n",
"# calculate
extreme quantile for each year
\n",
"with warnings.catch_warnings():\n",
" warnings.simplefilter('ignore')\n",
" y_pred_dist = y_pred.quantile(q=percentiles, dim='time')\n",
" y_true_dist = y_true.quantile(q=percentiles, dim='time')"
" warnings.simplefilter('ignore', category=RuntimeWarning)\n",
" y_pred_ex = y_pred.groupby('time.year').quantile(quantile, dim='time')\n",
" y_true_ex = y_true.groupby('time.year').quantile(quantile, dim='time')"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "20db96c8-e04f-4acb-886b-abb740863fbb",
"metadata": {},
"outputs": [],
"source": [
"# calculate bias in extreme quantile for each year\n",
"bias_ex = y_pred_ex - y_true_ex\n",
"for var in bias_ex:\n",
" print('Yearly average bias for P{:.0f} of {}: {:.2f}'.format(quantile * 100, var, bias_ex[var].mean().item()))"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "
b461c774-c5fc-4a50-8609-34a7e7674a34
",
"id": "
c7aeff55-deef-4d3c-a251-eac877c9afd9
",
"metadata": {},
"outputs": [],
"source": [
"# calculate bias in each percentile over entire reference period\n",
"bias_dist = y_pred_dist - y_true_dist"
"# mean absolute error in extreme quantile\n",
"mae_ex = np.abs(y_pred_ex - y_true_ex).mean()\n",
"for var in mae_ex:\n",
" print('Yearly average MAE for P{:.0f} of {}: {:.2f}'.format(quantile * 100, var, mae_avg[var].item()))"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "
bd2ec314-87ed-407f-af8b-5e2c6785e9cc
",
"id": "
2dc6b7cf-7ae1-4b0e-8483-376eab59f5dd
",
"metadata": {},
"outputs": [],
"source": [
"# calculate correlation coefficient for extreme values\n",
"for var in y_pred_dist:\n",
" for q in percentiles:\n",
" y_p = y_pred_dist[var].sel(quantile=q).values[~np.isnan(y_pred_dist[var].sel(quantile=q))]\n",
" y_t = y_true_dist[var].sel(quantile=q).values[~np.isnan(y_true_dist[var].sel(quantile=q))]\n",
" r, _ = stats.pearsonr(y_p, y_t)\n",
" print('Pearson correlation for {}, q={:.2f}: R={:.2f}'.format(var, q, r))"
"# root mean squared error over reference period\n",
"rmse_ex = ((y_pred_ex - y_true_ex) ** 2).mean()\n",
"for var in rmse_ex:\n",
" print('Yearly average RMSE for P{:.0f} of {}: {:.2f}'.format(quantile * 100, var, rmse_ex[var].item()))"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "
04a1b4fc-8fc0-4e1e-adbe-bf1eaafeac5d
",
"id": "
cfc196fa-0999-4603-959c-2c82f038c8fa
",
"metadata": {},
"outputs": [],
"source": [
"# plot bias in each percentile\n",
"fig, axes = plt.subplots(len(y_pred_dist.data_vars), len(percentiles), sharex=True, sharey=True, figsize=(32, 6))\n",
"axes = axes.reshape(len(y_pred_dist.data_vars), -1)\n",
"for ax, var in zip(axes, y_pred_dist):\n",
" # iterate over percentiles\n",
" for axis, q in zip(ax, percentiles):\n",
" ds = bias_dist.sel(quantile=q).to_array()\n",
" ds.plot(ax=axis, vmin=-2, vmax=2, cmap='RdBu_r') \n",
" axis.text(x=bias_dist.x[-1], y=bias_dist.y[0], s='Avg: {:.2f}'.format(ds.mean().item()), ha='right', va='bottom')"
"# plot extremes of observation, prediction, and bias\n",
"vmin, vmax = (-20, 0) if PREDICTAND == 'tasmin' else (10, 40)\n",
"fig, axes = plt.subplots(len(y_pred_ex.data_vars), 3, figsize=(24, len(y_pred_ex.data_vars) * 6),\n",
" sharex=True, sharey=True)\n",
"axes = axes.reshape(len(y_pred_ex.data_vars), -1)\n",
"for i, var in enumerate(y_pred_ex):\n",
" for ds, ax in zip([y_true_ex, y_pred_ex, bias_ex], axes[i, ...]):\n",
" if ds is bias_ex:\n",
" ds = ds[var].mean(dim='year')\n",
" im2 = ax.imshow(ds.values, origin='lower', cmap='RdBu_r', vmin=-2, vmax=2)\n",
" ax.text(x=ds.shape[0] - 2, y=2, s='Average: {:.2f}°C'.format(ds.mean().item()), fontsize=14, ha='right')\n",
" else:\n",
" im1 = ax.imshow(ds[var].mean(dim='year').values, origin='lower', cmap='Blues_r' if PREDICTAND == 'tasmin' else 'Reds',\n",
" vmin=vmin, vmax=vmax)\n",
" \n",
"# set titles\n",
"axes[0, 0].set_title('Observed', fontsize=16, pad=10);\n",
"axes[0, 1].set_title('Predicted', fontsize=16, pad=10);\n",
"axes[0, 2].set_title('Bias', fontsize=16, pad=10);\n",
"\n",
"# adjust axes\n",
"for ax in axes.flat:\n",
" ax.axes.get_xaxis().set_ticklabels([])\n",
" ax.axes.get_xaxis().set_ticks([])\n",
" ax.axes.get_yaxis().set_ticklabels([])\n",
" ax.axes.get_yaxis().set_ticks([])\n",
" ax.axes.axis('tight')\n",
" ax.set_xlabel('')\n",
" ax.set_ylabel('')\n",
"\n",
"# adjust figure\n",
"fig.suptitle('Average P{:.0f} of {}: 1991 - 2010'.format(quantile * 100, NAMES[PREDICTAND]), fontsize=20);\n",
"fig.subplots_adjust(hspace=0, wspace=0, top=0.85)\n",
"\n",
"# add colorbar for bias\n",
"axes = axes.flatten()\n",
"cbar_ax = fig.add_axes([axes[-1].get_position().x1 + 0.01, axes[-1].get_position().y0,\n",
" 0.01, axes[-1].get_position().y1 - axes[-1].get_position().y0])\n",
"cbar = fig.colorbar(im2, cax=cbar_ax)\n",
"cbar.set_label(label='Bias / (°C)', fontsize=16)\n",
"cbar.ax.tick_params(labelsize=14)\n",
"\n",
"# add colorbar for predictand\n",
"cbar_ax_predictand = fig.add_axes([axes[0].get_position().x0, axes[0].get_position().y0 - 0.1,\n",
" axes[-1].get_position().x0 - axes[0].get_position().x0,\n",
" 0.05])\n",
"cbar_predictand = fig.colorbar(im1, cax=cbar_ax_predictand, orientation='horizontal')\n",
"cbar_predictand.set_label(label='{} / (°C)'.format(NAMES[PREDICTAND].capitalize()), fontsize=16)\n",
"cbar_predictand.ax.tick_params(labelsize=14)\n",
"\n",
"# add metrics: MAE and RMSE\n",
"axes[1].text(x=ds.shape[0] - 2, y=2, s='MAE = {:.2f}°C'.format(mae_ex[PREDICTAND].item()), fontsize=14, ha='right')\n",
"axes[1].text(x=ds.shape[0] - 2, y=12, s='RMSE = {:.2f}°C$^2$'.format(rmse_ex[PREDICTAND].item()), fontsize=14, ha='right')\n",
"\n",
"# save figure\n",
"fig.savefig('../Notebooks/Figures/{}_average_bias_p{:.0f}.png'.format(PREDICTAND, quantile * 100), dpi=300, bbox_inches='tight')"
]
}
],
...
...
%% Cell type:markdown id:f15afea1-9ea4-4201-bdd7-32ae377db6a9 tags:
# Evaluate ERA-5 downscaling
%% Cell type:markdown id:7d2c4c86-18e4-44c4-bdf7-0e8249614749 tags:
We used
**1981-1991 as training**
period and
**1991-2010 as reference**
period. The results shown in this notebook are based on the model predictions on the reference period.
%% Cell type:markdown id:4186c89c-b55b-4559-a818-8b712baaf44e tags:
Define the predictand and the model to evaluate:
%% Cell type:code id:bb24b6ed-2d0a-44e0-b9a9-abdcb2a8294d tags:
```
python
# define the model parameters
PREDICTAND
=
'
tasm
in
'
PREDICTAND
=
'
tasm
ax
'
MODEL
=
'
USegNet
'
PPREDICTORS
=
'
ztuvq
'
PLEVELS
=
[
'
500
'
,
'
850
'
]
SPREDICTORS
=
'
p
'
DEM
=
'
dem
'
DOY
=
'
doy
'
```
%% Cell type:code id:686a7c1d-e71d-4954-8d79-26b9a84f648e tags:
```
python
# mapping from predictands to variable names
NAMES
=
{
'
tasmin
'
:
'
minimum temperature
'
,
'
tasmax
'
:
'
maximum temperature
'
,
'
pr
'
:
'
precipitation
'
}
```
%% Cell type:markdown id:4d5d12c5-50fd-4c5c-9240-c3df78e49b44 tags:
### Imports
%% Cell type:code id:1afb0fab-5d2a-4875-9032-29b99c6dec89 tags:
```
python
# builtins
import
datetime
import
warnings
import
calendar
# externals
import
xarray
as
xr
import
numpy
as
np
import
matplotlib.pyplot
as
plt
import
scipy.stats
as
stats
from
IPython.display
import
Image
from
sklearn.metrics
import
r2_score
# locals
from
climax.main.io
import
ERA5_PATH
,
OBS_PATH
,
TARGET_PATH
from
pysegcnn.core.utils
import
search_files
```
%% Cell type:markdown id:2c2d0b1d-9630-4cfe-ae9d-35b05b94a1a5 tags:
### Model architecture
%% Cell type:code id:f98e1c9f-0581-43ea-a569-fd05fdaf36c1 tags:
```
python
Image
(
"
./Figures/architecture.png
"
,
width
=
900
,
height
=
400
)
```
%% Cell type:markdown id:f47caa41-9380-4c02-8785-4febcf2cb2d0 tags:
### Load datasets
%% Cell type:code id:020dfe33-ce3c-467f-ad0a-295cc338b1a9 tags:
```
python
# model predictions and observations NetCDF
y_pred
=
TARGET_PATH
.
joinpath
(
PREDICTAND
,
'
_
'
.
join
([
MODEL
,
PREDICTAND
,
PPREDICTORS
,
*
PLEVELS
,
SPREDICTORS
,
DEM
,
DOY
])
+
'
.nc
'
)
if
PREDICTAND
==
'
tas
'
:
# read both tasmax and tasmin
tasmax
=
xr
.
open_dataset
(
search_files
(
OBS_PATH
.
joinpath
(
'
tasmax
'
),
'
.nc$
'
).
pop
())
tasmin
=
xr
.
open_dataset
(
search_files
(
OBS_PATH
.
joinpath
(
'
tasmin
'
),
'
.nc$
'
).
pop
())
y_true
=
xr
.
merge
([
tasmax
,
tasmin
])
else
:
y_true
=
xr
.
open_dataset
(
search_files
(
OBS_PATH
.
joinpath
(
PREDICTAND
),
'
.nc$
'
).
pop
())
```
%% Cell type:code id:1dc2a386-d63b-4c6a-8e63-00365927559d tags:
```
python
# load datasets
y_pred
=
xr
.
open_dataset
(
y_pred
)
y_true
=
y_true
.
sel
(
time
=
y_pred
.
time
)
# subset to time period covered by predictions
```
%% Cell type:code id:966d85fb-9185-408f-ac2b-1e4ca829ccd1 tags:
```
python
# align datasets and mask missing values in model predictions
y_true
,
y_pred
=
xr
.
align
(
y_true
,
y_pred
,
join
=
'
override
'
)
y_pred
=
y_pred
.
where
(
~
np
.
isnan
(
y_true
),
other
=
np
.
nan
)
```
%% Cell type:markdown id:ddebdf9f-862c-461e-aa57-cd344d54eee9 tags:
## Model validation: temperature
%% Cell type:markdown id:ab15d557-c7ea-40c0-9977-a3d410fea784 tags:
### Coefficient of determination
%% Cell type:code id:619d5dc9-4d36-43a3-b23c-a4ea51229c78 tags:
```
python
# get predicted and observed values over entire time series and grid points
y_pred_values
=
y_pred
[
PREDICTAND
].
values
.
flatten
()
y_true_values
=
y_true
[
PREDICTAND
].
values
.
flatten
()
```
%% Cell type:code id:49dff6ce-a629-460b-a43b-d1a0ef447351 tags:
```
python
# apply mask of valid pixels
mask
=
(
~
np
.
isnan
(
y_pred_values
)
&
~
np
.
isnan
(
y_true_values
))
y_pred_values
=
y_pred_values
[
mask
]
y_true_values
=
y_true_values
[
mask
]
```
%% Cell type:code id:e9e46770-4176-4257-8ea2-7050d3325e98 tags:
```
python
# calculate coefficient of determination
r2
=
r2_score
(
y_true_values
,
y_pred_values
)
```
%% Cell type:code id:703ab604-5193-4032-92eb-80f2cff9fc2c tags:
```
python
# scatter plot of observations vs. predictions
fig
,
ax
=
plt
.
subplots
(
1
,
1
,
figsize
=
(
10
,
10
))
# plot only a subset of data: otherwise plot is overloaded ...
subset
=
np
.
random
.
choice
(
np
.
arange
(
0
,
len
(
y_pred_values
)),
size
=
int
(
1e7
),
replace
=
False
)
ax
.
plot
(
y_true_values
[
subset
],
y_pred_values
[
subset
],
'
o
'
,
alpha
=
.
5
,
markeredgecolor
=
'
grey
'
,
markerfacecolor
=
'
none
'
,
markersize
=
3
);
# plot entire dataset
# ax.plot(y_true_values, y_pred_values, 'o', alpha=.5, markeredgecolor='grey', markerfacecolor='none', markersize=3);
# plot 1:1 mapping line
interval
=
np
.
arange
(
-
40
,
45
,
5
)
ax
.
plot
(
interval
,
interval
,
color
=
'
k
'
,
lw
=
2
,
ls
=
'
--
'
)
# add coefficient of determination: calculated on entire dataset!
ax
.
text
(
interval
[
-
1
]
-
1
,
interval
[
0
]
+
1
,
s
=
'
Coefficient of determination R$^2$ = {:.2f}
'
.
format
(
r2
),
ha
=
'
right
'
,
fontsize
=
14
)
# format axes
ax
.
set_ylim
(
-
40
,
40
)
ax
.
set_xlim
(
-
40
,
40
)
ax
.
set_xticks
(
interval
)
ax
.
set_xticklabels
(
interval
,
fontsize
=
14
)
ax
.
set_yticks
(
interval
)
ax
.
set_yticklabels
(
interval
,
fontsize
=
14
)
ax
.
set_xlabel
(
'
Observed
'
,
fontsize
=
14
)
ax
.
set_ylabel
(
'
Predicted
'
,
fontsize
=
14
)
ax
.
set_title
(
'
{} (°C): 1991 - 2010
'
.
format
(
NAMES
[
PREDICTAND
].
capitalize
()),
fontsize
=
16
,
pad
=
10
);
# save figure
fig
.
savefig
(
'
../Notebooks/Figures/{}_r2.png
'
.
format
(
PREDICTAND
),
dpi
=
300
,
bbox_inches
=
'
tight
'
)
```
%% Cell type:markdown id:5e8be24e-8ca2-4582-98c0-b56c6db289d2 tags:
### Bias
%% Cell type:markdown id:a4f7177c-7d09-401f-957b-0e493b9ef5d0 tags:
Calculate yearly average bias over entire reference period:
%% Cell type:code id:746bf95f-a78b-4da8-a063-1fa48e3c5da8 tags:
```
python
# yearly average bias over reference period
y_pred_yearly_avg
=
y_pred
.
groupby
(
'
time.year
'
).
mean
(
dim
=
'
time
'
)
y_true_yearly_avg
=
y_true
.
groupby
(
'
time.year
'
).
mean
(
dim
=
'
time
'
)
bias_yearly_avg
=
y_pred_yearly_avg
-
y_true_yearly_avg
for
var
in
bias
:
print
(
'
Yearly average bias {}: {:.2f}
'
.
format
(
var
,
bias_yearly_avg
[
var
].
mean
().
item
()))
for
var
in
bias
_yearly_avg
:
print
(
'
Yearly average bias
of
{}: {:.2f}
'
.
format
(
var
,
bias_yearly_avg
[
var
].
mean
().
item
()))
```
%% Cell type:code id:2ef19dce-29ac-4f6f-9999-e67745a4afd1 tags:
```
python
# mean absolute error over reference period
mae_avg
=
np
.
abs
(
y_pred_yearly_avg
-
y_true_yearly_avg
).
mean
()
for
var
in
mae
:
print
(
'
Yearly average
mean absolute error
{}: {:.2f}
'
.
format
(
var
,
mae
[
var
].
item
()))
for
var
in
mae
_avg
:
print
(
'
Yearly average
MAE of
{}: {:.2f}
'
.
format
(
var
,
mae
_avg
[
var
].
item
()))
```
%% Cell type:code id:4c4dd156-f763-482c-84e6-c4329bfd3fe4 tags:
```
python
# root mean squared error over reference period
rmse_avg
=
((
y_pred_yearly_avg
-
y_true_yearly_avg
)
**
2
).
mean
()
for
var
in
rmse
:
print
(
'
Root mean squared error
{}: {:.2f}
'
.
format
(
var
,
rmse
[
var
].
item
()))
for
var
in
rmse
_avg
:
print
(
'
Yearly average RMSE of
{}: {:.2f}
'
.
format
(
var
,
rmse
_avg
[
var
].
item
()))
```
%% Cell type:code id:ed54a50e-41c2-4a7f-9839-05357996f0c4 tags:
```
python
# Pearson's correlation coefficient over reference period
for
var
in
y_pred_yearly_avg
:
correlations
=
[]
for
year
in
y_pred_yearly_avg
.
year
:
y_p
=
y_pred_yearly_avg
[
var
].
sel
(
year
=
year
).
values
y_t
=
y_true_yearly_avg
[
var
].
sel
(
year
=
year
).
values
r
,
_
=
stats
.
pearsonr
(
y_p
[
~
np
.
isnan
(
y_p
)],
y_t
[
~
np
.
isnan
(
y_t
)])
correlations
.
append
(
r
)
print
(
'
Yearly average Pearson correlation coefficient for {}: {:.2f}
'
.
format
(
var
,
np
.
asarray
(
r
).
mean
()))
```
%% Cell type:code id:760f86ce-9e04-4938-b24f-d2819fbf622e tags:
```
python
# plot average of observation, prediction, and bias
vmin
,
vmax
=
(
-
15
,
15
)
if
PREDICTAND
==
'
tasmin
'
else
(
0
,
25
)
fig
,
axes
=
plt
.
subplots
(
len
(
y_pred_yearly_avg
.
data_vars
),
3
,
figsize
=
(
24
,
len
(
y_pred_yearly_avg
.
data_vars
)
*
6
),
sharex
=
True
,
sharey
=
True
)
axes
=
axes
.
reshape
(
len
(
y_pred_yearly_avg
.
data_vars
),
-
1
)
for
i
,
var
in
enumerate
(
y_pred_yearly_avg
):
for
ds
,
ax
in
zip
([
y_true_yearly_avg
,
y_pred_yearly_avg
,
bias_yearly_avg
],
axes
[
i
,
...]):
if
ds
is
bias_yearly_avg
:
ds
=
ds
[
var
].
mean
(
dim
=
'
year
'
)
im2
=
ax
.
imshow
(
ds
.
values
,
origin
=
'
lower
'
,
cmap
=
'
RdBu_r
'
,
vmin
=-
2
,
vmax
=
2
)
ax
.
text
(
x
=
ds
.
shape
[
0
]
-
2
,
y
=
2
,
s
=
'
Average: {:.2f}°C
'
.
format
(
ds
.
mean
().
item
()),
fontsize
=
14
,
ha
=
'
right
'
)
else
:
im1
=
ax
.
imshow
(
ds
[
var
].
mean
(
dim
=
'
year
'
).
values
,
origin
=
'
lower
'
,
cmap
=
'
RdYlBu_r
'
,
vmin
=
-
15
,
vmax
=
15
)
im1
=
ax
.
imshow
(
ds
[
var
].
mean
(
dim
=
'
year
'
).
values
,
origin
=
'
lower
'
,
cmap
=
'
RdYlBu_r
'
,
vmin
=
vmin
,
vmax
=
vmax
)
# set titles
axes
[
0
,
0
].
set_title
(
'
Observed
'
,
fontsize
=
16
);
axes
[
0
,
1
].
set_title
(
'
Predicted
'
,
fontsize
=
16
);
axes
[
0
,
2
].
set_title
(
'
Bias
'
,
fontsize
=
16
);
axes
[
0
,
0
].
set_title
(
'
Observed
'
,
fontsize
=
16
,
pad
=
10
);
axes
[
0
,
1
].
set_title
(
'
Predicted
'
,
fontsize
=
16
,
pad
=
10
);
axes
[
0
,
2
].
set_title
(
'
Bias
'
,
fontsize
=
16
,
pad
=
10
);
# adjust axes
for
ax
in
axes
.
flat
:
ax
.
axes
.
get_xaxis
().
set_ticklabels
([])
ax
.
axes
.
get_xaxis
().
set_ticks
([])
ax
.
axes
.
get_yaxis
().
set_ticklabels
([])
ax
.
axes
.
get_yaxis
().
set_ticks
([])
ax
.
axes
.
axis
(
'
tight
'
)
ax
.
set_xlabel
(
''
)
ax
.
set_ylabel
(
''
)
# adjust figure
fig
.
suptitle
(
'
Average {}: 1991 - 2010
'
.
format
(
NAMES
[
PREDICTAND
]),
fontsize
=
20
);
fig
.
subplots_adjust
(
hspace
=
0
,
wspace
=
0
,
top
=
0.85
)
# add colorbar for bias
axes
=
axes
.
flatten
()
cbar_ax
=
fig
.
add_axes
([
axes
[
-
1
].
get_position
().
x1
+
0.01
,
axes
[
-
1
].
get_position
().
y0
,
0.01
,
axes
[
-
1
].
get_position
().
y1
-
axes
[
-
1
].
get_position
().
y0
])
cbar
=
fig
.
colorbar
(
im2
,
cax
=
cbar_ax
)
cbar
.
set_label
(
label
=
'
Bias / (°C)
'
,
fontsize
=
16
)
cbar
.
ax
.
tick_params
(
labelsize
=
14
)
cbar_ax_bias
=
fig
.
add_axes
([
axes
[
-
1
].
get_position
().
x1
+
0.01
,
axes
[
-
1
].
get_position
().
y0
,
0.01
,
axes
[
-
1
].
get_position
().
y1
-
axes
[
-
1
].
get_position
().
y0
])
cbar_bias
=
fig
.
colorbar
(
im2
,
cax
=
cbar_ax_bias
)
cbar_bias
.
set_label
(
label
=
'
Bias / (°C)
'
,
fontsize
=
16
)
cbar_bias
.
ax
.
tick_params
(
labelsize
=
14
)
# add colorbar for predictand
cbar_ax_predictand
=
fig
.
add_axes
([
axes
[
0
].
get_position
().
x0
,
axes
[
0
].
get_position
().
y0
-
0.1
,
axes
[
-
1
].
get_position
().
x0
-
axes
[
0
].
get_position
().
x0
,
0.05
])
cbar_predictand
=
fig
.
colorbar
(
im1
,
cax
=
cbar_ax_predictand
,
orientation
=
'
horizontal
'
)
cbar_predictand
.
set_label
(
label
=
'
{} / (°C)
'
.
format
(
NAMES
[
PREDICTAND
].
capitalize
()),
fontsize
=
16
)
cbar_predictand
.
ax
.
tick_params
(
labelsize
=
14
)
# add metrics: MAE and RMSE
axes
[
1
].
text
(
x
=
ds
.
shape
[
0
]
-
2
,
y
=
2
,
s
=
'
MAE = {:.2f}°C
'
.
format
(
mae_avg
[
PREDICTAND
].
item
()),
fontsize
=
14
,
ha
=
'
right
'
)
axes
[
1
].
text
(
x
=
ds
.
shape
[
0
]
-
2
,
y
=
12
,
s
=
'
RMSE = {:.2f}°C$^2$
'
.
format
(
rmse_avg
[
PREDICTAND
].
item
()),
fontsize
=
14
,
ha
=
'
right
'
)
# save figure
fig
.
savefig
(
'
../Notebooks/Figures/{}_average_bias.png
'
.
format
(
PREDICTAND
),
dpi
=
300
,
bbox_inches
=
'
tight
'
)
```
%% Cell type:markdown id:aa4ef730-5a9d-40dc-a318-2f43a4cf1cd2 tags:
### Seasonal bias
%% Cell type:markdown id:eda455a2-e8ee-4644-bb85-b0cf76acd11a tags:
Calculate seasonal bias:
%% Cell type:code id:24aadff5-b19d-4f4b-a4c5-32ee656e64cd tags:
```
python
# group data by season: (DJF, MAM, JJA, SON)
y_true_snl
=
y_true
.
groupby
(
'
time.season
'
).
mean
(
dim
=
'
time
'
)
y_pred_snl
=
y_pred
.
groupby
(
'
time.season
'
).
mean
(
dim
=
'
time
'
)
bias_snl
=
y_pred_snl
-
y_true_snl
```
%% Cell type:code id:2d0d5a46-0652-4289-9e1c-6d47aafcfef0 tags:
```
python
# print average bias per season
for
var
in
bias_snl
.
data_vars
:
for
season
in
bias_snl
.
tasmin
.
season
:
for
season
in
bias_snl
[
PREDICTAND
]
.
season
:
print
(
'
Average bias of {} for season {}: {:.2f}
'
.
format
(
var
,
season
.
values
.
item
(),
bias_snl
[
var
].
sel
(
season
=
season
).
mean
().
item
()))
```
%% Cell type:markdown id:c4232ae9-a557-4d61-8ab3-d7eda6201f98 tags:
Plot seasonal differences, taken from the
[
xarray documentation
](
xarray.pydata.org/en/stable/examples/monthly-means.html
)
.
%% Cell type:code id:b39a6cc0-614c-452d-bb85-bc10e5179948 tags:
```
python
# plot seasonal differences
seasons
=
(
'
DJF
'
,
'
JJA
'
)
fig
,
axes
=
plt
.
subplots
(
nrows
=
1
,
ncols
=
len
(
seasons
)
+
1
,
figsize
=
(
24
,
8
),
sharex
=
True
,
sharey
=
True
)
axes
=
axes
.
flatten
()
# plot annual average bias
ds
=
bias_yearly_avg
[
PREDICTAND
].
mean
(
dim
=
'
year
'
)
axes
[
0
].
imshow
(
ds
.
values
,
origin
=
'
lower
'
,
cmap
=
'
RdBu_r
'
,
vmin
=-
2
,
vmax
=
2
)
axes
[
0
].
set_title
(
'
Annual
'
,
fontsize
=
16
);
axes
[
0
].
text
(
x
=
ds
.
shape
[
0
]
-
2
,
y
=
2
,
s
=
'
Average: {:.2f}°C
'
.
format
(
ds
.
mean
().
item
()),
fontsize
=
14
,
ha
=
'
right
'
)
# plot seasonal average bias
for
ax
,
season
in
zip
(
axes
[
1
:],
seasons
):
ds
=
bias_snl
[
PREDICTAND
].
sel
(
season
=
season
)
ax
.
imshow
(
ds
.
values
,
origin
=
'
lower
'
,
cmap
=
'
RdBu_r
'
,
vmin
=-
2
,
vmax
=
2
)
ax
.
set_title
(
season
,
fontsize
=
16
);
ax
.
text
(
x
=
ds
.
shape
[
0
]
-
2
,
y
=
2
,
s
=
'
Average: {:.2f}°C
'
.
format
(
ds
.
mean
().
item
()),
fontsize
=
14
,
ha
=
'
right
'
)
# adjust axes
for
ax
in
axes
.
flat
:
ax
.
axes
.
get_xaxis
().
set_ticklabels
([])
ax
.
axes
.
get_xaxis
().
set_ticks
([])
ax
.
axes
.
get_yaxis
().
set_ticklabels
([])
ax
.
axes
.
get_yaxis
().
set_ticks
([])
ax
.
axes
.
axis
(
'
tight
'
)
ax
.
set_xlabel
(
''
)
ax
.
set_ylabel
(
''
)
# adjust figure
fig
.
suptitle
(
'
Average bias of {}: 1991 - 2010
'
.
format
(
NAMES
[
PREDICTAND
]),
fontsize
=
20
);
fig
.
subplots_adjust
(
hspace
=
0
,
wspace
=
0
,
top
=
0.85
)
# add colorbar for bias
axes
=
axes
.
flatten
()
cbar_ax
=
fig
.
add_axes
([
axes
[
-
1
].
get_position
().
x1
+
0.01
,
axes
[
-
1
].
get_position
().
y0
,
0.01
,
axes
[
-
1
].
get_position
().
y1
-
axes
[
-
1
].
get_position
().
y0
])
cbar
=
fig
.
colorbar
(
im2
,
cax
=
cbar_ax
)
cbar
.
set_label
(
label
=
'
Bias / (°C)
'
,
fontsize
=
16
)
cbar
.
ax
.
tick_params
(
labelsize
=
14
)
# save figure
fig
.
savefig
(
'
../Notebooks/Figures/{}_average_bias_seasonal.png
'
.
format
(
PREDICTAND
),
dpi
=
300
,
bbox_inches
=
'
tight
'
)
```
%% Cell type:markdown id:
c70b3
69
d
-2
d16-42e3-9300-4a18757ad1b2
tags:
%% Cell type:markdown id:
416002
69-2
f8c-4717-8f74-b3dfaef60359
tags:
### Bias of extreme values
Calculate the mean annual cycle:
%% Cell type:code id:
4acfc3f2-20ed-498c-ab35-f392ae0e64f9
tags:
%% Cell type:code id:
c9f27c01-4dfc-4d16-8d29-00e69b7794cd
tags:
```
python
# TODO: smooth quantiles
# group timeseries by month and calculate mean over time and space
y_pred_ac
=
y_pred
.
groupby
(
'
time.month
'
).
mean
(
dim
=
(
'
time
'
,
'
y
'
,
'
x
'
))
y_true_ac
=
y_true
.
groupby
(
'
time.month
'
).
mean
(
dim
=
(
'
time
'
,
'
y
'
,
'
x
'
))
```
%% Cell type:code id:bcc63e73-2636-49c1-a66b-241eb5407e2e tags:
```
python
# plot mean annual cycle
fig
,
ax
=
plt
.
subplots
(
1
,
1
,
figsize
=
(
10
,
10
))
ax
.
plot
(
y_pred_ac
[
PREDICTAND
].
values
,
ls
=
'
--
'
,
color
=
'
k
'
,
label
=
'
Predicted
'
)
ax
.
plot
(
y_true_ac
[
PREDICTAND
].
values
,
ls
=
'
-
'
,
color
=
'
k
'
,
label
=
'
Observed
'
)
ax
.
legend
(
frameon
=
False
,
fontsize
=
14
);
ax
.
set_yticks
(
np
.
arange
(
np
.
floor
(
y_true_ac
[
PREDICTAND
].
min
().
item
()),
np
.
ceil
(
y_true_ac
[
PREDICTAND
].
max
().
item
())
+
1
,
1
))
ax
.
set_yticklabels
(
np
.
arange
(
np
.
floor
(
y_true_ac
[
PREDICTAND
].
min
().
item
()),
np
.
ceil
(
y_true_ac
[
PREDICTAND
].
max
().
item
())
+
1
,
1
),
fontsize
=
12
)
ax
.
set_xticks
(
np
.
arange
(
0
,
12
))
ax
.
set_xticklabels
([
calendar
.
month_name
[
i
+
1
]
for
i
in
np
.
arange
(
0
,
12
)],
rotation
=
90
,
fontsize
=
12
)
ax
.
set_title
(
'
Mean annual cycle of {}: 1991 - 2010
'
.
format
(
NAMES
[
PREDICTAND
]),
pad
=
20
,
fontsize
=
16
);
ax
.
set_ylabel
(
'
{} / (°C)
'
.
format
(
NAMES
[
PREDICTAND
].
capitalize
()),
fontsize
=
14
)
ax
.
set_xlabel
(
'
Month
'
,
fontsize
=
14
);
# save figure
fig
.
savefig
(
'
../Notebooks/Figures/{}_mean_annual_cycle.png
'
.
format
(
PREDICTAND
),
dpi
=
300
,
bbox_inches
=
'
tight
'
)
```
%% Cell type:markdown id:c70b369d-2d16-42e3-9300-4a18757ad1b2 tags:
### Bias of extreme values
%% Cell type:code id:33198267-8eb6-4b84-bb6d-c903b27ccbc5 tags:
```
python
#
perce
ntile
s
of interest
perce
ntile
s
=
[
0.0
1
,
0.02
,
0.98
,
0.99
]
#
extreme qua
ntile of interest
qua
ntile
=
0.0
2
if
PREDICTAND
==
'
tasmin
'
else
0.98
```
%% Cell type:code id:
51137d23-a380-4d48-a005-fd1edaf554eb
tags:
%% Cell type:code id:
c3da76d3-7261-4084-b5ec-f65682fd6596
tags:
```
python
# calculate
percentiles over reference period
# calculate
extreme quantile for each year
with
warnings
.
catch_warnings
():
warnings
.
simplefilter
(
'
ignore
'
)
y_pred_dist
=
y_pred
.
quantile
(
q
=
percentiles
,
dim
=
'
time
'
)
y_true_dist
=
y_true
.
quantile
(
q
=
percentiles
,
dim
=
'
time
'
)
warnings
.
simplefilter
(
'
ignore
'
,
category
=
RuntimeWarning
)
y_pred_ex
=
y_pred
.
groupby
(
'
time.year
'
).
quantile
(
quantile
,
dim
=
'
time
'
)
y_true_ex
=
y_true
.
groupby
(
'
time.year
'
).
quantile
(
quantile
,
dim
=
'
time
'
)
```
%% Cell type:code id:20db96c8-e04f-4acb-886b-abb740863fbb tags:
```
python
# calculate bias in extreme quantile for each year
bias_ex
=
y_pred_ex
-
y_true_ex
for
var
in
bias_ex
:
print
(
'
Yearly average bias for P{:.0f} of {}: {:.2f}
'
.
format
(
quantile
*
100
,
var
,
bias_ex
[
var
].
mean
().
item
()))
```
%% Cell type:code id:
b461c774-c5fc-4a50-8609-34a7e7674a34
tags:
%% Cell type:code id:
c7aeff55-deef-4d3c-a251-eac877c9afd9
tags:
```
python
# calculate bias in each percentile over entire reference period
bias_dist
=
y_pred_dist
-
y_true_dist
# mean absolute error in extreme quantile
mae_ex
=
np
.
abs
(
y_pred_ex
-
y_true_ex
).
mean
()
for
var
in
mae_ex
:
print
(
'
Yearly average MAE for P{:.0f} of {}: {:.2f}
'
.
format
(
quantile
*
100
,
var
,
mae_avg
[
var
].
item
()))
```
%% Cell type:code id:
bd2ec314-87ed-407f-af8b-5e2c6785e9cc
tags:
%% Cell type:code id:
2dc6b7cf-7ae1-4b0e-8483-376eab59f5dd
tags:
```
python
# calculate correlation coefficient for extreme values
for
var
in
y_pred_dist
:
for
q
in
percentiles
:
y_p
=
y_pred_dist
[
var
].
sel
(
quantile
=
q
).
values
[
~
np
.
isnan
(
y_pred_dist
[
var
].
sel
(
quantile
=
q
))]
y_t
=
y_true_dist
[
var
].
sel
(
quantile
=
q
).
values
[
~
np
.
isnan
(
y_true_dist
[
var
].
sel
(
quantile
=
q
))]
r
,
_
=
stats
.
pearsonr
(
y_p
,
y_t
)
print
(
'
Pearson correlation for {}, q={:.2f}: R={:.2f}
'
.
format
(
var
,
q
,
r
))
# root mean squared error over reference period
rmse_ex
=
((
y_pred_ex
-
y_true_ex
)
**
2
).
mean
()
for
var
in
rmse_ex
:
print
(
'
Yearly average RMSE for P{:.0f} of {}: {:.2f}
'
.
format
(
quantile
*
100
,
var
,
rmse_ex
[
var
].
item
()))
```
%% Cell type:code id:
04a1b4fc-8fc0-4e1e-adbe-bf1eaafeac5d
tags:
%% Cell type:code id:
cfc196fa-0999-4603-959c-2c82f038c8fa
tags:
```
python
# plot bias in each percentile
fig
,
axes
=
plt
.
subplots
(
len
(
y_pred_dist
.
data_vars
),
len
(
percentiles
),
sharex
=
True
,
sharey
=
True
,
figsize
=
(
32
,
6
))
axes
=
axes
.
reshape
(
len
(
y_pred_dist
.
data_vars
),
-
1
)
for
ax
,
var
in
zip
(
axes
,
y_pred_dist
):
# iterate over percentiles
for
axis
,
q
in
zip
(
ax
,
percentiles
):
ds
=
bias_dist
.
sel
(
quantile
=
q
).
to_array
()
ds
.
plot
(
ax
=
axis
,
vmin
=-
2
,
vmax
=
2
,
cmap
=
'
RdBu_r
'
)
axis
.
text
(
x
=
bias_dist
.
x
[
-
1
],
y
=
bias_dist
.
y
[
0
],
s
=
'
Avg: {:.2f}
'
.
format
(
ds
.
mean
().
item
()),
ha
=
'
right
'
,
va
=
'
bottom
'
)
# plot extremes of observation, prediction, and bias
vmin
,
vmax
=
(
-
20
,
0
)
if
PREDICTAND
==
'
tasmin
'
else
(
10
,
40
)
fig
,
axes
=
plt
.
subplots
(
len
(
y_pred_ex
.
data_vars
),
3
,
figsize
=
(
24
,
len
(
y_pred_ex
.
data_vars
)
*
6
),
sharex
=
True
,
sharey
=
True
)
axes
=
axes
.
reshape
(
len
(
y_pred_ex
.
data_vars
),
-
1
)
for
i
,
var
in
enumerate
(
y_pred_ex
):
for
ds
,
ax
in
zip
([
y_true_ex
,
y_pred_ex
,
bias_ex
],
axes
[
i
,
...]):
if
ds
is
bias_ex
:
ds
=
ds
[
var
].
mean
(
dim
=
'
year
'
)
im2
=
ax
.
imshow
(
ds
.
values
,
origin
=
'
lower
'
,
cmap
=
'
RdBu_r
'
,
vmin
=-
2
,
vmax
=
2
)
ax
.
text
(
x
=
ds
.
shape
[
0
]
-
2
,
y
=
2
,
s
=
'
Average: {:.2f}°C
'
.
format
(
ds
.
mean
().
item
()),
fontsize
=
14
,
ha
=
'
right
'
)
else
:
im1
=
ax
.
imshow
(
ds
[
var
].
mean
(
dim
=
'
year
'
).
values
,
origin
=
'
lower
'
,
cmap
=
'
Blues_r
'
if
PREDICTAND
==
'
tasmin
'
else
'
Reds
'
,
vmin
=
vmin
,
vmax
=
vmax
)
# set titles
axes
[
0
,
0
].
set_title
(
'
Observed
'
,
fontsize
=
16
,
pad
=
10
);
axes
[
0
,
1
].
set_title
(
'
Predicted
'
,
fontsize
=
16
,
pad
=
10
);
axes
[
0
,
2
].
set_title
(
'
Bias
'
,
fontsize
=
16
,
pad
=
10
);
# adjust axes
for
ax
in
axes
.
flat
:
ax
.
axes
.
get_xaxis
().
set_ticklabels
([])
ax
.
axes
.
get_xaxis
().
set_ticks
([])
ax
.
axes
.
get_yaxis
().
set_ticklabels
([])
ax
.
axes
.
get_yaxis
().
set_ticks
([])
ax
.
axes
.
axis
(
'
tight
'
)
ax
.
set_xlabel
(
''
)
ax
.
set_ylabel
(
''
)
# adjust figure
fig
.
suptitle
(
'
Average P{:.0f} of {}: 1991 - 2010
'
.
format
(
quantile
*
100
,
NAMES
[
PREDICTAND
]),
fontsize
=
20
);
fig
.
subplots_adjust
(
hspace
=
0
,
wspace
=
0
,
top
=
0.85
)
# add colorbar for bias
axes
=
axes
.
flatten
()
cbar_ax
=
fig
.
add_axes
([
axes
[
-
1
].
get_position
().
x1
+
0.01
,
axes
[
-
1
].
get_position
().
y0
,
0.01
,
axes
[
-
1
].
get_position
().
y1
-
axes
[
-
1
].
get_position
().
y0
])
cbar
=
fig
.
colorbar
(
im2
,
cax
=
cbar_ax
)
cbar
.
set_label
(
label
=
'
Bias / (°C)
'
,
fontsize
=
16
)
cbar
.
ax
.
tick_params
(
labelsize
=
14
)
# add colorbar for predictand
cbar_ax_predictand
=
fig
.
add_axes
([
axes
[
0
].
get_position
().
x0
,
axes
[
0
].
get_position
().
y0
-
0.1
,
axes
[
-
1
].
get_position
().
x0
-
axes
[
0
].
get_position
().
x0
,
0.05
])
cbar_predictand
=
fig
.
colorbar
(
im1
,
cax
=
cbar_ax_predictand
,
orientation
=
'
horizontal
'
)
cbar_predictand
.
set_label
(
label
=
'
{} / (°C)
'
.
format
(
NAMES
[
PREDICTAND
].
capitalize
()),
fontsize
=
16
)
cbar_predictand
.
ax
.
tick_params
(
labelsize
=
14
)
# add metrics: MAE and RMSE
axes
[
1
].
text
(
x
=
ds
.
shape
[
0
]
-
2
,
y
=
2
,
s
=
'
MAE = {:.2f}°C
'
.
format
(
mae_ex
[
PREDICTAND
].
item
()),
fontsize
=
14
,
ha
=
'
right
'
)
axes
[
1
].
text
(
x
=
ds
.
shape
[
0
]
-
2
,
y
=
12
,
s
=
'
RMSE = {:.2f}°C$^2$
'
.
format
(
rmse_ex
[
PREDICTAND
].
item
()),
fontsize
=
14
,
ha
=
'
right
'
)
# save figure
fig
.
savefig
(
'
../Notebooks/Figures/{}_average_bias_p{:.0f}.png
'
.
format
(
PREDICTAND
,
quantile
*
100
),
dpi
=
300
,
bbox_inches
=
'
tight
'
)
```
...
...
This diff is collapsed.
Click to expand it.
Preview
0%
Loading
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Save comment
Cancel
Please
register
or
sign in
to comment