Skip to content
Snippets Groups Projects
Commit 1d0f412e authored by Frisinghelli Daniel's avatar Frisinghelli Daniel
Browse files

Refactor.

parent 432286c1
No related branches found
No related tags found
No related merge requests found
%% Cell type:markdown id:4735431f-6741-437e-b0fc-dd6d8eaa22ca tags: %% Cell type:markdown id:4735431f-6741-437e-b0fc-dd6d8eaa22ca tags:
# Evaluate ERA-5 downscaling: precipitation # Evaluate ERA-5 downscaling: precipitation
%% Cell type:markdown id:a87da113-4b0f-4ac8-9721-19c85848acec tags: %% Cell type:markdown id:a87da113-4b0f-4ac8-9721-19c85848acec tags:
We used **1981-1991 as training** period and **1991-2010 as reference** period. The results shown in this notebook are based on the model predictions on the reference period. We used **1981-1991 as training** period and **1991-2010 as reference** period. The results shown in this notebook are based on the model predictions on the reference period.
%% Cell type:markdown id:ad1769d4-9c0c-4e3f-9adf-ef02bb43c047 tags: %% Cell type:markdown id:ad1769d4-9c0c-4e3f-9adf-ef02bb43c047 tags:
**Predictors on pressure levels (500, 850)**: **Predictors on pressure levels (500, 850)**:
- Geopotential (z) - Geopotential (z)
- Temperature (t) - Temperature (t)
- Zonal wind (u) - Zonal wind (u)
- Meridional wind (v) - Meridional wind (v)
- Specific humidity (q) - Specific humidity (q)
**Predictors on surface**: **Predictors on surface**:
- Mean sea level pressure (msl) - Mean sea level pressure (msl)
**Auxiliary predictors**: **Auxiliary predictors**:
- Elevation from Copernicus EU-DEM v1.1 (dem) - Elevation from Copernicus EU-DEM v1.1 (dem)
- Day of the year (doy) - Day of the year (doy)
%% Cell type:markdown id:f9334da7-17d1-45ef-9ed9-5c2bee9fcdcc tags: %% Cell type:markdown id:f9334da7-17d1-45ef-9ed9-5c2bee9fcdcc tags:
Define the predictand and the model to evaluate: Define the predictand and the model to evaluate:
%% Cell type:code id:a81acde7-16a2-4087-bc08-95b084adbd06 tags: %% Cell type:code id:a81acde7-16a2-4087-bc08-95b084adbd06 tags:
``` python ``` python
# define the model parameters # define the model parameters
PREDICTAND = 'pr' PREDICTAND = 'pr'
MODEL = 'USegNet' MODEL = 'USegNet'
PPREDICTORS = 'ztuvq' PPREDICTORS = 'ztuvq'
# PPREDICTORS = '' # PPREDICTORS = ''
PLEVELS = ['500', '850'] PLEVELS = ['500', '850']
# PLEVELS = [] # PLEVELS = []
SPREDICTORS = 'p' SPREDICTORS = 'p'
DEM = '' DEM = 'dem'
DEM_FEATURES = 'dem' DEM_FEATURES = ''
DOY = '' DOY = ''
WET_DAY_THRESHOLD = '2' WET_DAY_THRESHOLD = '1'
# LOSS = 'MSELoss' # LOSS = 'MSELoss'
LOSS = 'BernoulliGammaLoss' LOSS = 'BernoulliGammaLoss'
SEASON = 'season'
``` ```
%% Cell type:markdown id:dd188df0-69ee-44b0-82b2-d212994dc271 tags: %% Cell type:markdown id:dd188df0-69ee-44b0-82b2-d212994dc271 tags:
### Imports ### Imports
%% Cell type:code id:06792bf2-b33b-4728-ba60-d60fab46779d tags: %% Cell type:code id:06792bf2-b33b-4728-ba60-d60fab46779d tags:
``` python ``` python
# builtins # builtins
import datetime import datetime
import warnings import warnings
import calendar import calendar
# externals # externals
import xarray as xr import xarray as xr
import numpy as np import numpy as np
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1.inset_locator import inset_axes from mpl_toolkits.axes_grid1.inset_locator import inset_axes
import scipy.stats as stats import scipy.stats as stats
from IPython.display import Image from IPython.display import Image
from sklearn.metrics import r2_score, roc_curve, auc, classification_report from sklearn.metrics import r2_score, roc_curve, auc, classification_report
# locals # locals
from climax.main.io import ERA5_PATH, OBS_PATH, TARGET_PATH, DEM_PATH from climax.main.io import ERA5_PATH, OBS_PATH, TARGET_PATH, DEM_PATH
from pysegcnn.core.utils import search_files from pysegcnn.core.utils import search_files
from pysegcnn.core.graphics import plot_classification_report from pysegcnn.core.graphics import plot_classification_report
``` ```
%% Cell type:code id:cd14134e-f9be-4935-877b-ef6d34e03d2e tags: %% Cell type:code id:cd14134e-f9be-4935-877b-ef6d34e03d2e tags:
``` python ``` python
# mapping from predictands to variable names # mapping from predictands to variable names
NAMES = {'tasmin': 'minimum temperature', 'tasmax': 'maximum temperature', 'pr': 'precipitation'} NAMES = {'tasmin': 'minimum temperature', 'tasmax': 'maximum temperature', 'pr': 'precipitation'}
``` ```
%% Cell type:markdown id:a3d92474-2bc1-4035-8938-c5fbb07ae891 tags: %% Cell type:markdown id:a3d92474-2bc1-4035-8938-c5fbb07ae891 tags:
### Model architecture ### Model architecture
%% Cell type:code id:c68e022b-41c2-438b-bfb0-e27eddee89bf tags: %% Cell type:code id:c68e022b-41c2-438b-bfb0-e27eddee89bf tags:
``` python ``` python
Image("./Figures/architecture.png", width=900, height=400) Image("./Figures/architecture.png", width=900, height=400)
``` ```
%% Cell type:markdown id:c8833efe-c715-4872-9aee-a0b5766f5c67 tags: %% Cell type:markdown id:c8833efe-c715-4872-9aee-a0b5766f5c67 tags:
### Loss function ### Loss function
%% Cell type:markdown id:bb801367-6872-4a0b-bff5-70cb6746e057 tags: %% Cell type:markdown id:bb801367-6872-4a0b-bff5-70cb6746e057 tags:
For precipitation, the network is optimizing the negative log-likelihood of a Bernoulli-Gamma distribution after [Cannon (2008)](http://journals.ametsoc.org/doi/10.1175/2008JHM960.1). For precipitation, the network is optimizing the negative log-likelihood of a Bernoulli-Gamma distribution after [Cannon (2008)](http://journals.ametsoc.org/doi/10.1175/2008JHM960.1).
%% Cell type:markdown id:a8775d5e-5ad4-47e4-8fef-6dc230e15dee tags: %% Cell type:markdown id:a8775d5e-5ad4-47e4-8fef-6dc230e15dee tags:
Bernoulli-Gamma distribution: Bernoulli-Gamma distribution:
%% Cell type:markdown id:ab10f8de-d8d2-4427-b9c8-5d68803543c3 tags: %% Cell type:markdown id:ab10f8de-d8d2-4427-b9c8-5d68803543c3 tags:
$$P(y \mid, p, \alpha, \beta) = \begin{cases} 1 - p, & \text{for } y = 0\\ p \cdot \frac{y^{\alpha -1} \exp(-y/\beta)}{\beta^{\alpha} \tau(\alpha)}, & \text{for } y > 0\end{cases}$$ $$P(y \mid, p, \alpha, \beta) = \begin{cases} 1 - p, & \text{for } y = 0\\ p \cdot \frac{y^{\alpha -1} \exp(-y/\beta)}{\beta^{\alpha} \tau(\alpha)}, & \text{for } y > 0\end{cases}$$
%% Cell type:markdown id:6b6dbd06-1f0e-4c52-84f2-b7ff31c75726 tags: %% Cell type:markdown id:6b6dbd06-1f0e-4c52-84f2-b7ff31c75726 tags:
Log-likelihood function: Log-likelihood function:
%% Cell type:markdown id:e41c7b39-f352-4a98-820f-9a7345b3283c tags: %% Cell type:markdown id:e41c7b39-f352-4a98-820f-9a7345b3283c tags:
$$\mathcal{J}(p, \alpha, \beta \mid y) = \underbrace{(1 - P(y > 0)) \log(1 - p)}_{\text{Bernoulli}} + \underbrace{P(y > 0) \cdot \left(\log(p) + (\alpha - 1) \log(y) - \frac{y}{\beta} - \alpha \log(\beta) - \log(\tau(\alpha))\right)}_{\text{Gamma}}$$ $$\mathcal{J}(p, \alpha, \beta \mid y) = \underbrace{(1 - P(y > 0)) \log(1 - p)}_{\text{Bernoulli}} + \underbrace{P(y > 0) \cdot \left(\log(p) + (\alpha - 1) \log(y) - \frac{y}{\beta} - \alpha \log(\beta) - \log(\tau(\alpha))\right)}_{\text{Gamma}}$$
%% Cell type:markdown id:5a0c55f0-79fb-4501-b3cf-b5414399a3d9 tags: %% Cell type:markdown id:5a0c55f0-79fb-4501-b3cf-b5414399a3d9 tags:
### Load datasets ### Load datasets
%% Cell type:code id:efa76f8e-c089-47ff-a001-d4c2a11c4d6d tags: %% Cell type:code id:efa76f8e-c089-47ff-a001-d4c2a11c4d6d tags:
``` python ``` python
# construct file pattern to match # construct file pattern to match
PATTERN = '_'.join([MODEL, PREDICTAND, PPREDICTORS, *PLEVELS]) PATTERN = '_'.join([MODEL, PREDICTAND, PPREDICTORS, *PLEVELS])
PATTERN = '_'.join([PATTERN, SPREDICTORS]) if SPREDICTORS else PATTERN PATTERN = '_'.join([PATTERN, SPREDICTORS]) if SPREDICTORS else PATTERN
PATTERN = '_'.join([PATTERN, DEM]) if DEM else PATTERN PATTERN = '_'.join([PATTERN, DEM]) if DEM else PATTERN
PATTERN = '_'.join([PATTERN, DEM_FEATURES]) if DEM_FEATURES else PATTERN PATTERN = '_'.join([PATTERN, DEM_FEATURES]) if DEM_FEATURES else PATTERN
PATTERN = '_'.join([PATTERN, DOY]) if DOY else PATTERN PATTERN = '_'.join([PATTERN, DOY]) if DOY else PATTERN
PATTERN = '_'.join([PATTERN, '{}mm'.format(str(WET_DAY_THRESHOLD).replace('.', ''))]) if WET_DAY_THRESHOLD else PATTERN PATTERN = '_'.join([PATTERN, '{}mm'.format(str(WET_DAY_THRESHOLD).replace('.', ''))]) if WET_DAY_THRESHOLD else PATTERN
PATTERN = '_'.join([PATTERN, LOSS]) PATTERN = '_'.join([PATTERN, LOSS])
PATTERN = '_'.join([PATTERN, SEASON]) if SEASON else PATTERN
PATTERN PATTERN
``` ```
%% Cell type:code id:ecaba394-f802-481f-8274-c44e2f5fdf1a tags: %% Cell type:code id:ecaba394-f802-481f-8274-c44e2f5fdf1a tags:
``` python ``` python
# digital elevation model # digital elevation model
dem = xr.open_dataset(search_files(DEM_PATH, 'eu_dem_v11_stt.nc').pop()) dem = xr.open_dataset(search_files(DEM_PATH, 'eu_dem_v11_stt.nc').pop())
dem = dem.Band1.to_dataset().rename({'Band1': 'elevation'}) dem = dem.Band1.to_dataset().rename({'Band1': 'elevation'})
``` ```
%% Cell type:code id:a5db133d-2c36-4e84-879e-20e617e821f1 tags: %% Cell type:code id:a5db133d-2c36-4e84-879e-20e617e821f1 tags:
``` python ``` python
# model predictions and observations NetCDF # model predictions and observations NetCDF
y_pred = xr.open_dataset(search_files(TARGET_PATH.joinpath(PREDICTAND), '.'.join([PATTERN, 'nc$'])).pop()) y_pred = xr.open_dataset(search_files(TARGET_PATH.joinpath(PREDICTAND), '.'.join([PATTERN, 'nc$'])).pop())
y_true = xr.open_dataset(search_files(OBS_PATH.joinpath(PREDICTAND), 'OBS_pr(.*).nc$').pop()) y_true = xr.open_dataset(search_files(OBS_PATH.joinpath(PREDICTAND), 'OBS_pr(.*).nc$').pop())
try: try:
y_pred = y_pred.rename({'pr': 'precipitation'}) y_pred = y_pred.rename({'pr': 'precipitation'})
except ValueError: except ValueError:
pass pass
``` ```
%% Cell type:code id:325f3086-85f0-4c28-b37d-7370d2d92405 tags: %% Cell type:code id:325f3086-85f0-4c28-b37d-7370d2d92405 tags:
``` python ``` python
# reference dataset: ERA-5 precipitation # reference dataset: ERA-5 precipitation
y_refe = xr.open_dataset(search_files(ERA5_PATH.joinpath('ERA5', 'total_precipitation'), '.nc$').pop()) y_refe = xr.open_dataset(search_files(ERA5_PATH.joinpath('ERA5', 'total_precipitation'), '.nc$').pop())
``` ```
%% Cell type:code id:528c3116-7707-45ca-b811-0adad7bc20f3 tags: %% Cell type:code id:528c3116-7707-45ca-b811-0adad7bc20f3 tags:
``` python ``` python
# subset to time period covered by predictions # subset to time period covered by predictions
y_true = y_true.sel(time=y_pred.time) y_true = y_true.sel(time=y_pred.time)
y_refe = y_refe.sel(time=y_pred.time).drop_vars('lambert_azimuthal_equal_area') y_refe = y_refe.sel(time=y_pred.time).drop_vars('lambert_azimuthal_equal_area')
y_refe = y_refe.rename({'tp': 'precipitation'}) y_refe = y_refe.rename({'tp': 'precipitation'})
``` ```
%% Cell type:code id:70b903cb-e597-45d3-b575-0ebaf7a45649 tags: %% Cell type:code id:70b903cb-e597-45d3-b575-0ebaf7a45649 tags:
``` python ``` python
# align datasets and mask missing values in model predictions # align datasets
if LOSS == 'BernoulliGammaLoss': if len(y_pred.data_vars) > 1:
y_true, y_refe, y_pred_pr, y_pred_prob = xr.align(y_true.precipitation, y_refe.precipitation, y_pred.precipitation, y_pred.prob, join='override') y_true, y_refe, y_pred_pr, y_pred_prob = xr.align(y_true.precipitation, y_refe.precipitation, y_pred.precipitation, y_pred.prob, join='override')
y_pred_prob = y_pred_prob.where(~np.isnan(y_true), other=np.nan)
else: else:
y_true, y_refe, y_pred_pr = xr.align(y_true.precipitation, y_refe.precipitation, y_pred.precipitation, join='override') y_true, y_refe, y_pred_pr = xr.align(y_true.precipitation, y_refe.precipitation, y_pred.precipitation, join='override')
y_pred_pr = y_pred_pr.where(~np.isnan(y_true), other=np.nan) ```
y_refe = y_refe.where(~np.isnan(y_true), other=np.nan)
%% Cell type:code id:aa3f8e67-6b49-46d4-a956-a0cee7b3923a tags:
``` python
# mask missing values
mask = ~np.isnan(y_true)
y_pred_pr = y_pred_pr.where(mask, other=np.nan)
y_refe = y_refe.where(mask, other=np.nan)
if len(y_pred.data_vars) > 1:
y_pred_prob = y_pred_prob.where(mask, other=np.nan)
``` ```
%% Cell type:code id:4e71997f-7808-463a-8b45-dcac639ebe88 tags: %% Cell type:code id:4e71997f-7808-463a-8b45-dcac639ebe88 tags:
``` python ``` python
# align digital elevation model # align digital elevation model
_, dem = xr.align(y_true.isel(time=0), dem, join='override') _, dem = xr.align(y_true.isel(time=0), dem, join='override')
dem = dem.where(~np.isnan(y_true.isel(time=0)), other=np.nan) dem = dem.where(~np.isnan(y_true.isel(time=0)), other=np.nan)
``` ```
%% Cell type:markdown id:b269a131-cf5b-4c6c-9f8e-a5408250aa83 tags: %% Cell type:markdown id:b269a131-cf5b-4c6c-9f8e-a5408250aa83 tags:
## Model validation: precipitation amount ## Model validation: precipitation amount
%% Cell type:markdown id:0fa1fe82-0d6e-4676-b5b8-9eac2fd28ffb tags: %% Cell type:markdown id:0fa1fe82-0d6e-4676-b5b8-9eac2fd28ffb tags:
### Coefficient of determination: monthly mean ### Coefficient of determination: monthly mean
%% Cell type:code id:4d11cee6-1ebd-4424-8ec6-92e5a196bac4 tags: %% Cell type:code id:4d11cee6-1ebd-4424-8ec6-92e5a196bac4 tags:
``` python ``` python
# calculate monthly mean precipitation (mm / month) # calculate monthly mean precipitation (mm / month)
y_pred_values = y_pred_pr.resample(time='1M').sum(skipna=False).groupby('time.month').mean(dim='time').values y_pred_values = y_pred_pr.resample(time='1M').sum(skipna=False).groupby('time.month').mean(dim='time').values
y_true_values = y_true.resample(time='1M').sum(skipna=False).groupby('time.month').mean(dim='time').values y_true_values = y_true.resample(time='1M').sum(skipna=False).groupby('time.month').mean(dim='time').values
``` ```
%% Cell type:code id:6e55cdd5-8adb-40c5-9742-46f4fc3d4be9 tags: %% Cell type:code id:6e55cdd5-8adb-40c5-9742-46f4fc3d4be9 tags:
``` python ``` python
# apply mask of valid pixels # apply mask of valid pixels
mask = (~np.isnan(y_pred_values) & ~np.isnan(y_true_values)) mask = (~np.isnan(y_pred_values) & ~np.isnan(y_true_values))
y_pred_values = y_pred_values[mask] y_pred_values = y_pred_values[mask]
y_true_values = y_true_values[mask] y_true_values = y_true_values[mask]
``` ```
%% Cell type:code id:13a9ff21-34ea-4db1-9c0a-bcb7ea1001f7 tags: %% Cell type:code id:13a9ff21-34ea-4db1-9c0a-bcb7ea1001f7 tags:
``` python ``` python
# calculate coefficient of determination # calculate coefficient of determination
r2 = r2_score(y_true_values, y_pred_values) r2 = r2_score(y_true_values, y_pred_values)
r2 r2
``` ```
%% Cell type:code id:a1b33431-9f2b-4e42-bbb0-f4fa258edd98 tags: %% Cell type:code id:a1b33431-9f2b-4e42-bbb0-f4fa258edd98 tags:
``` python ``` python
# group timeseries by month and calculate mean over time and space # group timeseries by month and calculate mean over time and space
y_pred_ac = y_pred_pr.resample(time='1M').sum(skipna=False).groupby('time.month').mean(dim=('y', 'x', 'time'), skipna=True).values.squeeze() y_pred_ac = y_pred_pr.resample(time='1M').sum(skipna=False).groupby('time.month').mean(dim=('y', 'x', 'time'), skipna=True).values.squeeze()
y_true_ac = y_true.resample(time='1M').sum(skipna=False).groupby('time.month').mean(dim=('y', 'x', 'time'), skipna=True).values.squeeze() y_true_ac = y_true.resample(time='1M').sum(skipna=False).groupby('time.month').mean(dim=('y', 'x', 'time'), skipna=True).values.squeeze()
``` ```
%% Cell type:code id:c683858e-8b7f-4c76-a40f-6f68397b3479 tags: %% Cell type:code id:c683858e-8b7f-4c76-a40f-6f68397b3479 tags:
``` python ``` python
# scatter plot of observations vs. predictions # scatter plot of observations vs. predictions
fig, ax = plt.subplots(1, 1, figsize=(10, 10)) fig, ax = plt.subplots(1, 1, figsize=(10, 10))
# plot only a subset of data: otherwise plot is overloaded ... # plot only a subset of data: otherwise plot is overloaded ...
# subset = np.random.choice(np.arange(0, len(y_pred_values)), size=int(1e3), replace=False) # subset = np.random.choice(np.arange(0, len(y_pred_values)), size=int(1e3), replace=False)
# ax.plot(y_true_values[subset], y_pred_values[subset], 'o', alpha=.5, markeredgecolor='grey', markerfacecolor='none', markersize=3); # ax.plot(y_true_values[subset], y_pred_values[subset], 'o', alpha=.5, markeredgecolor='grey', markerfacecolor='none', markersize=3);
# plot entire dataset # plot entire dataset
ax.plot(y_true_values, y_pred_values, 'o', alpha=.5, markeredgecolor='grey', markerfacecolor='none', markersize=3); ax.plot(y_true_values, y_pred_values, 'o', alpha=.5, markeredgecolor='grey', markerfacecolor='none', markersize=3);
# plot 1:1 mapping line # plot 1:1 mapping line
interval = np.arange(0, 300, 50) interval = np.arange(0, 300, 50)
ax.plot(interval, interval, color='k', lw=2, ls='--') ax.plot(interval, interval, color='k', lw=2, ls='--')
# add coefficient of determination: calculated on entire dataset! # add coefficient of determination: calculated on entire dataset!
ax.text(interval[-1] - 2, interval[0] + 2, s='Coefficient of determination R$^2$ = {:.2f}'.format(r2), ha='right', fontsize=18) ax.text(interval[-1] - 2, interval[0] + 2, s='Coefficient of determination R$^2$ = {:.2f}'.format(r2), ha='right', fontsize=18)
# format axes # format axes
ax.set_ylim(interval[0], interval[-1]) ax.set_ylim(interval[0], interval[-1])
ax.set_xlim(interval[0], interval[-1]) ax.set_xlim(interval[0], interval[-1])
ax.set_xticks(interval) ax.set_xticks(interval)
ax.set_xticklabels(interval, fontsize=16) ax.set_xticklabels(interval, fontsize=16)
ax.set_yticks(interval) ax.set_yticks(interval)
ax.set_yticklabels(interval, fontsize=16) ax.set_yticklabels(interval, fontsize=16)
ax.set_xlabel('Observed', fontsize=18) ax.set_xlabel('Observed', fontsize=18)
ax.set_ylabel('Predicted', fontsize=18) ax.set_ylabel('Predicted', fontsize=18)
ax.set_title('Monthly mean {} (mm / month)'.format(NAMES[PREDICTAND]), fontsize=20, pad=10); ax.set_title('Monthly mean {} (mm / month)'.format(NAMES[PREDICTAND]), fontsize=20, pad=10);
# add axis for annual cycle # add axis for annual cycle
axins = inset_axes(ax, width="30%", height="40%", loc=2, borderpad=0.25) axins = inset_axes(ax, width="30%", height="40%", loc=2, borderpad=0.25)
axins.plot(y_pred_ac, ls='--', color='k', label='Predicted') axins.plot(y_pred_ac, ls='--', color='k', label='Predicted')
axins.plot(y_true_ac, ls='-', color='k', label='Observed') axins.plot(y_true_ac, ls='-', color='k', label='Observed')
axins.legend(frameon=False, fontsize=12, loc='lower center'); axins.legend(frameon=False, fontsize=12, loc='lower center');
axins.set_yticks(np.arange(0, 200, 50)) axins.set_yticks(np.arange(0, 200, 50))
axins.set_yticklabels(np.arange(0, 200, 50), fontsize=12) axins.set_yticklabels(np.arange(0, 200, 50), fontsize=12)
axins.yaxis.tick_right() axins.yaxis.tick_right()
axins.set_xticks(np.arange(0, 12)) axins.set_xticks(np.arange(0, 12))
axins.set_xticklabels([calendar.month_name[i + 1] for i in np.arange(0, 12)], rotation=90, fontsize=12) axins.set_xticklabels([calendar.month_name[i + 1] for i in np.arange(0, 12)], rotation=90, fontsize=12)
# save figure # save figure
fig.savefig('../Notebooks/Figures/{}_r2.png'.format(PREDICTAND), dpi=300, bbox_inches='tight') fig.savefig('../Notebooks/Figures/{}_r2.png'.format(PREDICTAND), dpi=300, bbox_inches='tight')
``` ```
%% Cell type:markdown id:3538a109-4d3a-418e-9d4d-c3c120256ab2 tags: %% Cell type:markdown id:3538a109-4d3a-418e-9d4d-c3c120256ab2 tags:
### Bias ### Bias
%% Cell type:markdown id:43c922bd-9e9b-4812-b096-f3bde06fb249 tags: %% Cell type:markdown id:43c922bd-9e9b-4812-b096-f3bde06fb249 tags:
Calculate yearly average bias over entire reference period: Calculate yearly average bias over entire reference period:
%% Cell type:code id:52efc445-054e-4184-9667-4e761bc12a84 tags: %% Cell type:code id:52efc445-054e-4184-9667-4e761bc12a84 tags:
``` python ``` python
# yearly average bias over reference period # yearly average bias over reference period
y_pred_yearly_avg = y_pred_pr.groupby('time.year').mean(dim='time') y_pred_yearly_avg = y_pred_pr.groupby('time.year').mean(dim='time')
y_refe_yearly_avg = y_refe.groupby('time.year').mean(dim='time') y_refe_yearly_avg = y_refe.groupby('time.year').mean(dim='time')
y_true_yearly_avg = y_true.groupby('time.year').mean(dim='time') y_true_yearly_avg = y_true.groupby('time.year').mean(dim='time')
bias_yearly_avg = ((y_pred_yearly_avg - y_true_yearly_avg) / y_true_yearly_avg) * 100 bias_yearly_avg = ((y_pred_yearly_avg - y_true_yearly_avg) / y_true_yearly_avg) * 100
bias_yearly_avg_ref = ((y_refe_yearly_avg - y_true_yearly_avg) / y_true_yearly_avg) * 100 bias_yearly_avg_ref = ((y_refe_yearly_avg - y_true_yearly_avg) / y_true_yearly_avg) * 100
print('(Model) Yearly average relative bias: {:.2f}%'.format(bias_yearly_avg.mean().item())) print('(Model) Yearly average relative bias: {:.2f}%'.format(bias_yearly_avg.mean().item()))
print('(ERA-5) Yearly average relative bias: {:.2f}%'.format(bias_yearly_avg_ref.mean().item())) print('(ERA-5) Yearly average relative bias: {:.2f}%'.format(bias_yearly_avg_ref.mean().item()))
``` ```
%% Cell type:code id:92a094cb-2a88-4f4a-8c87-1025672d6fe7 tags: %% Cell type:code id:92a094cb-2a88-4f4a-8c87-1025672d6fe7 tags:
``` python ``` python
# mean absolute error over reference period # mean absolute error over reference period
mae_avg = np.abs(y_pred_yearly_avg - y_true_yearly_avg) mae_avg = np.abs(y_pred_yearly_avg - y_true_yearly_avg)
mae_avg_ref = np.abs(y_refe_yearly_avg - y_true_yearly_avg) mae_avg_ref = np.abs(y_refe_yearly_avg - y_true_yearly_avg)
print('(Model) Yearly average MAE: {:.2f} mm'.format(mae_avg.mean().item())) print('(Model) Yearly average MAE: {:.2f} mm'.format(mae_avg.mean().item()))
print('(ERA-5) Yearly average MAE: {:.2f} mm'.format(mae_avg_ref.mean().item())) print('(ERA-5) Yearly average MAE: {:.2f} mm'.format(mae_avg_ref.mean().item()))
``` ```
%% Cell type:code id:25d397e2-1b46-4f48-b39d-8632a7e56288 tags: %% Cell type:code id:25d397e2-1b46-4f48-b39d-8632a7e56288 tags:
``` python ``` python
# root mean squared error over reference period # root mean squared error over reference period
rmse_avg = ((y_pred_yearly_avg - y_true_yearly_avg) ** 2).mean() rmse_avg = ((y_pred_yearly_avg - y_true_yearly_avg) ** 2).mean()
rmse_avg_ref = ((y_refe_yearly_avg - y_true_yearly_avg) **2).mean() rmse_avg_ref = ((y_refe_yearly_avg - y_true_yearly_avg) **2).mean()
print('(Model) Yearly average RMSE: {:.2f} mm / day'.format(rmse_avg.item())) print('(Model) Yearly average RMSE: {:.2f} mm / day'.format(rmse_avg.item()))
print('(ERA-5) Yearly average RMSE: {:.2f} mm / day'.format(rmse_avg_ref.item())) print('(ERA-5) Yearly average RMSE: {:.2f} mm / day'.format(rmse_avg_ref.item()))
``` ```
%% Cell type:code id:d6bbdcd6-3920-4856-b90a-d56f0e5ab2df tags: %% Cell type:code id:d6bbdcd6-3920-4856-b90a-d56f0e5ab2df tags:
``` python ``` python
# Pearson's correlation coefficient over reference period # Pearson's correlation coefficient over reference period
for year in y_pred_yearly_avg.year: for year in y_pred_yearly_avg.year:
y_p = y_pred_yearly_avg.sel(year=year).values y_p = y_pred_yearly_avg.sel(year=year).values
y_t = y_true_yearly_avg.sel(year=year).values y_t = y_true_yearly_avg.sel(year=year).values
r, _ = stats.pearsonr(y_p[~np.isnan(y_p)], y_t[~np.isnan(y_t)]) r, _ = stats.pearsonr(y_p[~np.isnan(y_p)], y_t[~np.isnan(y_t)])
print('({:0d}) Pearson correlation: {:.2f}'.format(year.item(), np.asarray(r).mean())) print('({:0d}) Pearson correlation: {:.2f}'.format(year.item(), np.asarray(r).mean()))
``` ```
%% Cell type:code id:ca683581-6b3b-4abd-ad46-157ebded19c6 tags: %% Cell type:code id:ca683581-6b3b-4abd-ad46-157ebded19c6 tags:
``` python ``` python
# plot yearly average MAE of reference vs. prediction # plot yearly average MAE of reference vs. prediction
vmin, vmax = 0, 5 vmin, vmax = 0, 5
fig, axes = plt.subplots(1, 3, figsize=(24, 8), sharex=True, sharey=True) fig, axes = plt.subplots(1, 3, figsize=(24, 8), sharex=True, sharey=True)
# plot bias of ERA-5 reference # plot bias of ERA-5 reference
reference = bias_yearly_avg_ref.mean(dim='year') reference = bias_yearly_avg_ref.mean(dim='year')
im1 = axes[0].imshow(reference.values, origin='lower', cmap='RdBu_r', vmin=-40, vmax=40) im1 = axes[0].imshow(reference.values, origin='lower', cmap='RdBu_r', vmin=-40, vmax=40)
axes[0].text(x=reference.shape[0] - 2, y=2, s='Average: {:.1f}%'.format(reference.mean().item()), fontsize=14, ha='right') axes[0].text(x=reference.shape[0] - 2, y=2, s='Average: {:.1f}%'.format(reference.mean().item()), fontsize=14, ha='right')
# plot MAE of model # plot MAE of model
prediction = bias_yearly_avg.mean(dim='year') prediction = bias_yearly_avg.mean(dim='year')
im2 = axes[1].imshow(prediction.values, origin='lower', cmap='RdBu_r', vmin=-40, vmax=40) im2 = axes[1].imshow(prediction.values, origin='lower', cmap='RdBu_r', vmin=-40, vmax=40)
axes[1].text(x=reference.shape[0] - 2, y=2, s='Average: {:.1f}%'.format(prediction.mean().item()), fontsize=14, ha='right') axes[1].text(x=reference.shape[0] - 2, y=2, s='Average: {:.1f}%'.format(prediction.mean().item()), fontsize=14, ha='right')
# plot topography # plot topography
im_dem = axes[2].imshow(dem['elevation'].values, origin='lower', cmap='terrain', vmin=0, vmax=4000) im_dem = axes[2].imshow(dem['elevation'].values, origin='lower', cmap='terrain', vmin=0, vmax=4000)
# set titles # set titles
axes[0].set_title('ERA-5', fontsize=14, pad=10); axes[0].set_title('ERA-5', fontsize=14, pad=10);
axes[1].set_title('DCEDN', fontsize=14, pad=10); axes[1].set_title('DCEDN', fontsize=14, pad=10);
axes[2].set_title('Copernicus EU-DEM v1.1', fontsize=14, pad=10) axes[2].set_title('Copernicus EU-DEM v1.1', fontsize=14, pad=10)
# adjust axes # adjust axes
for ax in axes.flat: for ax in axes.flat:
ax.axes.get_xaxis().set_ticklabels([]) ax.axes.get_xaxis().set_ticklabels([])
ax.axes.get_xaxis().set_ticks([]) ax.axes.get_xaxis().set_ticks([])
ax.axes.get_yaxis().set_ticklabels([]) ax.axes.get_yaxis().set_ticklabels([])
ax.axes.get_yaxis().set_ticks([]) ax.axes.get_yaxis().set_ticks([])
ax.axes.axis('tight') ax.axes.axis('tight')
ax.set_xlabel('') ax.set_xlabel('')
ax.set_ylabel('') ax.set_ylabel('')
ax.set_axis_off() ax.set_axis_off()
# adjust figure # adjust figure
# fig.suptitle('Average yearly mean absolute error: 1991 - 2010', fontsize=20); # fig.suptitle('Average yearly mean absolute error: 1991 - 2010', fontsize=20);
fig.subplots_adjust(hspace=0, wspace=0, top=0.85) fig.subplots_adjust(hspace=0, wspace=0, top=0.85)
# add colorbar for bias # add colorbar for dem
axes = axes.flatten() axes = axes.flatten()
cbar_ax_bias = fig.add_axes([axes[-1].get_position().x1 + 0.01, axes[-1].get_position().y0, cbar_ax_bias = fig.add_axes([axes[-1].get_position().x1 + 0.01, axes[-1].get_position().y0,
0.01, axes[-1].get_position().y1 - axes[-1].get_position().y0]) 0.01, axes[-1].get_position().y1 - axes[-1].get_position().y0])
cbar_bias = fig.colorbar(im_dem, cax=cbar_ax_bias) cbar_bias = fig.colorbar(im_dem, cax=cbar_ax_bias)
cbar_bias.set_label(label='Elevation (m)', fontsize=14) cbar_bias.set_label(label='Elevation (m)', fontsize=14)
cbar_bias.ax.tick_params(labelsize=14, pad=10) cbar_bias.ax.tick_params(labelsize=14, pad=10)
# add colorbar for predictand # add colorbar for predictand
cbar_ax_predictand = fig.add_axes([axes[0].get_position().x0, axes[0].get_position().y0 - 0.1, cbar_ax_predictand = fig.add_axes([axes[0].get_position().x0, axes[0].get_position().y0 - 0.1,
axes[-1].get_position().x0 - axes[0].get_position().x0, axes[-1].get_position().x0 - axes[0].get_position().x0,
0.03]) 0.03])
cbar_predictand = fig.colorbar(im1, cax=cbar_ax_predictand, orientation='horizontal') cbar_predictand = fig.colorbar(im1, cax=cbar_ax_predictand, orientation='horizontal')
cbar_predictand.set_label(label='Relative mean error (%)', fontsize=14) cbar_predictand.set_label(label='Relative mean error (%)', fontsize=14)
cbar_predictand.ax.tick_params(labelsize=14, pad=10) cbar_predictand.ax.tick_params(labelsize=14, pad=10)
# add metrics: MAE and RMSE # add metrics: MAE and RMSE
#axes[1].text(x=ds.shape[0] - 2, y=2, s='MAE = {:.1f}'.format(mae_avg[NAMES[PREDICTAND]].item()) + 'mm day$^{-1}$', fontsize=14, ha='right') #axes[1].text(x=ds.shape[0] - 2, y=2, s='MAE = {:.1f}'.format(mae_avg[NAMES[PREDICTAND]].item()) + 'mm day$^{-1}$', fontsize=14, ha='right')
#axes[1].text(x=ds.shape[0] - 2, y=12, s='RMSE = {:.1f}'.format(rmse_avg[NAMES[PREDICTAND]].item()) + 'mm$^2$ day$^{-2}$', fontsize=14, ha='right') #axes[1].text(x=ds.shape[0] - 2, y=12, s='RMSE = {:.1f}'.format(rmse_avg[NAMES[PREDICTAND]].item()) + 'mm$^2$ day$^{-2}$', fontsize=14, ha='right')
# save figure # save figure
fig.savefig('../Notebooks/Figures/{}_rbias_ERA_vs_model.png'.format(PREDICTAND), dpi=300, bbox_inches='tight') fig.savefig('../Notebooks/Figures/{}_rbias_ERA_vs_model.png'.format(PREDICTAND), dpi=300, bbox_inches='tight')
``` ```
%% Cell type:code id:f022fcbc-077f-4f76-8330-47fa72a1fcec tags: %% Cell type:code id:f022fcbc-077f-4f76-8330-47fa72a1fcec tags:
``` python ``` python
# plot average of observation, prediction, and bias # plot average of observation, prediction, and bias
vmin, vmax = 0, 5 vmin, vmax = 0, 5
fig, axes = plt.subplots(1, 3, figsize=(24, 8), sharex=True, sharey=True) fig, axes = plt.subplots(1, 3, figsize=(24, 8), sharex=True, sharey=True)
axes = axes.flatten() axes = axes.flatten()
for ds, ax in zip([y_true_yearly_avg, y_pred_yearly_avg, bias_yearly_avg], axes): for ds, ax in zip([y_true_yearly_avg, y_pred_yearly_avg, bias_yearly_avg], axes):
if ds is bias_yearly_avg: if ds is bias_yearly_avg:
ds = ds.mean(dim='year') ds = ds.mean(dim='year')
im2 = ax.imshow(ds.values, origin='lower', cmap='RdBu_r', vmin=-40, vmax=40) im2 = ax.imshow(ds.values, origin='lower', cmap='RdBu_r', vmin=-40, vmax=40)
ax.text(x=ds.shape[0] - 2, y=2, s='Average: {:.1f}%'.format(ds.mean().item()), fontsize=14, ha='right') ax.text(x=ds.shape[0] - 2, y=2, s='Average: {:.1f}%'.format(ds.mean().item()), fontsize=14, ha='right')
else: else:
im1 = ax.imshow(ds.mean(dim='year').values, origin='lower', cmap='BuPu', vmin=vmin, vmax=vmax) im1 = ax.imshow(ds.mean(dim='year').values, origin='lower', cmap='BuPu', vmin=vmin, vmax=vmax)
# set titles # set titles
axes[0].set_title('Observed', fontsize=16, pad=10); axes[0].set_title('Observed', fontsize=16, pad=10);
axes[1].set_title('Predicted', fontsize=16, pad=10); axes[1].set_title('Predicted', fontsize=16, pad=10);
axes[2].set_title('Bias', fontsize=16, pad=10); axes[2].set_title('Bias', fontsize=16, pad=10);
# adjust axes # adjust axes
for ax in axes.flat: for ax in axes.flat:
ax.axes.get_xaxis().set_ticklabels([]) ax.axes.get_xaxis().set_ticklabels([])
ax.axes.get_xaxis().set_ticks([]) ax.axes.get_xaxis().set_ticks([])
ax.axes.get_yaxis().set_ticklabels([]) ax.axes.get_yaxis().set_ticklabels([])
ax.axes.get_yaxis().set_ticks([]) ax.axes.get_yaxis().set_ticks([])
ax.axes.axis('tight') ax.axes.axis('tight')
ax.set_xlabel('') ax.set_xlabel('')
ax.set_ylabel('') ax.set_ylabel('')
# adjust figure # adjust figure
fig.suptitle('Average {}: 1991 - 2010'.format(NAMES[PREDICTAND]), fontsize=20); fig.suptitle('Average {}: 1991 - 2010'.format(NAMES[PREDICTAND]), fontsize=20);
fig.subplots_adjust(hspace=0, wspace=0, top=0.85) fig.subplots_adjust(hspace=0, wspace=0, top=0.85)
# add colorbar for bias # add colorbar for bias
axes = axes.flatten() axes = axes.flatten()
cbar_ax_bias = fig.add_axes([axes[-1].get_position().x1 + 0.01, axes[-1].get_position().y0, cbar_ax_bias = fig.add_axes([axes[-1].get_position().x1 + 0.01, axes[-1].get_position().y0,
0.01, axes[-1].get_position().y1 - axes[-1].get_position().y0]) 0.01, axes[-1].get_position().y1 - axes[-1].get_position().y0])
cbar_bias = fig.colorbar(im2, cax=cbar_ax_bias) cbar_bias = fig.colorbar(im2, cax=cbar_ax_bias)
cbar_bias.set_label(label='Relative bias / (%)', fontsize=16) cbar_bias.set_label(label='Relative bias / (%)', fontsize=16)
cbar_bias.ax.tick_params(labelsize=14) cbar_bias.ax.tick_params(labelsize=14)
# add colorbar for predictand # add colorbar for predictand
cbar_ax_predictand = fig.add_axes([axes[0].get_position().x0, axes[0].get_position().y0 - 0.1, cbar_ax_predictand = fig.add_axes([axes[0].get_position().x0, axes[0].get_position().y0 - 0.1,
axes[-1].get_position().x0 - axes[0].get_position().x0, axes[-1].get_position().x0 - axes[0].get_position().x0,
0.05]) 0.05])
cbar_predictand = fig.colorbar(im1, cax=cbar_ax_predictand, orientation='horizontal') cbar_predictand = fig.colorbar(im1, cax=cbar_ax_predictand, orientation='horizontal')
cbar_predictand.set_label(label='{} / '.format(NAMES[PREDICTAND].capitalize()) + '(mm day$^{-1}$)', fontsize=16) cbar_predictand.set_label(label='{} / '.format(NAMES[PREDICTAND].capitalize()) + '(mm day$^{-1}$)', fontsize=16)
cbar_predictand.ax.tick_params(labelsize=14) cbar_predictand.ax.tick_params(labelsize=14)
# add metrics: MAE and RMSE # add metrics: MAE and RMSE
axes[1].text(x=ds.shape[0] - 2, y=2, s='MAE = {:.1f}'.format(mae_avg.mean().item()) + 'mm day$^{-1}$', fontsize=14, ha='right') axes[1].text(x=ds.shape[0] - 2, y=2, s='MAE = {:.1f}'.format(mae_avg.mean().item()) + 'mm day$^{-1}$', fontsize=14, ha='right')
axes[1].text(x=ds.shape[0] - 2, y=12, s='RMSE = {:.1f}'.format(rmse_avg.mean().item()) + 'mm day$^{-1}$', fontsize=14, ha='right') axes[1].text(x=ds.shape[0] - 2, y=12, s='RMSE = {:.1f}'.format(rmse_avg.mean().item()) + 'mm day$^{-1}$', fontsize=14, ha='right')
# save figure # save figure
fig.savefig('../Notebooks/Figures/{}_average_bias.png'.format(PREDICTAND), dpi=300, bbox_inches='tight') fig.savefig('../Notebooks/Figures/{}_average_bias.png'.format(PREDICTAND), dpi=300, bbox_inches='tight')
``` ```
%% Cell type:markdown id:f9eccfb4-c8f6-41f1-9636-9a46221cf53e tags: %% Cell type:markdown id:f9eccfb4-c8f6-41f1-9636-9a46221cf53e tags:
### Seasonal bias ### Seasonal bias
%% Cell type:markdown id:fe07f855-a68a-4bdf-a923-ef334fb82ad3 tags: %% Cell type:markdown id:fe07f855-a68a-4bdf-a923-ef334fb82ad3 tags:
Calculate seasonal bias: Calculate seasonal bias:
%% Cell type:code id:65813f9f-094d-43c0-8261-3364294631f2 tags: %% Cell type:code id:65813f9f-094d-43c0-8261-3364294631f2 tags:
``` python ``` python
# group data by season: (DJF, MAM, JJA, SON) # group data by season: (DJF, MAM, JJA, SON)
y_true_snl = y_true.groupby('time.season').mean(dim='time') y_true_snl = y_true.groupby('time.season').mean(dim='time')
y_pred_snl = y_pred_pr.groupby('time.season').mean(dim='time') y_pred_snl = y_pred_pr.groupby('time.season').mean(dim='time')
y_refe_snl = y_refe.groupby('time.season').mean(dim='time') y_refe_snl = y_refe.groupby('time.season').mean(dim='time')
bias_snl = ((y_pred_snl - y_true_snl) / y_true_snl) * 100 bias_snl = ((y_pred_snl - y_true_snl) / y_true_snl) * 100
bias_snl_ref = ((y_refe_snl - y_true_snl) / y_true_snl) * 100 bias_snl_ref = ((y_refe_snl - y_true_snl) / y_true_snl) * 100
``` ```
%% Cell type:code id:2a70e6ac-54fb-47b8-ae76-a1b51d9d1f17 tags: %% Cell type:code id:2a70e6ac-54fb-47b8-ae76-a1b51d9d1f17 tags:
``` python ``` python
# print average bias per season: ERA-5 # print average bias per season: ERA-5
for season in bias_snl_ref.season: for season in bias_snl_ref.season:
print('(ERA-5) Average bias for season {}: {:.1f}%'.format(season.values.item(), bias_snl_ref.sel(season=season).mean().item())) print('(ERA-5) Average bias for season {}: {:.1f}%'.format(season.values.item(), bias_snl_ref.sel(season=season).mean().item()))
``` ```
%% Cell type:code id:40fa07e4-2df1-4ac9-b27f-cd902e16ed30 tags: %% Cell type:code id:40fa07e4-2df1-4ac9-b27f-cd902e16ed30 tags:
``` python ``` python
# print average bias per season: model # print average bias per season: model
for season in bias_snl.season: for season in bias_snl.season:
print('(Model) Average bias for season {}: {:.1f}%'.format(season.values.item(), bias_snl.sel(season=season).mean().item())) print('(Model) Average bias for season {}: {:.1f}%'.format(season.values.item(), bias_snl.sel(season=season).mean().item()))
``` ```
%% Cell type:markdown id:a73b3ed4-eb44-4f3e-b240-7f6fe44bedd3 tags: %% Cell type:markdown id:a73b3ed4-eb44-4f3e-b240-7f6fe44bedd3 tags:
Plot seasonal differences, taken from the [xarray documentation](xarray.pydata.org/en/stable/examples/monthly-means.html). Plot seasonal differences, taken from the [xarray documentation](xarray.pydata.org/en/stable/examples/monthly-means.html).
%% Cell type:code id:2b2390e4-8bf1-41bd-b7d1-33e92fe8bc65 tags: %% Cell type:code id:2b2390e4-8bf1-41bd-b7d1-33e92fe8bc65 tags:
``` python ``` python
# plot seasonal differences # plot seasonal differences
seasons = ('DJF', 'JJA') seasons = ('DJF', 'JJA')
fig, axes = plt.subplots(nrows=1, ncols=len(seasons) + 1, figsize=(24,8), sharex=True, sharey=True) fig, axes = plt.subplots(nrows=1, ncols=len(seasons) + 1, figsize=(24,8), sharex=True, sharey=True)
axes = axes.flatten() axes = axes.flatten()
# plot annual average bias # plot annual average bias
ds = bias_yearly_avg.mean(dim='year') ds = bias_yearly_avg.mean(dim='year')
axes[0].imshow(ds.values, origin='lower', cmap='RdBu_r', vmin=-40, vmax=40) axes[0].imshow(ds.values, origin='lower', cmap='RdBu_r', vmin=-40, vmax=40)
axes[0].set_title('Annual', fontsize=16); axes[0].set_title('Annual', fontsize=16);
axes[0].text(x=ds.shape[0] - 2, y=2, s='Average: {:.1f}%'.format(ds.mean().item()), fontsize=14, ha='right') axes[0].text(x=ds.shape[0] - 2, y=2, s='Average: {:.1f}%'.format(ds.mean().item()), fontsize=14, ha='right')
# plot seasonal average bias # plot seasonal average bias
for ax, season in zip(axes[1:], seasons): for ax, season in zip(axes[1:], seasons):
ds = bias_snl.sel(season=season) ds = bias_snl.sel(season=season)
ax.imshow(ds.values, origin='lower', cmap='RdBu_r', vmin=-40, vmax=40) ax.imshow(ds.values, origin='lower', cmap='RdBu_r', vmin=-40, vmax=40)
ax.set_title(season, fontsize=16); ax.set_title(season, fontsize=16);
ax.text(x=ds.shape[0] - 2, y=2, s='Average: {:.1f}%'.format(ds.mean().item()), fontsize=14, ha='right') ax.text(x=ds.shape[0] - 2, y=2, s='Average: {:.1f}%'.format(ds.mean().item()), fontsize=14, ha='right')
# adjust axes # adjust axes
for ax in axes.flat: for ax in axes.flat:
ax.axes.get_xaxis().set_ticklabels([]) ax.axes.get_xaxis().set_ticklabels([])
ax.axes.get_xaxis().set_ticks([]) ax.axes.get_xaxis().set_ticks([])
ax.axes.get_yaxis().set_ticklabels([]) ax.axes.get_yaxis().set_ticklabels([])
ax.axes.get_yaxis().set_ticks([]) ax.axes.get_yaxis().set_ticks([])
ax.axes.axis('tight') ax.axes.axis('tight')
ax.set_xlabel('') ax.set_xlabel('')
ax.set_ylabel('') ax.set_ylabel('')
# adjust figure # adjust figure
fig.suptitle('Average bias of {}: 1991 - 2010'.format(NAMES[PREDICTAND]), fontsize=20); fig.suptitle('Average bias of {}: 1991 - 2010'.format(NAMES[PREDICTAND]), fontsize=20);
fig.subplots_adjust(hspace=0, wspace=0, top=0.85) fig.subplots_adjust(hspace=0, wspace=0, top=0.85)
# add colorbar for bias # add colorbar for bias
axes = axes.flatten() axes = axes.flatten()
cbar_ax = fig.add_axes([axes[-1].get_position().x1 + 0.01, axes[-1].get_position().y0, cbar_ax = fig.add_axes([axes[-1].get_position().x1 + 0.01, axes[-1].get_position().y0,
0.01, axes[-1].get_position().y1 - axes[-1].get_position().y0]) 0.01, axes[-1].get_position().y1 - axes[-1].get_position().y0])
cbar = fig.colorbar(im2, cax=cbar_ax) cbar = fig.colorbar(im2, cax=cbar_ax)
cbar.set_label(label='Relative bias / (%)', fontsize=16) cbar.set_label(label='Relative bias / (%)', fontsize=16)
cbar.ax.tick_params(labelsize=14) cbar.ax.tick_params(labelsize=14)
# save figure # save figure
fig.savefig('../Notebooks/Figures/{}_average_bias_seasonal.png'.format(PREDICTAND), dpi=300, bbox_inches='tight') fig.savefig('../Notebooks/Figures/{}_average_bias_seasonal.png'.format(PREDICTAND), dpi=300, bbox_inches='tight')
``` ```
%% Cell type:markdown id:ed9dee0e-9319-4ad5-a619-4ea30b116917 tags: %% Cell type:markdown id:ed9dee0e-9319-4ad5-a619-4ea30b116917 tags:
### Bias of extreme values ### Bias of extreme values
%% Cell type:code id:92f0ce1d-2cd3-46f6-be1b-86cfe5c8e9d7 tags: %% Cell type:code id:92f0ce1d-2cd3-46f6-be1b-86cfe5c8e9d7 tags:
``` python ``` python
# extreme quantile of interest # extreme quantile of interest
quantile = 0.98 quantile = 0.98
``` ```
%% Cell type:code id:bf98b3e0-7886-485f-a3bc-4bd64b7b2814 tags: %% Cell type:code id:bf98b3e0-7886-485f-a3bc-4bd64b7b2814 tags:
``` python ``` python
# calculate extreme quantile for each year # calculate extreme quantile for each year
with warnings.catch_warnings(): with warnings.catch_warnings():
warnings.simplefilter('ignore', category=RuntimeWarning) warnings.simplefilter('ignore', category=RuntimeWarning)
y_pred_ex = y_pred_pr.groupby('time.year').quantile(quantile, dim='time') y_pred_ex = y_pred_pr.groupby('time.year').quantile(quantile, dim='time')
y_true_ex = y_true.groupby('time.year').quantile(quantile, dim='time') y_true_ex = y_true.groupby('time.year').quantile(quantile, dim='time')
y_refe_ex = y_refe.groupby('time.year').quantile(quantile, dim='time') y_refe_ex = y_refe.groupby('time.year').quantile(quantile, dim='time')
``` ```
%% Cell type:code id:89893df3-987a-4794-a6e9-0893c925218c tags: %% Cell type:code id:89893df3-987a-4794-a6e9-0893c925218c tags:
``` python ``` python
# calculate bias in extreme quantile for each year # calculate bias in extreme quantile for each year
bias_ex = ((y_pred_ex - y_true_ex) / y_true_ex) * 100 bias_ex = ((y_pred_ex - y_true_ex) / y_true_ex) * 100
bias_ex_ref = ((y_refe_ex - y_true_ex) / y_true_ex) * 100 bias_ex_ref = ((y_refe_ex - y_true_ex) / y_true_ex) * 100
``` ```
%% Cell type:code id:9a610a70-bb94-47e6-b470-e01bf4a295c0 tags: %% Cell type:code id:9a610a70-bb94-47e6-b470-e01bf4a295c0 tags:
``` python ``` python
# bias of extreme quantile: ERA-5 # bias of extreme quantile: ERA-5
print('(ERA-5) Yearly average bias for P{:.0f}: {:.1f}%'.format(quantile * 100, bias_ex_ref.mean().item())) print('(ERA-5) Yearly average bias for P{:.0f}: {:.1f}%'.format(quantile * 100, bias_ex_ref.mean().item()))
``` ```
%% Cell type:code id:423962d7-453d-4c86-b93e-a4571bfec1c4 tags: %% Cell type:code id:423962d7-453d-4c86-b93e-a4571bfec1c4 tags:
``` python ``` python
# bias of extreme quantile: Model # bias of extreme quantile: Model
print('(Model) Yearly average bias for P{:.0f}: {:.1f}%'.format(quantile * 100, bias_ex.mean().item())) print('(Model) Yearly average bias for P{:.0f}: {:.1f}%'.format(quantile * 100, bias_ex.mean().item()))
``` ```
%% Cell type:code id:180b723a-49b0-4692-8249-b457eda48d44 tags: %% Cell type:code id:180b723a-49b0-4692-8249-b457eda48d44 tags:
``` python ``` python
# mean absolute error in extreme quantile # mean absolute error in extreme quantile
mae_ex = np.abs(y_pred_ex - y_true_ex).mean() mae_ex = np.abs(y_pred_ex - y_true_ex).mean()
mae_ex_ref = np.abs(y_refe_ex - y_true_ex).mean() mae_ex_ref = np.abs(y_refe_ex - y_true_ex).mean()
``` ```
%% Cell type:code id:7331acdd-60bb-414e-a0e7-e81956c3a1bf tags: %% Cell type:code id:7331acdd-60bb-414e-a0e7-e81956c3a1bf tags:
``` python ``` python
# mae of extreme quantile: ERA-5 # mae of extreme quantile: ERA-5
print('(ERA-5) Yearly average MAE for P{:.0f}: {:.1f} mm / day'.format(quantile * 100, mae_ex_ref.item())) print('(ERA-5) Yearly average MAE for P{:.0f}: {:.1f} mm / day'.format(quantile * 100, mae_ex_ref.item()))
``` ```
%% Cell type:code id:d7bf1a6c-e111-4c72-8774-64a3939cbe50 tags: %% Cell type:code id:d7bf1a6c-e111-4c72-8774-64a3939cbe50 tags:
``` python ``` python
# mae of extreme quantile: Model # mae of extreme quantile: Model
print('(Model) Yearly average MAE for P{:.0f}: {:.1f} mm / day'.format(quantile * 100, mae_ex.item())) print('(Model) Yearly average MAE for P{:.0f}: {:.1f} mm / day'.format(quantile * 100, mae_ex.item()))
``` ```
%% Cell type:code id:e6309621-7d16-43b7-a9d4-c2857dbe8270 tags: %% Cell type:code id:e6309621-7d16-43b7-a9d4-c2857dbe8270 tags:
``` python ``` python
# root mean squared error in extreme quantile # root mean squared error in extreme quantile
rmse_ex = ((y_pred_ex - y_true_ex) ** 2).mean() rmse_ex = ((y_pred_ex - y_true_ex) ** 2).mean()
rmse_ex_ref = ((y_refe_ex - y_true_ex) ** 2).mean() rmse_ex_ref = ((y_refe_ex - y_true_ex) ** 2).mean()
``` ```
%% Cell type:code id:712a2109-14b9-45a3-9194-e2b917c5ba3f tags: %% Cell type:code id:712a2109-14b9-45a3-9194-e2b917c5ba3f tags:
``` python ``` python
# rmse of extreme quantile: ERA-5 # rmse of extreme quantile: ERA-5
print('(ERA-5) Yearly average RMSE for P{:.0f}: {:.1f} mm / day'.format(quantile * 100, rmse_ex_ref.item())) print('(ERA-5) Yearly average RMSE for P{:.0f}: {:.1f} mm / day'.format(quantile * 100, rmse_ex_ref.item()))
``` ```
%% Cell type:code id:1dfad687-c16f-4767-8ddd-96c3e56bd96a tags: %% Cell type:code id:1dfad687-c16f-4767-8ddd-96c3e56bd96a tags:
``` python ``` python
# rmse of extreme quantile: Model # rmse of extreme quantile: Model
print('(Model) Yearly average RMSE for P{:.0f}: {:.1f} mm / day'.format(quantile * 100, rmse_ex.item())) print('(Model) Yearly average RMSE for P{:.0f}: {:.1f} mm / day'.format(quantile * 100, rmse_ex.item()))
``` ```
%% Cell type:code id:95338821-a78e-4a1c-ae73-25031ba7f6ac tags: %% Cell type:code id:95338821-a78e-4a1c-ae73-25031ba7f6ac tags:
``` python ``` python
# plot extremes of observation, prediction, and bias # plot extremes of observation, prediction, and bias
vmin, vmax = 10, 40 vmin, vmax = 10, 40
fig, axes = plt.subplots(1, 3, figsize=(24, 6), sharex=True, sharey=True) fig, axes = plt.subplots(1, 3, figsize=(24, 6), sharex=True, sharey=True)
axes = axes.reshape(1, -1) axes = axes.reshape(1, -1)
for ds, ax in zip([y_true_ex, y_pred_ex, bias_ex], axes[i, ...]): for ds, ax in zip([y_true_ex, y_pred_ex, bias_ex], axes[i, ...]):
if ds is bias_ex: if ds is bias_ex:
ds = ds.mean(dim='year') ds = ds.mean(dim='year')
im2 = ax.imshow(ds.values, origin='lower', cmap='RdBu_r', vmin=-40, vmax=40) im2 = ax.imshow(ds.values, origin='lower', cmap='RdBu_r', vmin=-40, vmax=40)
ax.text(x=ds.shape[0] - 2, y=2, s='Average: {:.1f}%'.format(ds.mean().item()), fontsize=14, ha='right') ax.text(x=ds.shape[0] - 2, y=2, s='Average: {:.1f}%'.format(ds.mean().item()), fontsize=14, ha='right')
else: else:
im1 = ax.imshow(ds.mean(dim='year').values, origin='lower', cmap='BuPu', vmin=vmin, vmax=vmax) im1 = ax.imshow(ds.mean(dim='year').values, origin='lower', cmap='BuPu', vmin=vmin, vmax=vmax)
# set titles # set titles
axes[0, 0].set_title('Observed', fontsize=16, pad=10); axes[0, 0].set_title('Observed', fontsize=16, pad=10);
axes[0, 1].set_title('Predicted', fontsize=16, pad=10); axes[0, 1].set_title('Predicted', fontsize=16, pad=10);
axes[0, 2].set_title('Bias', fontsize=16, pad=10); axes[0, 2].set_title('Bias', fontsize=16, pad=10);
# adjust axes # adjust axes
for ax in axes.flat: for ax in axes.flat:
ax.axes.get_xaxis().set_ticklabels([]) ax.axes.get_xaxis().set_ticklabels([])
ax.axes.get_xaxis().set_ticks([]) ax.axes.get_xaxis().set_ticks([])
ax.axes.get_yaxis().set_ticklabels([]) ax.axes.get_yaxis().set_ticklabels([])
ax.axes.get_yaxis().set_ticks([]) ax.axes.get_yaxis().set_ticks([])
ax.axes.axis('tight') ax.axes.axis('tight')
ax.set_xlabel('') ax.set_xlabel('')
ax.set_ylabel('') ax.set_ylabel('')
# adjust figure # adjust figure
fig.suptitle('Average P{:.0f} of {}: 1991 - 2010'.format(quantile * 100, NAMES[PREDICTAND]), fontsize=20); fig.suptitle('Average P{:.0f} of {}: 1991 - 2010'.format(quantile * 100, NAMES[PREDICTAND]), fontsize=20);
fig.subplots_adjust(hspace=0, wspace=0, top=0.85) fig.subplots_adjust(hspace=0, wspace=0, top=0.85)
# add colorbar for bias # add colorbar for bias
axes = axes.flatten() axes = axes.flatten()
cbar_ax = fig.add_axes([axes[-1].get_position().x1 + 0.01, axes[-1].get_position().y0, cbar_ax = fig.add_axes([axes[-1].get_position().x1 + 0.01, axes[-1].get_position().y0,
0.01, axes[-1].get_position().y1 - axes[-1].get_position().y0]) 0.01, axes[-1].get_position().y1 - axes[-1].get_position().y0])
cbar = fig.colorbar(im2, cax=cbar_ax) cbar = fig.colorbar(im2, cax=cbar_ax)
cbar.set_label(label='Relative bias / (%)', fontsize=16) cbar.set_label(label='Relative bias / (%)', fontsize=16)
cbar.ax.tick_params(labelsize=14) cbar.ax.tick_params(labelsize=14)
# add colorbar for predictand # add colorbar for predictand
cbar_ax_predictand = fig.add_axes([axes[0].get_position().x0, axes[0].get_position().y0 - 0.1, cbar_ax_predictand = fig.add_axes([axes[0].get_position().x0, axes[0].get_position().y0 - 0.1,
axes[-1].get_position().x0 - axes[0].get_position().x0, axes[-1].get_position().x0 - axes[0].get_position().x0,
0.05]) 0.05])
cbar_predictand = fig.colorbar(im1, cax=cbar_ax_predictand, orientation='horizontal') cbar_predictand = fig.colorbar(im1, cax=cbar_ax_predictand, orientation='horizontal')
cbar_predictand.set_label(label='{} / '.format(NAMES[PREDICTAND].capitalize()) + '(mm day$^{-1}$)', fontsize=16) cbar_predictand.set_label(label='{} / '.format(NAMES[PREDICTAND].capitalize()) + '(mm day$^{-1}$)', fontsize=16)
cbar_predictand.ax.tick_params(labelsize=14) cbar_predictand.ax.tick_params(labelsize=14)
# add metrics: MAE and RMSE # add metrics: MAE and RMSE
axes[1].text(x=ds.shape[0] - 2, y=2, s='MAE = {:.1f}'.format(mae_ex.item()) + 'mm day$^{-1}$', fontsize=14, ha='right') axes[1].text(x=ds.shape[0] - 2, y=2, s='MAE = {:.1f}'.format(mae_ex.item()) + 'mm day$^{-1}$', fontsize=14, ha='right')
axes[1].text(x=ds.shape[0] - 2, y=12, s='RMSE = {:.1f}'.format(rmse_ex.item()) + 'mm$^2$ day$^{-2}$', fontsize=14, ha='right') axes[1].text(x=ds.shape[0] - 2, y=12, s='RMSE = {:.1f}'.format(rmse_ex.item()) + 'mm$^2$ day$^{-2}$', fontsize=14, ha='right')
# save figure # save figure
fig.savefig('../Notebooks/Figures/{}_average_bias_p{:.0f}.png'.format(PREDICTAND, quantile * 100), dpi=300, bbox_inches='tight') fig.savefig('../Notebooks/Figures/{}_average_bias_p{:.0f}.png'.format(PREDICTAND, quantile * 100), dpi=300, bbox_inches='tight')
``` ```
%% Cell type:markdown id:f7758b71-e844-4a47-b741-25cc2d277814 tags: %% Cell type:markdown id:f7758b71-e844-4a47-b741-25cc2d277814 tags:
### Bias of extremes: winter vs. summer ### Bias of extremes: winter vs. summer
%% Cell type:code id:de16b327-dc4b-4645-b3a0-abc8f72c4225 tags: %% Cell type:code id:de16b327-dc4b-4645-b3a0-abc8f72c4225 tags:
``` python ``` python
# group data by season and compute extreme percentile # group data by season and compute extreme percentile
with warnings.catch_warnings(): with warnings.catch_warnings():
warnings.simplefilter('ignore', category=RuntimeWarning) warnings.simplefilter('ignore', category=RuntimeWarning)
y_true_ex_snl = y_true.groupby('time.season').quantile(quantile, dim='time') y_true_ex_snl = y_true.groupby('time.season').quantile(quantile, dim='time')
y_pred_ex_snl = y_pred_pr.groupby('time.season').quantile(quantile, dim='time') y_pred_ex_snl = y_pred_pr.groupby('time.season').quantile(quantile, dim='time')
y_refe_ex_snl = y_refe.groupby('time.season').quantile(quantile, dim='time') y_refe_ex_snl = y_refe.groupby('time.season').quantile(quantile, dim='time')
``` ```
%% Cell type:code id:185a375c-b7bc-42cf-a57d-08268e824f21 tags: %% Cell type:code id:185a375c-b7bc-42cf-a57d-08268e824f21 tags:
``` python ``` python
# compute relative bias in seasonal extremes # compute relative bias in seasonal extremes
bias_ex_snl = ((y_pred_ex_snl - y_true_ex_snl) / y_true_ex_snl) * 100 bias_ex_snl = ((y_pred_ex_snl - y_true_ex_snl) / y_true_ex_snl) * 100
bias_ex_snl_ref = ((y_refe_ex_snl - y_true_ex_snl) / y_true_ex_snl) * 100 bias_ex_snl_ref = ((y_refe_ex_snl - y_true_ex_snl) / y_true_ex_snl) * 100
``` ```
%% Cell type:code id:e5f82b16-7ebf-4b73-80dd-0bc7e6a305a1 tags: %% Cell type:code id:e5f82b16-7ebf-4b73-80dd-0bc7e6a305a1 tags:
``` python ``` python
# print average bias in extreme per season: ERA-5 # print average bias in extreme per season: ERA-5
for season in bias_ex_snl_ref.season: for season in bias_ex_snl_ref.season:
print('(ERA-5) Average bias of P{:.0f} for season {}: {:.1f}%'.format(quantile * 100, season.values.item(), bias_ex_snl_ref.sel(season=season).mean().item())) print('(ERA-5) Average bias of P{:.0f} for season {}: {:.1f}%'.format(quantile * 100, season.values.item(), bias_ex_snl_ref.sel(season=season).mean().item()))
``` ```
%% Cell type:code id:0671b4f3-52c6-4167-b35b-03902dbe11a3 tags: %% Cell type:code id:0671b4f3-52c6-4167-b35b-03902dbe11a3 tags:
``` python ``` python
# print average bias in extreme per season: Model # print average bias in extreme per season: Model
for season in bias_ex_snl.season: for season in bias_ex_snl.season:
print('(Model) Average bias of P{:.0f} for season {}: {:.1f}%'.format(quantile * 100, season.values.item(), bias_ex_snl.sel(season=season).mean().item())) print('(Model) Average bias of P{:.0f} for season {}: {:.1f}%'.format(quantile * 100, season.values.item(), bias_ex_snl.sel(season=season).mean().item()))
``` ```
%% Cell type:code id:322ee9c4-92db-4041-8a36-45aa8d2021b5 tags: %% Cell type:code id:322ee9c4-92db-4041-8a36-45aa8d2021b5 tags:
``` python ``` python
# plot seasonal differences # plot seasonal differences
seasons = ('DJF', 'JJA') seasons = ('DJF', 'JJA')
fig, axes = plt.subplots(nrows=1, ncols=len(seasons) + 1, figsize=(24,8), sharex=True, sharey=True) fig, axes = plt.subplots(nrows=1, ncols=len(seasons) + 1, figsize=(24,8), sharex=True, sharey=True)
axes = axes.flatten() axes = axes.flatten()
# plot annual average bias of extreme # plot annual average bias of extreme
ds = bias_ex.mean(dim='year') ds = bias_ex.mean(dim='year')
axes[0].imshow(ds.values, origin='lower', cmap='RdBu_r', vmin=-40, vmax=40) axes[0].imshow(ds.values, origin='lower', cmap='RdBu_r', vmin=-40, vmax=40)
axes[0].set_title('Annual', fontsize=16); axes[0].set_title('Annual', fontsize=16);
axes[0].text(x=ds.shape[0] - 2, y=2, s='Average: {:.1f}%'.format(ds.mean().item()), fontsize=14, ha='right') axes[0].text(x=ds.shape[0] - 2, y=2, s='Average: {:.1f}%'.format(ds.mean().item()), fontsize=14, ha='right')
# plot seasonal average bias of extreme # plot seasonal average bias of extreme
for ax, season in zip(axes[1:], seasons): for ax, season in zip(axes[1:], seasons):
ds = bias_ex_snl.sel(season=season) ds = bias_ex_snl.sel(season=season)
ax.imshow(ds.values, origin='lower', cmap='RdBu_r', vmin=-40, vmax=40) ax.imshow(ds.values, origin='lower', cmap='RdBu_r', vmin=-40, vmax=40)
ax.set_title(season, fontsize=16); ax.set_title(season, fontsize=16);
ax.text(x=ds.shape[0] - 2, y=2, s='Average: {:.1f}%'.format(ds.mean().item()), fontsize=14, ha='right') ax.text(x=ds.shape[0] - 2, y=2, s='Average: {:.1f}%'.format(ds.mean().item()), fontsize=14, ha='right')
# adjust axes # adjust axes
for ax in axes.flat: for ax in axes.flat:
ax.axes.get_xaxis().set_ticklabels([]) ax.axes.get_xaxis().set_ticklabels([])
ax.axes.get_xaxis().set_ticks([]) ax.axes.get_xaxis().set_ticks([])
ax.axes.get_yaxis().set_ticklabels([]) ax.axes.get_yaxis().set_ticklabels([])
ax.axes.get_yaxis().set_ticks([]) ax.axes.get_yaxis().set_ticks([])
ax.axes.axis('tight') ax.axes.axis('tight')
ax.set_xlabel('') ax.set_xlabel('')
ax.set_ylabel('') ax.set_ylabel('')
# adjust figure # adjust figure
fig.suptitle('Average bias of P{:.0f} of {}: 1991 - 2010'.format(quantile * 100, NAMES[PREDICTAND]), fontsize=20); fig.suptitle('Average bias of P{:.0f} of {}: 1991 - 2010'.format(quantile * 100, NAMES[PREDICTAND]), fontsize=20);
fig.subplots_adjust(hspace=0, wspace=0, top=0.85) fig.subplots_adjust(hspace=0, wspace=0, top=0.85)
# add colorbar for bias # add colorbar for bias
axes = axes.flatten() axes = axes.flatten()
cbar_ax = fig.add_axes([axes[-1].get_position().x1 + 0.01, axes[-1].get_position().y0, cbar_ax = fig.add_axes([axes[-1].get_position().x1 + 0.01, axes[-1].get_position().y0,
0.01, axes[-1].get_position().y1 - axes[-1].get_position().y0]) 0.01, axes[-1].get_position().y1 - axes[-1].get_position().y0])
cbar = fig.colorbar(im2, cax=cbar_ax) cbar = fig.colorbar(im2, cax=cbar_ax)
cbar.set_label(label='Relative bias / (%)', fontsize=16) cbar.set_label(label='Relative bias / (%)', fontsize=16)
cbar.ax.tick_params(labelsize=14) cbar.ax.tick_params(labelsize=14)
# save figure # save figure
fig.savefig('../Notebooks/Figures/{}_average_bias_seasonal_ex.png'.format(PREDICTAND), dpi=300, bbox_inches='tight') fig.savefig('../Notebooks/Figures/{}_average_bias_seasonal_ex.png'.format(PREDICTAND), dpi=300, bbox_inches='tight')
``` ```
%% Cell type:markdown id:1f1a76c5-0bcb-4f79-832c-caf958e5703a tags: %% Cell type:markdown id:1f1a76c5-0bcb-4f79-832c-caf958e5703a tags:
### Frequency of wet days ### Frequency of wet days
%% Cell type:code id:4491395e-f69a-4749-ac73-c4cbce60c3bc tags: %% Cell type:code id:4491395e-f69a-4749-ac73-c4cbce60c3bc tags:
``` python ``` python
# minimum precipitation (mm / day) defining a wet day # minimum precipitation (mm / day) defining a wet day
WET_DAY_THRESHOLD = float(WET_DAY_THRESHOLD) WET_DAY_THRESHOLD = float(WET_DAY_THRESHOLD)
``` ```
%% Cell type:code id:40676a79-26a2-4714-900e-d530fda32f8f tags: %% Cell type:code id:40676a79-26a2-4714-900e-d530fda32f8f tags:
``` python ``` python
# true and predicted frequency of wet days # true and predicted frequency of wet days
mask = (~np.isnan(y_true)) & (~np.isnan(y_pred_pr)) mask = (~np.isnan(y_true)) & (~np.isnan(y_pred_pr))
wet_days_true = (y_true >= WET_DAY_THRESHOLD).where(mask, other=np.nan).astype(np.float32) wet_days_true = (y_true >= WET_DAY_THRESHOLD).where(mask, other=np.nan).astype(np.float32)
wet_days_pred = (y_pred_pr >= WET_DAY_THRESHOLD).where(mask, other=np.nan).astype(np.float32) wet_days_pred = (y_pred_pr >= WET_DAY_THRESHOLD).where(mask, other=np.nan).astype(np.float32)
``` ```
%% Cell type:code id:1ad9ad08-ad1f-4a76-9020-c09a442385c9 tags: %% Cell type:code id:1ad9ad08-ad1f-4a76-9020-c09a442385c9 tags:
``` python ``` python
# number of wet days in reference period: annual # number of wet days in reference period: annual
n_wet_days_true = wet_days_true.sum(dim='time', skipna=False) n_wet_days_true = wet_days_true.sum(dim='time', skipna=False)
n_wet_days_pred = wet_days_pred.sum(dim='time', skipna=False) n_wet_days_pred = wet_days_pred.sum(dim='time', skipna=False)
``` ```
%% Cell type:code id:54088572-1873-4949-9937-0ea099e9c2b6 tags: %% Cell type:code id:54088572-1873-4949-9937-0ea099e9c2b6 tags:
``` python ``` python
# frequency of wet days in reference period: annual # frequency of wet days in reference period: annual
f_wet_days_true = (n_wet_days_true / len(wet_days_true.time)) * 100 f_wet_days_true = (n_wet_days_true / len(wet_days_true.time)) * 100
f_wet_days_pred = (n_wet_days_pred / len(wet_days_pred.time)) * 100 f_wet_days_pred = (n_wet_days_pred / len(wet_days_pred.time)) * 100
``` ```
%% Cell type:code id:3534d1d9-f50d-40f3-b1f0-d34c4152a37c tags: %% Cell type:code id:3534d1d9-f50d-40f3-b1f0-d34c4152a37c tags:
``` python ``` python
# frequency of wet days in reference period: seasonal # frequency of wet days in reference period: seasonal
f_wet_days_true_snl = wet_days_true.groupby('time.season').mean(dim='time', skipna=False) f_wet_days_true_snl = wet_days_true.groupby('time.season').mean(dim='time', skipna=False)
f_wet_days_pred_snl = wet_days_pred.groupby('time.season').mean(dim='time', skipna=False) f_wet_days_pred_snl = wet_days_pred.groupby('time.season').mean(dim='time', skipna=False)
``` ```
%% Cell type:code id:60ac0f93-f9ca-4f75-8786-e692efe3a556 tags: %% Cell type:code id:60ac0f93-f9ca-4f75-8786-e692efe3a556 tags:
``` python ``` python
# relative bias of frequency of wet vs. dry days: annual # relative bias of frequency of wet vs. dry days: annual
bias_wet = ((f_wet_days_pred - f_wet_days_true) / f_wet_days_true) * 100 bias_wet = ((f_wet_days_pred - f_wet_days_true) / f_wet_days_true) * 100
# relative bias of frequency of wet vs. dry days: seasonal # relative bias of frequency of wet vs. dry days: seasonal
bias_wet_snl = ((f_wet_days_pred_snl - f_wet_days_true_snl) / f_wet_days_true_snl) * 100 bias_wet_snl = ((f_wet_days_pred_snl - f_wet_days_true_snl) / f_wet_days_true_snl) * 100
``` ```
%% Cell type:code id:e8e9d30a-2fbc-4f29-bc82-bfbc155430a0 tags: %% Cell type:code id:e8e9d30a-2fbc-4f29-bc82-bfbc155430a0 tags:
``` python ``` python
# plot average of observation, prediction, and bias # plot average of observation, prediction, and bias
fig, axes = plt.subplots(2, 3, figsize=(24, 12), sharex=True, sharey=True) fig, axes = plt.subplots(2, 3, figsize=(24, 12), sharex=True, sharey=True)
axes = axes.flatten() axes = axes.flatten()
# plot annual average bias of extreme # plot annual average bias of extreme
ds = bias_wet ds = bias_wet
im = axes[0].imshow(ds.values, origin='lower', cmap='RdBu_r', vmin=-40, vmax=40) im = axes[0].imshow(ds.values, origin='lower', cmap='RdBu_r', vmin=-40, vmax=40)
axes[0].set_title('Annual', fontsize=16); axes[0].set_title('Annual', fontsize=16);
axes[0].text(x=ds.shape[0] - 2, y=2, s='Average: {:.1f}%'.format(ds.mean().item()), fontsize=14, ha='right') axes[0].text(x=ds.shape[0] - 2, y=2, s='Average: {:.1f}%'.format(ds.mean().item()), fontsize=14, ha='right')
# plot seasonal average bias of extreme # plot seasonal average bias of extreme
for ax, season in zip(axes[1:], bias_wet_snl.season): for ax, season in zip(axes[1:], bias_wet_snl.season):
ds = bias_wet_snl.sel(season=season) ds = bias_wet_snl.sel(season=season)
ax.imshow(ds.values, origin='lower', cmap='RdBu_r', vmin=-40, vmax=40) ax.imshow(ds.values, origin='lower', cmap='RdBu_r', vmin=-40, vmax=40)
ax.set_title(season.item(), fontsize=16); ax.set_title(season.item(), fontsize=16);
ax.text(x=ds.shape[0] - 2, y=2, s='Average: {:.1f}%'.format(ds.mean().item()), fontsize=14, ha='right') ax.text(x=ds.shape[0] - 2, y=2, s='Average: {:.1f}%'.format(ds.mean().item()), fontsize=14, ha='right')
# adjust axes # adjust axes
for ax in axes.flat: for ax in axes.flat:
ax.axes.get_xaxis().set_ticklabels([]) ax.axes.get_xaxis().set_ticklabels([])
ax.axes.get_xaxis().set_ticks([]) ax.axes.get_xaxis().set_ticks([])
ax.axes.get_yaxis().set_ticklabels([]) ax.axes.get_yaxis().set_ticklabels([])
ax.axes.get_yaxis().set_ticks([]) ax.axes.get_yaxis().set_ticks([])
ax.axes.axis('tight') ax.axes.axis('tight')
ax.set_xlabel('') ax.set_xlabel('')
ax.set_ylabel('') ax.set_ylabel('')
# turn off last axis # turn off last axis
axes[-1].set_visible(False) axes[-1].set_visible(False)
# adjust figure # adjust figure
fig.suptitle('Frequency of wet days (>= {:.1f} mm): 1991 - 2010'.format(WET_DAY_THRESHOLD), fontsize=20); fig.suptitle('Frequency of wet days (>= {:.1f} mm): 1991 - 2010'.format(WET_DAY_THRESHOLD), fontsize=20);
fig.subplots_adjust(hspace=0.1, wspace=0, top=0.925) fig.subplots_adjust(hspace=0.1, wspace=0, top=0.925)
# add colorbar # add colorbar
cbar_ax_predictand = fig.add_axes([axes[-1].get_position().x1 + 0.01, axes[-1].get_position().y0, cbar_ax_predictand = fig.add_axes([axes[-1].get_position().x1 + 0.01, axes[-1].get_position().y0,
0.01, axes[0].get_position().y1 - axes[-1].get_position().y0]) 0.01, axes[0].get_position().y1 - axes[-1].get_position().y0])
cbar_predictand = fig.colorbar(im, cax=cbar_ax_predictand) cbar_predictand = fig.colorbar(im, cax=cbar_ax_predictand)
cbar_predictand.set_label(label='Relative bias / (%)', fontsize=16) cbar_predictand.set_label(label='Relative bias / (%)', fontsize=16)
cbar_predictand.ax.tick_params(labelsize=14) cbar_predictand.ax.tick_params(labelsize=14)
# save figure # save figure
fig.savefig('../Notebooks/Figures/{}_bias_wet_days.png'.format(PREDICTAND), dpi=300, bbox_inches='tight') fig.savefig('../Notebooks/Figures/{}_bias_wet_days.png'.format(PREDICTAND), dpi=300, bbox_inches='tight')
``` ```
%% Cell type:markdown id:fde67674-8bed-45b2-b4ca-4a507084b5a8 tags: %% Cell type:markdown id:fde67674-8bed-45b2-b4ca-4a507084b5a8 tags:
### Mean wet day precipitation ### Mean wet day precipitation
%% Cell type:code id:aa2d2b5a-e192-4c34-975c-d4e61c969a15 tags: %% Cell type:code id:aa2d2b5a-e192-4c34-975c-d4e61c969a15 tags:
``` python ``` python
# calculate mean wet day precipitation # calculate mean wet day precipitation
dii_true = (y_true * wet_days_true).sum(dim='time', skipna=False) / n_wet_days_true dii_true = (y_true * wet_days_true).sum(dim='time', skipna=False) / n_wet_days_true
dii_pred = (y_pred_pr * wet_days_pred).sum(dim='time', skipna=False) / n_wet_days_pred dii_pred = (y_pred_pr * wet_days_pred).sum(dim='time', skipna=False) / n_wet_days_pred
``` ```
%% Cell type:code id:d6331bb6-afe8-4e62-9fa8-60bce9af42ec tags: %% Cell type:code id:d6331bb6-afe8-4e62-9fa8-60bce9af42ec tags:
``` python ``` python
# calculate relative bias of mean wet day precipitation # calculate relative bias of mean wet day precipitation
bias_dii = ((dii_pred - dii_true) / dii_true) * 100 bias_dii = ((dii_pred - dii_true) / dii_true) * 100
``` ```
%% Cell type:code id:36e3cbac-233d-4892-9756-941d91981534 tags: %% Cell type:code id:36e3cbac-233d-4892-9756-941d91981534 tags:
``` python ``` python
# plot average of observation, prediction, and bias # plot average of observation, prediction, and bias
fig, axes = plt.subplots(1, 3, figsize=(24, 6), sharex=True, sharey=True) fig, axes = plt.subplots(1, 3, figsize=(24, 6), sharex=True, sharey=True)
for ds, ax in zip([dii_true, dii_pred, bias_dii], axes): for ds, ax in zip([dii_true, dii_pred, bias_dii], axes):
if ds is bias_dii: if ds is bias_dii:
im2 = ax.imshow(ds.values, origin='lower', cmap='RdBu_r', vmin=-40, vmax=40) im2 = ax.imshow(ds.values, origin='lower', cmap='RdBu_r', vmin=-40, vmax=40)
ax.text(x=ds.shape[0] - 2, y=2, s='Average: {:.1f}%'.format(ds.mean().item()), fontsize=14, ha='right') ax.text(x=ds.shape[0] - 2, y=2, s='Average: {:.1f}%'.format(ds.mean().item()), fontsize=14, ha='right')
else: else:
im1 = ax.imshow(ds.values, origin='lower', cmap='BuPu', vmin=0, vmax=15) im1 = ax.imshow(ds.values, origin='lower', cmap='BuPu', vmin=0, vmax=15)
# set titles # set titles
axes[0].set_title('Observed', fontsize=16, pad=10); axes[0].set_title('Observed', fontsize=16, pad=10);
axes[1].set_title('Predicted', fontsize=16, pad=10); axes[1].set_title('Predicted', fontsize=16, pad=10);
axes[2].set_title('Bias', fontsize=16, pad=10); axes[2].set_title('Bias', fontsize=16, pad=10);
# adjust axes # adjust axes
for ax in axes.flat: for ax in axes.flat:
ax.axes.get_xaxis().set_ticklabels([]) ax.axes.get_xaxis().set_ticklabels([])
ax.axes.get_xaxis().set_ticks([]) ax.axes.get_xaxis().set_ticks([])
ax.axes.get_yaxis().set_ticklabels([]) ax.axes.get_yaxis().set_ticklabels([])
ax.axes.get_yaxis().set_ticks([]) ax.axes.get_yaxis().set_ticks([])
ax.axes.axis('tight') ax.axes.axis('tight')
ax.set_xlabel('') ax.set_xlabel('')
ax.set_ylabel('') ax.set_ylabel('')
# adjust figure # adjust figure
fig.suptitle('Mean wet day (>= {:.1f} mm) precipitation: 1991 - 2010'.format(WET_DAY_THRESHOLD), fontsize=20); fig.suptitle('Mean wet day (>= {:.1f} mm) precipitation: 1991 - 2010'.format(WET_DAY_THRESHOLD), fontsize=20);
fig.subplots_adjust(hspace=0, wspace=0, top=0.85) fig.subplots_adjust(hspace=0, wspace=0, top=0.85)
# add colorbar for bias # add colorbar for bias
axes = axes.flatten() axes = axes.flatten()
cbar_ax_bias = fig.add_axes([axes[-1].get_position().x1 + 0.01, axes[-1].get_position().y0, cbar_ax_bias = fig.add_axes([axes[-1].get_position().x1 + 0.01, axes[-1].get_position().y0,
0.01, axes[-1].get_position().y1 - axes[-1].get_position().y0]) 0.01, axes[-1].get_position().y1 - axes[-1].get_position().y0])
cbar_bias = fig.colorbar(im2, cax=cbar_ax_bias) cbar_bias = fig.colorbar(im2, cax=cbar_ax_bias)
cbar_bias.set_label(label='Relative bias / (%)', fontsize=16) cbar_bias.set_label(label='Relative bias / (%)', fontsize=16)
cbar_bias.ax.tick_params(labelsize=14) cbar_bias.ax.tick_params(labelsize=14)
# add colorbar for predictand # add colorbar for predictand
cbar_ax_predictand = fig.add_axes([axes[0].get_position().x0, axes[0].get_position().y0 - 0.1, cbar_ax_predictand = fig.add_axes([axes[0].get_position().x0, axes[0].get_position().y0 - 0.1,
axes[-1].get_position().x0 - axes[0].get_position().x0, axes[-1].get_position().x0 - axes[0].get_position().x0,
0.05]) 0.05])
cbar_predictand = fig.colorbar(im1, cax=cbar_ax_predictand, orientation='horizontal') cbar_predictand = fig.colorbar(im1, cax=cbar_ax_predictand, orientation='horizontal')
cbar_predictand.set_label(label='Mean wet day precipitation / (mm day$^{-1}$)', fontsize=16) cbar_predictand.set_label(label='Mean wet day precipitation / (mm day$^{-1}$)', fontsize=16)
cbar_predictand.ax.tick_params(labelsize=14) cbar_predictand.ax.tick_params(labelsize=14)
# save figure # save figure
fig.savefig('../Notebooks/Figures/{}_bias_wet_days_p.png'.format(PREDICTAND), dpi=300, bbox_inches='tight') fig.savefig('../Notebooks/Figures/{}_bias_wet_days_p.png'.format(PREDICTAND), dpi=300, bbox_inches='tight')
``` ```
%% Cell type:markdown id:359e6d22-5f7f-4203-bb2a-fa51e731d867 tags: %% Cell type:markdown id:359e6d22-5f7f-4203-bb2a-fa51e731d867 tags:
## Model validation: precipitation probability ## Model validation: precipitation probability
%% Cell type:markdown id:5bd5c551-d426-45d7-9338-5b5003ca676e tags: %% Cell type:markdown id:5bd5c551-d426-45d7-9338-5b5003ca676e tags:
### ROC: Receiver operating characteristics ### ROC: Receiver operating characteristics
%% Cell type:code id:1bff2faf-c894-4917-8dcf-310106c153ad tags: %% Cell type:code id:1bff2faf-c894-4917-8dcf-310106c153ad tags:
``` python ``` python
# true and predicted probability of precipitation # true and predicted probability of precipitation
p_true = (y_true > float(WET_DAY_THRESHOLD)).values.flatten() p_true = (y_true > float(WET_DAY_THRESHOLD)).values.flatten()
p_pred = y_pred_prob.values.flatten() p_pred = y_pred_prob.values.flatten()
``` ```
%% Cell type:code id:6c4a7097-26af-4497-bf52-43601038be41 tags: %% Cell type:code id:6c4a7097-26af-4497-bf52-43601038be41 tags:
``` python ``` python
# apply mask of valid pixels # apply mask of valid pixels
mask = (~np.isnan(p_true) & ~np.isnan(p_pred)) mask = (~np.isnan(p_true) & ~np.isnan(p_pred))
p_pred = p_pred[mask] p_pred = p_pred[mask]
p_true = p_true[mask].astype(float) p_true = p_true[mask].astype(float)
``` ```
%% Cell type:code id:e4c35894-e31f-4879-806e-0b3935412730 tags: %% Cell type:code id:e4c35894-e31f-4879-806e-0b3935412730 tags:
``` python ``` python
# calculate ROC: false positive rate vs. true positive rate # calculate ROC: false positive rate vs. true positive rate
fpr, tpr, _ = roc_curve(p_true, p_pred) fpr, tpr, _ = roc_curve(p_true, p_pred)
area = auc(fpr, tpr) # area under ROC curve area = auc(fpr, tpr) # area under ROC curve
rocss = 2 * area - 1 # ROC skill score (cf. https://journals.ametsoc.org/view/journals/clim/16/24/1520-0442_2003_016_4145_otrsop_2.0.co_2.xml) rocss = 2 * area - 1 # ROC skill score (cf. https://journals.ametsoc.org/view/journals/clim/16/24/1520-0442_2003_016_4145_otrsop_2.0.co_2.xml)
``` ```
%% Cell type:code id:ac418718-44bf-4378-a1da-7fb75e45b3a8 tags: %% Cell type:code id:ac418718-44bf-4378-a1da-7fb75e45b3a8 tags:
``` python ``` python
# plot ROC curve # plot ROC curve
fig, ax = plt.subplots(1, 1, figsize=(10, 10)) fig, ax = plt.subplots(1, 1, figsize=(10, 10))
ax.plot(fpr, tpr, lw=2, label='Area={:.2f}, ROCSS={:.2f}'.format(area, rocss), color='k') ax.plot(fpr, tpr, lw=2, label='Area={:.2f}, ROCSS={:.2f}'.format(area, rocss), color='k')
# plot classifier with no skill # plot classifier with no skill
interval = np.arange(-0.05, 1.1, 0.05) interval = np.arange(-0.05, 1.1, 0.05)
ax.plot([0, 1], [0, 1], lw=2, linestyle='--', color='k') ax.plot([0, 1], [0, 1], lw=2, linestyle='--', color='k')
ax.text(0.95, 0.975, 'Random Classifier', ha='right', va='top', rotation=45, fontsize=12) ax.text(0.95, 0.975, 'Random Classifier', ha='right', va='top', rotation=45, fontsize=12)
# plot perfect classifier # plot perfect classifier
ax.plot(0, 1, '-o', markersize=5, markerfacecolor='k', markeredgecolor='none') ax.plot(0, 1, '-o', markersize=5, markerfacecolor='k', markeredgecolor='none')
ax.text(0.02, 1, 'Perfect classifier', va='center', fontsize=12) ax.text(0.02, 1, 'Perfect classifier', va='center', fontsize=12)
# plot direction of increase / decrease # plot direction of increase / decrease
ax.arrow(np.median(interval), np.median(interval), 0.1, -0.1, head_width=0.01, facecolor='k') ax.arrow(np.median(interval), np.median(interval), 0.1, -0.1, head_width=0.01, facecolor='k')
ax.arrow(np.median(interval), np.median(interval), -0.1, 0.1, head_width=0.01, facecolor='k') ax.arrow(np.median(interval), np.median(interval), -0.1, 0.1, head_width=0.01, facecolor='k')
ax.text(np.median(interval) + 0.05, np.median(interval) - 0.05, s='Worse', rotation=45, ha='left', fontsize=12) ax.text(np.median(interval) + 0.05, np.median(interval) - 0.05, s='Worse', rotation=45, ha='left', fontsize=12)
ax.text(np.median(interval) - 0.05, np.median(interval) + 0.05, s='Better', rotation=45, ha='left', fontsize=12) ax.text(np.median(interval) - 0.05, np.median(interval) + 0.05, s='Better', rotation=45, ha='left', fontsize=12)
# adjust axes # adjust axes
ax.set_xticks(np.arange(0, 1.1, 0.1)) ax.set_xticks(np.arange(0, 1.1, 0.1))
ax.set_xticklabels(['{:.2f}'.format(i) for i in np.arange(0, 1.1, 0.1)], fontsize=12) ax.set_xticklabels(['{:.2f}'.format(i) for i in np.arange(0, 1.1, 0.1)], fontsize=12)
ax.set_yticks(np.arange(0, 1.1, 0.1)) ax.set_yticks(np.arange(0, 1.1, 0.1))
ax.set_yticklabels(['{:.2f}'.format(i) for i in np.arange(0, 1.1, 0.1)], fontsize=12) ax.set_yticklabels(['{:.2f}'.format(i) for i in np.arange(0, 1.1, 0.1)], fontsize=12)
ax.set_xlim(interval[0], interval[-1]) ax.set_xlim(interval[0], interval[-1])
ax.set_ylim(interval[0], interval[-1]) ax.set_ylim(interval[0], interval[-1])
ax.set_xlabel('False Positive Rate', fontsize=14) ax.set_xlabel('False Positive Rate', fontsize=14)
ax.set_ylabel('True Positive Rate', fontsize=14) ax.set_ylabel('True Positive Rate', fontsize=14)
ax.set_title('ROC of precipitation probability: 1991 - 2010', fontsize=14, pad=10) ax.set_title('ROC of precipitation probability: 1991 - 2010', fontsize=14, pad=10)
ax.legend(frameon=False, loc='lower right', fontsize=14); ax.legend(frameon=False, loc='lower right', fontsize=14);
# save figure # save figure
fig.savefig('../Notebooks/Figures/{}_ROC.png'.format(PREDICTAND), dpi=300, bbox_inches='tight') fig.savefig('../Notebooks/Figures/{}_ROC.png'.format(PREDICTAND), dpi=300, bbox_inches='tight')
``` ```
......
%% Cell type:markdown id:63805b4a-b30e-4c10-a948-bc59651ca7a6 tags: %% Cell type:markdown id:63805b4a-b30e-4c10-a948-bc59651ca7a6 tags:
### Imports ### Imports
%% Cell type:code id:28982ce9-bf0c-4eb1-8b9e-bec118359966 tags: %% Cell type:code id:28982ce9-bf0c-4eb1-8b9e-bec118359966 tags:
``` python ``` python
# builtins # builtins
import datetime import datetime
import warnings import warnings
import calendar import calendar
# externals # externals
import xarray as xr import xarray as xr
import numpy as np import numpy as np
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import seaborn as sns import seequantilesn as sns
import pandas as pd import pandas as pd
import scipy.stats as stats import scipy.stats as stats
from mpl_toolkits.axes_grid1.inset_locator import inset_axes from mpl_toolkits.axes_grid1.inset_locator import inset_axes
import scipy.stats as stats import scipy.stats as stats
from IPython.display import Image from IPython.display import Image
from sklearn.metrics import r2_score, roc_curve, auc, classification_report from sklearn.metrics import r2_score, roc_curve, auc, classification_report
from sklearn.model_selection import train_test_split from sklearn.model_selection import train_test_split
# locals # locals
from climax.main.io import ERA5_PATH, OBS_PATH, TARGET_PATH, DEM_PATH from climax.main.io import ERA5_PATH, OBS_PATH, TARGET_PATH, DEM_PATH
from climax.main.config import CALIB_PERIOD, VALID_PERIOD from climax.main.config import CALIB_PERIOD, VALID_PERIOD
from pysegcnn.core.utils import search_files from pysegcnn.core.utils import search_files
from pysegcnn.core.graphics import plot_classification_report from pysegcnn.core.graphics import plot_classification_report
``` ```
%% Cell type:code id:de6ae734-3a6a-477e-a5a0-8b9ec5911369 tags: %% Cell type:code id:de6ae734-3a6a-477e-a5a0-8b9ec5911369 tags:
``` python ``` python
# entire reference period # entire reference period
REFERENCE_PERIOD = np.concatenate([CALIB_PERIOD, VALID_PERIOD], axis=0) REFERENCE_PERIOD = np.concatenate([CALIB_PERIOD, VALID_PERIOD], axis=0)
``` ```
%% Cell type:code id:534d9565-4b58-4959-bef3-edde969e2364 tags: %% Cell type:code id:534d9565-4b58-4959-bef3-edde969e2364 tags:
``` python ``` python
# empirical quantiles # empirical quantiles
quantiles = np.arange(0.01, 1, 0.005) quantiles = np.arange(0.01, 1, 0.005)
``` ```
%% Cell type:markdown id:12382efb-1a3a-4ede-a904-7f762bfe56c7 tags: %% Cell type:markdown id:12382efb-1a3a-4ede-a904-7f762bfe56c7 tags:
### Load observations ### Load observations
%% Cell type:code id:2373d894-e252-4f16-826b-88731e195259 tags: %% Cell type:code id:2373d894-e252-4f16-826b-88731e195259 tags:
``` python ``` python
# model predictions and observations NetCDF # model predictions and observations NetCDF
y_true = xr.open_dataset(search_files(OBS_PATH.joinpath('pr'), 'OBS_pr(.*).nc$').pop()) y_true = xr.open_dataset(search_files(OBS_PATH.joinpath('pr'), 'OBS_pr(.*).nc$').pop())
``` ```
%% Cell type:markdown id:5d30b543-aa3b-45f3-b8e8-90d72f4f6896 tags: %% Cell type:markdown id:5d30b543-aa3b-45f3-b8e8-90d72f4f6896 tags:
### Select time period ### Select time period
%% Cell type:code id:f902683a-a560-48f9-b2d1-ef9c341ca69a tags: %% Cell type:code id:f902683a-a560-48f9-b2d1-ef9c341ca69a tags:
``` python ``` python
# time period # time period
PERIOD = REFERENCE_PERIOD PERIOD = REFERENCE_PERIOD
``` ```
%% Cell type:code id:0c2c1912-a947-4afe-84a7-895726be5cfd tags: %% Cell type:code id:0c2c1912-a947-4afe-84a7-895726be5cfd tags:
``` python ``` python
# subset to time period # subset to time period
y = y_true.sel(time=PERIOD) y = y_true.sel(time=PERIOD)
``` ```
%% Cell type:markdown id:f6d01e1e-9dc2-4c31-a31a-a6c91abc7fb4 tags: %% Cell type:markdown id:f6d01e1e-9dc2-4c31-a31a-a6c91abc7fb4 tags:
### Fit distributions: annually ### Fit distributions: annually
%% Cell type:code id:0ffce851-50fc-4795-84b9-972e4f1a5169 tags: %% Cell type:code id:0ffce851-50fc-4795-84b9-972e4f1a5169 tags:
``` python ``` python
# helper function retrieving only valid observations # helper function retrieving only valid observations
def valid(ds): def valid(ds):
valid = ds.precipitation.values valid = ds.precipitation.values
valid = valid[~np.isnan(valid)] # mask missing values valid = valid[~np.isnan(valid)] # mask missing values
valid = valid[valid > 0] # only consider pr > 0 valid = valid[valid > 0] # only consider pr > 0
return valid return valid
``` ```
%% Cell type:code id:6f68803b-4dbc-4d43-99c0-a32e482b647a tags: %% Cell type:code id:6f68803b-4dbc-4d43-99c0-a32e482b647a tags:
``` python ``` python
# valid observations # valid observations
y_valid = valid(y) y_valid = valid(y)
``` ```
%% Cell type:code id:5de4933a-ef9d-4afe-8af6-ff68d91860ce tags: %% Cell type:code id:5de4933a-ef9d-4afe-8af6-ff68d91860ce tags:
``` python ``` python
# fit gamma distribution to data # fit gamma distribution to data
alpha, loc, beta = stats.gamma.fit(y_valid, floc=0) alpha, loc, beta = stats.gamma.fit(y_valid, floc=0)
gamma = stats.gamma(alpha, loc=loc, scale=beta) gamma = stats.gamma(alpha, loc=loc, scale=beta)
``` ```
%% Cell type:code id:dcd9bfeb-67dc-4b63-98fd-c86c3a07c2b0 tags:
``` python
# fit lognormal distribution
alpha, loc, beta = stats.lognorm.fit(y_valid, floc=0)
lognorm = stats.lognorm(alpha, loc=loc, scale=beta)
```
%% Cell type:code id:75b74f7c-c9d7-4d52-b140-e0ad9de17b69 tags: %% Cell type:code id:75b74f7c-c9d7-4d52-b140-e0ad9de17b69 tags:
``` python ``` python
# fit generalized pareto distribution to data # fit generalized pareto distribution to data
alpha, loc, beta = stats.genpareto.fit(y_valid, floc=0) alpha, loc, beta = stats.genpareto.fit(y_valid, floc=0)
genpareto = stats.genpareto(alpha, loc=loc, scale=beta) genpareto = stats.genpareto(alpha, loc=loc, scale=beta)
``` ```
%% Cell type:code id:d489a3e7-7ece-440e-bbd9-1cfd739d822c tags:
``` python
# fit exponential distribution to data
loc, beta = stats.expon.fit(y_valid, floc=0)
expon = stats.expon(loc=loc, scale=beta)
```
%% Cell type:code id:01d8c7d9-541e-481d-b0de-e8590c571ca5 tags:
``` python
# fit weibull distribution to data
alpha, loc, beta = stats.weibull_min.fit(y_valid, floc=0)
weibull = stats.weibull_min(alpha, loc=loc, scale=beta)
```
%% Cell type:code id:14ade547-443a-457a-bccd-d88d049b9d81 tags: %% Cell type:code id:14ade547-443a-457a-bccd-d88d049b9d81 tags:
``` python ``` python
# empirical quantiles and theoretical quantiles # empirical quantiles and theoretical quantiles
eq = np.quantile(y_valid, quantiles) eq = np.quantile(y_valid, quantiles)
tq_gamma = gamma.ppf(quantiles) tq_gamma = gamma.ppf(quantiles)
tq_genpareto = genpareto.ppf(quantiles) tq_genpareto = genpareto.ppf(quantiles)
tq_expon = expon.ppf(quantiles)
tq_lognorm = lognorm.ppf(quantiles)
tq_weibull = weibull.ppf(quantiles)
# Q-Q plot # Q-Q plot
RANGE = 40 RANGE = 40
fig, ax = plt.subplots(1, 1, figsize=(6, 6)) fig, ax = plt.subplots(1, 1, figsize=(6, 6))
ax.scatter(eq, tq_gamma, color='grey', label='Gamma') ax.scatter(eq, tq_gamma, marker='*', color='k', label='Gamma')
ax.scatter(eq, tq_genpareto, color='k', label='GenPareto') ax.scatter(eq, tq_genpareto, marker='x', color='k', label='GenPareto')
ax.scatter(eq, tq_expon, marker='o', color='k', label='Expon')
ax.scatter(eq, tq_lognorm, marker='+', color='k', label='LogNorm')
ax.scatter(eq, tq_weibull, marker='^', color='k', label='Weibull')
ax.plot(np.arange(0, RANGE), np.arange(0, RANGE), '--k') ax.plot(np.arange(0, RANGE), np.arange(0, RANGE), '--k')
ax.set_xlim(0, RANGE) ax.set_xlim(0, RANGE)
ax.set_ylim(0, RANGE) ax.set_ylim(0, RANGE)
ax.set_xticks(np.arange(0, RANGE + 5, 5)) ax.set_xticks(np.arange(0, RANGE + 5, 5))
ax.set_yticks(np.arange(0, RANGE + 5, 5)) ax.set_yticks(np.arange(0, RANGE + 5, 5))
ax.set_xticklabels([str(t) for t in np.arange(0, RANGE + 5, 5)], fontsize=12) ax.set_xticklabels([str(t) for t in np.arange(0, RANGE + 5, 5)], fontsize=12)
ax.set_yticklabels([str(t) for t in np.arange(0, RANGE + 5, 5)], fontsize=12) ax.set_yticklabels([str(t) for t in np.arange(0, RANGE + 5, 5)], fontsize=12)
ax.set_ylabel('Theoretical quantiles', fontsize=14); ax.set_ylabel('Theoretical quantiles', fontsize=14);
ax.set_xlabel('Empirical quantiles', fontsize=14); ax.set_xlabel('Empirical quantiles', fontsize=14);
ax.legend(frameon=False, fontsize=14); ax.legend(frameon=False, fontsize=14);
ax.set_title('Reference period: {} - {}'.format(str(PERIOD[0]), str(PERIOD[-1])), fontsize=14) ax.set_title('Reference period: {} - {}'.format(str(PERIOD[0]), str(PERIOD[-1])), fontsize=14)
# save figure # save figure
fig.savefig('./Figures/pr_distribution.png', bbox_inches='tight', dpi=300) fig.savefig('./Figures/pr_distribution.png', bbox_inches='tight', dpi=300)
``` ```
%% Cell type:markdown id:5fd0e9d8-759d-45ee-bb1f-9c749ac23e8e tags: %% Cell type:markdown id:5fd0e9d8-759d-45ee-bb1f-9c749ac23e8e tags:
### Fit distributions: monthly ### Fit distributions: monthly
%% Cell type:code id:156e5415-4065-4887-b759-0e665d671b38 tags: %% Cell type:code id:156e5415-4065-4887-b759-0e665d671b38 tags:
``` python ``` python
# get the indices of the observations for each month # get the indices of the observations for each month
month_idx = y.groupby('time.month').groups month_idx = y.groupby('time.month').groups
``` ```
%% Cell type:code id:092e865d-f033-4f60-8098-86ae5068e045 tags: %% Cell type:code id:092e865d-f033-4f60-8098-86ae5068e045 tags:
``` python ``` python
# fit distribution to observations for each month # fit distribution to observations for each month
month_gamma = {} month_gamma = {}
month_genpareto = {} month_genpareto = {}
month_expon = {}
month_lognorm = {}
month_weibull = {}
for month, idx in month_idx.items(): for month, idx in month_idx.items():
print('Month: {}'.format(calendar.month_name[month])) print('Month: {}'.format(calendar.month_name[month]))
# select the data of the current month # select the data of the current month
data = y.isel(time=idx) data = y.isel(time=idx)
data = valid(data) data = valid(data)
# fit distributions # fit distributions
# gamma # gamma
alpha, loc, beta = stats.gamma.fit(data, floc=0) alpha, loc, beta = stats.gamma.fit(data, floc=0)
gamma = stats.gamma(alpha, loc=loc, scale=beta) gamma = stats.gamma(alpha, loc=loc, scale=beta)
month_gamma[month] = gamma month_gamma[month] = gamma
# genpareto # genpareto
alpha, loc, beta = stats.genpareto.fit(data, floc=0) alpha, loc, beta = stats.genpareto.fit(data, floc=0)
genpareto = stats.genpareto(alpha, loc=loc, scale=beta) genpareto = stats.genpareto(alpha, loc=loc, scale=beta)
month_genpareto[month] = genpareto month_genpareto[month] = genpareto
# exponential
loc, beta = stats.expon.fit(data, floc=0)
expon = stats.expon(loc=loc, scale=beta)
month_expon[month] = expon
# lognormal
alpha, loc, beta = stats.lognorm.fit(data, floc=0)
lognorm = stats.lognorm(alpha, loc=loc, scale=beta)
month_lognorm[month] = lognorm
# weibull
alpha, loc, beta = stats.weibull_min.fit(data, floc=0)
weibull = stats.weibull_min(alpha, loc=loc, scale=beta)
month_weibull[month] = weibull
``` ```
%% Cell type:code id:396e5ee4-1632-4591-b93b-91fa6ac1d373 tags: %% Cell type:code id:396e5ee4-1632-4591-b93b-91fa6ac1d373 tags:
``` python ``` python
# plot empirical vs. theoretical quantiles for each month # plot empirical vs. theoretical quantiles for each month
fig, axes = plt.subplots(4, 3, figsize=(12, 12), sharex=True, sharey=True) fig, axes = plt.subplots(4, 3, figsize=(12, 12), sharex=True, sharey=True)
axes = axes.flatten() axes = axes.flatten()
RANGE = 40 RANGE = 40
for month, idx in month_idx.items(): for month, idx in month_idx.items():
# axis to plot # axis to plot
ax = axes[month - 1] ax = axes[month - 1]
# compute empirical quantiles # compute empirical quantiles
data = y.isel(time=idx) data = y.isel(time=idx)
data = valid(data) data = valid(data)
eq = np.quantile(data, quantiles) eq = np.quantile(data, quantiles)
# compute theoretical quantiles # compute theoretical quantiles
tq_gamma = month_gamma[month].ppf(quantiles) tq_gamma = month_gamma[month].ppf(quantiles)
tq_gpare = month_genpareto[month].ppf(quantiles) tq_gpare = month_genpareto[month].ppf(quantiles)
tq_expon = month_expon[month].ppf(quantiles)
tq_lognr = month_lognorm[month].ppf(quantiles)
tq_weibu = month_weibull[month].ppf(quantiles)
# plot empirical vs. theoretical quantiles # plot empirical vs. theoretical quantiles
ax.scatter(eq, tq_gamma, color='grey', label='Gamma') ax.scatter(eq, tq_gamma, marker='*', color='k', label='Gamma')
ax.scatter(eq, tq_gpare, color='k', label='GenPareto') ax.scatter(eq, tq_gpare, marker='x', color='k', label='GenPareto')
ax.scatter(eq, tq_expon, marker='o', color='k', label='Expon')
ax.scatter(eq, tq_lognr, marker='+', color='k', label='LogNorm')
ax.scatter(eq, tq_weibu, marker='^', color='k', label='Weibull')
ax.plot(np.arange(0, RANGE), np.arange(0, RANGE), '-k') ax.plot(np.arange(0, RANGE), np.arange(0, RANGE), '-k')
ax.set_title(calendar.month_name[month], fontsize=14) ax.set_title(calendar.month_name[month], fontsize=14)
ax.set_xlim(0, RANGE) ax.set_xlim(0, RANGE)
ax.set_ylim(0, RANGE) ax.set_ylim(0, RANGE)
ax.set_xticks(np.arange(0, RANGE + 5, 5)) ax.set_xticks(np.arange(0, RANGE + 5, 5))
ax.set_yticks(np.arange(0, RANGE + 5, 5)) ax.set_yticks(np.arange(0, RANGE + 5, 5))
ax.set_xticklabels([str(t) for t in np.arange(0, RANGE + 5, 5)], fontsize=12) ax.set_xticklabels([str(t) for t in np.arange(0, RANGE + 5, 5)], fontsize=12)
ax.set_yticklabels([str(t) for t in np.arange(0, RANGE + 5, 5)], fontsize=12) ax.set_yticklabels([str(t) for t in np.arange(0, RANGE + 5, 5)], fontsize=12)
# add legend # add legend
axes[0].legend(frameon=False, fontsize=12) axes[0].legend(frameon=False, fontsize=12, loc='upper left')
# add figure title # add figure title
fig.suptitle('Reference period: {} - {}'.format(str(PERIOD[0]), str(PERIOD[-1])), fontsize=14) fig.suptitle('Reference period: {} - {}'.format(str(PERIOD[0]), str(PERIOD[-1])), fontsize=14)
# adjust subplots # adjust subplots
fig.subplots_adjust(wspace=0.1) fig.subplots_adjust(wspace=0.1)
fig.savefig('./Figures/pr_distribution_m.png', bbox_inches='tight', dpi=300) fig.savefig('./Figures/pr_distribution_m.png', bbox_inches='tight', dpi=300)
``` ```
%% Cell type:markdown id:c0fea8ac-bac0-4096-bc81-90d799f8ab94 tags: %% Cell type:markdown id:c0fea8ac-bac0-4096-bc81-90d799f8ab94 tags:
### Empirical quantiles per grid point ### Empirical quantiles per grid point
%% Cell type:code id:a02c42e0-591c-4630-89b8-5dd8ef71a4a0 tags: %% Cell type:code id:a02c42e0-591c-4630-89b8-5dd8ef71a4a0 tags:
``` python ``` python
# compute empirical quantiles over time # compute empirical quantiles over time
equantiles = y.precipitation.quantile(quantiles, dim='time') equantiles = y.precipitation.quantile(quantiles, dim='time')
equantiles = equantiles.rename({'quantile': 'q'}) equantiles = equantiles.rename({'quantile': 'q'})
``` ```
%% Cell type:code id:966d2724-2628-4842-abc9-695711945347 tags: %% Cell type:code id:966d2724-2628-4842-abc9-695711945347 tags:
``` python ``` python
# iterate over the grid points # iterate over the grid points
gammaq = np.ones(shape=(len(equantiles.q), len(equantiles.y), len(equantiles.x))) * np.nan gammaq = np.ones(shape=(len(equantiles.q), len(equantiles.y), len(equantiles.x))) * np.nan
genpaq = np.ones(shape=(len(equantiles.q), len(equantiles.y), len(equantiles.x))) * np.nan genpaq = np.ones(shape=(len(equantiles.q), len(equantiles.y), len(equantiles.x))) * np.nan
exponq = np.ones(shape=(len(equantiles.q), len(equantiles.y), len(equantiles.x))) * np.nan
lognrq = np.ones(shape=(len(equantiles.q), len(equantiles.y), len(equantiles.x))) * np.nan
weibuq = np.ones(shape=(len(equantiles.q), len(equantiles.y), len(equantiles.x))) * np.nan
for i, _ in enumerate(y.x): for i, _ in enumerate(y.x):
print('Rows: {}/{}'.format(i + 1, len(y.x))) print('Rows: {}/{}'.format(i + 1, len(y.x)))
for j, _ in enumerate(y.y): for j, _ in enumerate(y.y):
# current grid point: xarray.Dataset, dimensions=(time) # current grid point: xarray.Dataset, dimensions=(time)
point = y.isel(x=i, y=j) point = y.isel(x=i, y=j)
point = valid(point) point = valid(point)
# check if the grid point is valid # check if the grid point is valid
if point.size < 1: if point.size < 1:
# move on to next grid point # move on to next grid point
continue continue
# fit Gamma distribution to grid point # fit Gamma distribution to grid point
alpha, loc, beta = stats.gamma.fit(point, floc=0) alpha, loc, beta = stats.gamma.fit(point, floc=0)
gamma = stats.gamma(alpha, loc=loc, scale=beta) gamma = stats.gamma(alpha, loc=loc, scale=beta)
# fit GenPareto distribution to grid point # fit GenPareto distribution to grid point
alpha, loc, beta = stats.genpareto.fit(point, floc=0) alpha, loc, beta = stats.genpareto.fit(point, floc=0)
genpa = stats.genpareto(alpha, loc=loc, scale=beta) genpa = stats.genpareto(alpha, loc=loc, scale=beta)
# fit Exponential distribution to grid point
loc, beta = stats.expon.fit(point, floc=0)
expon = stats.expon(loc=loc, scale=beta)
# fit LogNormal distribution
alpha, loc, beta = stats.lognorm.fit(point, floc=0)
lognr = stats.lognorm(alpha, loc=loc, scale=beta)
# fit Weibull distribution
alpha, loc, beta = stats.weibull_min.fit(point, floc=0)
weibu = stats.weibull_min(alpha, loc=loc, scale=beta)
# compute theoretical quantiles of fitted distributions # compute theoretical quantiles of fitted distributions
tq_gamma = gamma.ppf(quantiles) tq_gamma = gamma.ppf(quantiles)
tq_genpa = genpa.ppf(quantiles) tq_genpa = genpa.ppf(quantiles)
tq_expon = expon.ppf(quantiles)
tq_lognr = lognr.ppf(quantiles)
tq_weibu = weibu.ppf(quantiles)
# store theoretical quantiles for current grid point # store theoretical quantiles for current grid point
gammaq[:, j, i] = tq_gamma gammaq[:, j, i] = tq_gamma
genpaq[:, j, i] = tq_genpa genpaq[:, j, i] = tq_genpa
exponq[:, j, i] = tq_expon
lognrq[:, j, i] = tq_lognr
weibuq[:, j, i] = tq_weibu
# store theoretical quantiles in xarray.DataArray # store theoretical quantiles in xarray.DataArray
gammaq = xr.DataArray(data=gammaq, dims=['q', 'y', 'x'], coords=dict(q=quantiles, y=y.y, x=y.x), gammaq = xr.DataArray(data=gammaq, dims=['q', 'y', 'x'], coords=dict(q=quantiles, y=y.y, x=y.x),
name='precipitation') name='precipitation')
genpaq = xr.DataArray(data=genpaq, dims=['q', 'y', 'x'], coords=dict(q=quantiles, y=y.y, x=y.x), genpaq = xr.DataArray(data=genpaq, dims=['q', 'y', 'x'], coords=dict(q=quantiles, y=y.y, x=y.x),
name='precipitation') name='precipitation')
exponq = xr.DataArray(data=exponq, dims=['q', 'y', 'x'], coords=dict(q=quantiles, y=y.y, x=y.x),
name='precipitation')
lognrq = xr.DataArray(data=lognrq, dims=['q', 'y', 'x'], coords=dict(q=quantiles, y=y.y, x=y.x),
name='precipitation')
weibuq = xr.DataArray(data=weibuq, dims=['q', 'y', 'x'], coords=dict(q=quantiles, y=y.y, x=y.x),
name='precipitation')
``` ```
%% Cell type:code id:601de7cb-35f4-40e1-9b51-2dab23102659 tags: %% Cell type:code id:601de7cb-35f4-40e1-9b51-2dab23102659 tags:
``` python ``` python
# compute bias in theoretical quantiles # compute bias in theoretical quantiles
bias_gamma = gammaq - equantiles # predicted - observed bias_gamma = gammaq - equantiles # predicted - observed
bias_genpa = genpaq - equantiles bias_genpa = genpaq - equantiles
bias_expon = exponq - equantiles
bias_lognr = lognrq - equantiles
bias_weibu = weibuq - equantiles
```
%% Cell type:code id:23abd0d1-7c27-4f02-b7ae-9165c2dde0b6 tags:
``` python
# distributions
dists = {k: v for k, v in zip(['gamma', 'genpareto', 'expon', 'lognr', 'weibu'], [bias_gamma, bias_genpa, bias_expon, bias_lognr, bias_weibu])}
``` ```
%% Cell type:code id:b8089c11-a48d-4028-9d4b-e03101ff5e55 tags: %% Cell type:code id:b8089c11-a48d-4028-9d4b-e03101ff5e55 tags:
``` python ``` python
# plot spatial bias in different quantiles # plot spatial bias in different quantiles
plot_quantiles = quantiles[18::20] plot_quantiles = quantiles[18::20]
fig, axes = plt.subplots(3, 3, sharex=True, sharey=True, figsize=(12, 12)) fig, axes = plt.subplots(3, 3, sharex=True, sharey=True, figsize=(12, 12))
axes = axes.flatten() axes = axes.flatten()
for dist in ['gamma', 'genpareto']: for dist, biasq in dists.items():
biasq = bias_gamma if dist == 'gamma' else bias_genpa
# iterate over quantiles to plot # iterate over quantiles to plot
for ax, q in zip(axes, plot_quantiles): for ax, q in zip(axes, plot_quantiles):
im = ax.imshow(biasq.sel(q=q).values, origin='lower', vmin=0, vmax=5, cmap='viridis_r') im = ax.imshow(biasq.sel(q=q).values, origin='lower', vmin=0, vmax=5, cmap='viridis_r')
ax.set_title(str('P{:.0f}'.format(q * 100)), fontsize=14) ax.set_title(str('P{:.0f}'.format(q * 100)), fontsize=14)
# adjust subplots # adjust subplots
fig.subplots_adjust(wspace=0.1, hspace=0.1) fig.subplots_adjust(wspace=0.1, hspace=0.1)
# add colorbar for bias # add colorbar for bias
axes = axes.flatten() axes = axes.flatten()
cbar_ax_bias = fig.add_axes([axes[2].get_position().x1 + 0.01, axes[2].get_position().y0, cbar_ax_bias = fig.add_axes([axes[2].get_position().x1 + 0.01, axes[2].get_position().y0,
0.01, axes[2].get_position().y1 - axes[2].get_position().y0]) 0.01, axes[2].get_position().y1 - axes[2].get_position().y0])
cbar_bias = fig.colorbar(im, cax=cbar_ax_bias) cbar_bias = fig.colorbar(im, cax=cbar_ax_bias)
cbar_bias.set_label(label='Bias (mm)', fontsize=14) cbar_bias.set_label(label='Bias (mm)', fontsize=14)
cbar_bias.ax.tick_params(labelsize=14, pad=10) cbar_bias.ax.tick_params(labelsize=14, pad=10)
# save figure # save figure
fig
fig.savefig('./Figures/pr_distribution_{}_grid.png'.format(dist), bbox_inches='tight', dpi=300) fig.savefig('./Figures/pr_distribution_{}_grid.png'.format(dist), bbox_inches='tight', dpi=300)
``` ```
%% Cell type:code id:a5ee4d3f-608b-4598-b235-3cd20a184aff tags:
``` python
```
......
%% Cell type:markdown id:d6b83379-c5a8-48c3-bb85-d00a341a37f4 tags: %% Cell type:markdown id:d6b83379-c5a8-48c3-bb85-d00a341a37f4 tags:
### Imports ### Imports
%% Cell type:code id:eeba7f9b-066a-4843-bd64-5b6326c0b536 tags: %% Cell type:code id:eeba7f9b-066a-4843-bd64-5b6326c0b536 tags:
``` python ``` python
# builtins # builtins
import datetime import datetime
import warnings import warnings
import calendar import calendar
# externals # externals
import xarray as xr import xarray as xr
import numpy as np import numpy as np
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import seaborn as sns import seaborn as sns
import pandas as pd import pandas as pd
from mpl_toolkits.axes_grid1.inset_locator import inset_axes from mpl_toolkits.axes_grid1.inset_locator import inset_axes
import scipy.stats as stats import scipy.stats as stats
from IPython.display import Image from IPython.display import Image
from sklearn.metrics import r2_score, roc_curve, auc, classification_report from sklearn.metrics import r2_score, roc_curve, auc, classification_report
from sklearn.model_selection import train_test_split from sklearn.model_selection import train_test_split
# locals # locals
from climax.main.io import ERA5_PATH, OBS_PATH, TARGET_PATH, DEM_PATH from climax.main.io import ERA5_PATH, OBS_PATH, TARGET_PATH, DEM_PATH
from climax.main.config import CALIB_PERIOD from climax.main.config import CALIB_PERIOD
from pysegcnn.core.utils import search_files from pysegcnn.core.utils import search_files
from pysegcnn.core.graphics import plot_classification_report from pysegcnn.core.graphics import plot_classification_report
``` ```
%% Cell type:code id:e75b3217-26f7-4a4a-ae2a-4fbb92a9f2a2 tags: %% Cell type:code id:e75b3217-26f7-4a4a-ae2a-4fbb92a9f2a2 tags:
``` python ``` python
# model predictions and observations NetCDF # model predictions and observations NetCDF
y_true = xr.open_dataset(search_files(OBS_PATH.joinpath('pr'), 'OBS_pr(.*).nc$').pop()) y_true = xr.open_dataset(search_files(OBS_PATH.joinpath('pr'), 'OBS_pr(.*).nc$').pop())
``` ```
%% Cell type:code id:3aa8466e-84a9-4c2e-ae19-403b6246e27f tags: %% Cell type:code id:3aa8466e-84a9-4c2e-ae19-403b6246e27f tags:
``` python ``` python
# subset to calibration period # subset to calibration period
y_true = y_true.sel(time=CALIB_PERIOD) y_true = y_true.sel(time=CALIB_PERIOD)
``` ```
%% Cell type:code id:4f1a58a2-8c4c-4d73-a116-e64e68fdd507 tags: %% Cell type:code id:4f1a58a2-8c4c-4d73-a116-e64e68fdd507 tags:
``` python ``` python
# precipitation threshold defining a wet day # precipitation threshold defining a wet day
WET_DAY_THRESHOLD = 1 WET_DAY_THRESHOLD = 1
``` ```
%% Cell type:code id:5e6696df-8660-4083-9a32-0dd282112948 tags: %% Cell type:code id:5e6696df-8660-4083-9a32-0dd282112948 tags:
``` python ``` python
# calculate number of wet days in calibration period # calculate number of wet days in calibration period
wet_days = (y_true.mean(dim=('y', 'x')) >= WET_DAY_THRESHOLD).astype(np.int16) wet_days = (y_true.mean(dim=('y', 'x')) >= WET_DAY_THRESHOLD).astype(np.int16)
nwet_days = wet_days.to_array().values.squeeze() nwet_days = wet_days.to_array().values.squeeze()
``` ```
%% Cell type:code id:b87accd6-d5e4-4dc6-9532-3ef8aa162d24 tags: %% Cell type:code id:b87accd6-d5e4-4dc6-9532-3ef8aa162d24 tags:
``` python ``` python
# split training/validation set chronologically # split training/validation set chronologically
train, valid = train_test_split(CALIB_PERIOD, shuffle=False, test_size=0.25) train, valid = train_test_split(CALIB_PERIOD, shuffle=False, test_size=0.25)
``` ```
%% Cell type:code id:559d1450-09db-4b2f-844a-d572485973e0 tags: %% Cell type:code id:559d1450-09db-4b2f-844a-d572485973e0 tags:
``` python ``` python
# split training/validation set by number of wet days # split training/validation set by number of wet days
train_st, valid_st = train_test_split(CALIB_PERIOD, stratify=nwet_days, test_size=0.5) train_st, valid_st = train_test_split(CALIB_PERIOD, stratify=nwet_days, test_size=0.5)
train_st, valid_st = np.asarray(sorted(train_st)), np.asarray(sorted(valid_st)) # sort chronologically train_st, valid_st = np.asarray(sorted(train_st)), np.asarray(sorted(valid_st)) # sort chronologically
``` ```
%% Cell type:code id:7fd013f9-77d0-48de-8d5f-2c6a1cb3ed17 tags: %% Cell type:code id:7fd013f9-77d0-48de-8d5f-2c6a1cb3ed17 tags:
``` python ``` python
# plot distribution of wet days in calibration period # plot distribution of wet days in calibration period
fig, axes = plt.subplots(2, 2, sharex=True, sharey=True, figsize=(10, 10)) fig, axes = plt.subplots(2, 2, sharex=True, sharey=True, figsize=(10, 10))
axes = axes.flatten() axes = axes.flatten()
# not stratified # not stratified
sns.countplot(x=wet_days.sel(time=train).to_array().values.squeeze(), ax=axes[0]) sns.countplot(x=wet_days.sel(time=train).to_array().values.squeeze(), ax=axes[0])
sns.countplot(x=wet_days.sel(time=valid).to_array().values.squeeze(), ax=axes[2]) sns.countplot(x=wet_days.sel(time=valid).to_array().values.squeeze(), ax=axes[2])
# stratified # stratified
sns.countplot(x=wet_days.sel(time=train_st).to_array().values.squeeze(), ax=axes[1]) sns.countplot(x=wet_days.sel(time=train_st).to_array().values.squeeze(), ax=axes[1])
sns.countplot(x=wet_days.sel(time=valid_st).to_array().values.squeeze(), ax=axes[3]) sns.countplot(x=wet_days.sel(time=valid_st).to_array().values.squeeze(), ax=axes[3])
# axes properties # axes properties
for ax in axes: for ax in axes:
ax.set_ylabel('') ax.set_ylabel('')
for ax in axes[2:]: for ax in axes[2:]:
ax.set_xticklabels(['Dry', 'Wet']) ax.set_xticklabels(['Dry', 'Wet'])
for ax in [axes[0], axes[1]]: for ax in [axes[0], axes[1]]:
ax.text(1, ax.get_ylim()[-1] - 5, 'Training', ha='left', va='top', fontsize=12) ax.text(1, ax.get_ylim()[-1] - 5, 'Training', ha='left', va='top', fontsize=12)
for ax in [axes[2], axes[3]]: for ax in [axes[2], axes[3]]:
ax.text(1, ax.get_ylim()[-1] - 5, 'Validation', ha='left', va='top', fontsize=12) ax.text(1, ax.get_ylim()[-1] - 5, 'Validation', ha='left', va='top', fontsize=12)
axes[0].set_title('Not stratified') axes[0].set_title('Not stratified')
axes[1].set_title('Stratified') axes[1].set_title('Stratified')
# adjust subplot # adjust subplot
fig.subplots_adjust(wspace=0.1, hspace=0.1) fig.subplots_adjust(wspace=0.1, hspace=0.1)
fig.suptitle('Stratified sampling: wet day threshold {:0d} mm'.format(WET_DAY_THRESHOLD)); fig.suptitle('Stratified sampling: wet day threshold {:0d} mm'.format(WET_DAY_THRESHOLD));
``` ```
%% Output
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment