Skip to content
GitLab
Explore
Sign in
Primary navigation
Search or go to…
Project
C
Climax
Manage
Activity
Members
Labels
Plan
Issues
Issue boards
Milestones
Wiki
Code
Merge requests
Repository
Branches
Commits
Tags
Repository graph
Compare revisions
Snippets
Deploy
Releases
Package Registry
Container Registry
Model registry
Operate
Terraform modules
Monitor
Incidents
Analyze
Value stream analytics
Contributor analytics
Repository analytics
Model experiments
Help
Help
Support
GitLab documentation
Compare GitLab plans
Community forum
Contribute to GitLab
Provide feedback
Keyboard shortcuts
?
Snippets
Groups
Projects
Show more breadcrumbs
earth_observation_public
Climax
Commits
49f8a7f8
Commit
49f8a7f8
authored
3 years ago
by
Frisinghelli Daniel
Browse files
Options
Downloads
Patches
Plain Diff
Fixed formatting for benchmark simulation.
parent
b78ca653
No related branches found
Branches containing commit
No related tags found
No related merge requests found
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
Notebooks/eval_bootstrap.ipynb
+4
-4
4 additions, 4 deletions
Notebooks/eval_bootstrap.ipynb
with
4 additions
and
4 deletions
Notebooks/eval_bootstrap.ipynb
+
4
−
4
View file @
49f8a7f8
...
...
@@ -82,7 +82,7 @@
"outputs": [],
"source": [
"# predictand to evaluate\n",
"PREDICTAND = '
tasmin
'"
"PREDICTAND = '
pr
'"
]
},
{
...
...
@@ -93,7 +93,7 @@
"outputs": [],
"source": [
"# whether only precipitation was used as predictor\n",
"PR_ONLY =
Fals
e"
"PR_ONLY =
Tru
e"
]
},
{
...
...
@@ -117,8 +117,8 @@
"source": [
"# model to evaluate\n",
"if PREDICTAND == 'pr' and PR_ONLY:\n",
" models = ['USegNet_pr_pr_1mm_{}_{}'.format(
PREDICTAND,
loss, OPTIM) if loss == 'BernoulliGammaLoss' else\n",
" 'USegNet_pr_pr_{}_{}'.format(
PREDICTAND,
loss, OPTIM) for loss in LOSS]\n",
" models = ['USegNet_pr_pr_1mm_{}_{}'.format(loss, OPTIM) if loss == 'BernoulliGammaLoss' else\n",
" 'USegNet_pr_pr_{}_{}'.format(loss, OPTIM) for loss in LOSS]\n",
"else:\n",
" models = ['USegNet_{}_ztuvq_500_850_p_dem_doy_1mm_{}_{}'.format(PREDICTAND, loss, OPTIM) if loss == 'BernoulliGammaLoss' else\n",
" 'USegNet_{}_ztuvq_500_850_p_dem_doy_{}_{}'.format(PREDICTAND, loss, OPTIM) for loss in LOSS]"
...
...
%% Cell type:markdown id:fde8874d-299f-4f48-a10a-9fb6a00b43b9 tags:
# Evaluate bootstrapped model results
%% Cell type:markdown id:969d063b-5262-4324-901f-0a48630c4f27 tags:
## Imports and constants
%% Cell type:code id:8af00ae4-4aeb-4ff8-a46a-65966b28c440 tags:
```
python
# builtins
import
pathlib
import
warnings
# externals
import
numpy
as
np
import
xarray
as
xr
import
pandas
as
pd
from
sklearn.metrics
import
r2_score
,
auc
,
roc_curve
# locals
from
climax.core.dataset
import
ERA5Dataset
from
climax.main.io
import
OBS_PATH
,
ERA5_PATH
from
climax.main.config
import
VALID_PERIOD
from
pysegcnn.core.utils
import
search_files
```
%% Cell type:code id:5bc74835-dc59-46ed-849b-3ff614e53eee tags:
```
python
# mapping from predictands to variable names
NAMES
=
{
'
tasmin
'
:
'
minimum temperature
'
,
'
tasmax
'
:
'
maximum temperature
'
,
'
pr
'
:
'
precipitation
'
}
```
%% Cell type:code id:c8a63ef3-35ef-4ffa-b1f3-5c2986eb7eb1 tags:
```
python
# path to bootstrapped model results
RESULTS
=
pathlib
.
Path
(
'
/mnt/CEPH_PROJECTS/FACT_CLIMAX/ERA5_PRED/bootstrap
'
)
```
%% Cell type:markdown id:7eae545b-4d8a-4689-a6c0-4aba2cb9104e tags:
## Search model configurations
%% Cell type:code id:3b83c9f3-7081-4cec-8f23-c4de007839d7 tags:
```
python
# predictand to evaluate
PREDICTAND
=
'
tasmin
'
PREDICTAND
=
'
pr
'
```
%% Cell type:code id:49e03dc8-e709-4877-922a-4914e61d7636 tags:
```
python
# whether only precipitation was used as predictor
PR_ONLY
=
Fals
e
PR_ONLY
=
Tru
e
```
%% Cell type:code id:3e856f80-14fd-405f-a44e-cc77863f8e5b tags:
```
python
# loss function and optimizer
LOSS
=
[
'
L1Loss
'
,
'
MSELoss
'
,
'
BernoulliGammaLoss
'
]
if
PREDICTAND
==
'
pr
'
else
[
'
L1Loss
'
,
'
MSELoss
'
]
OPTIM
=
'
Adam
'
```
%% Cell type:code id:011b792d-7349-44ad-997d-11f236472a11 tags:
```
python
# model to evaluate
if
PREDICTAND
==
'
pr
'
and
PR_ONLY
:
models
=
[
'
USegNet_pr_pr_1mm_{}_{}
'
.
format
(
PREDICTAND
,
loss
,
OPTIM
)
if
loss
==
'
BernoulliGammaLoss
'
else
'
USegNet_pr_pr_{}_{}
'
.
format
(
PREDICTAND
,
loss
,
OPTIM
)
for
loss
in
LOSS
]
models
=
[
'
USegNet_pr_pr_1mm_{}_{}
'
.
format
(
loss
,
OPTIM
)
if
loss
==
'
BernoulliGammaLoss
'
else
'
USegNet_pr_pr_{}_{}
'
.
format
(
loss
,
OPTIM
)
for
loss
in
LOSS
]
else
:
models
=
[
'
USegNet_{}_ztuvq_500_850_p_dem_doy_1mm_{}_{}
'
.
format
(
PREDICTAND
,
loss
,
OPTIM
)
if
loss
==
'
BernoulliGammaLoss
'
else
'
USegNet_{}_ztuvq_500_850_p_dem_doy_{}_{}
'
.
format
(
PREDICTAND
,
loss
,
OPTIM
)
for
loss
in
LOSS
]
```
%% Cell type:code id:dc4ca6f0-5490-4522-8661-e36bd1be11b7 tags:
```
python
# get bootstrapped models
models
=
{
loss
:
sorted
(
search_files
(
RESULTS
.
joinpath
(
PREDICTAND
),
model
+
'
(.*).nc$
'
),
key
=
lambda
x
:
int
(
x
.
stem
.
split
(
'
_
'
)[
-
1
]))
for
loss
,
model
in
zip
(
LOSS
,
models
)}
models
```
%% Cell type:markdown id:5a64795a-6e5c-409a-8b3b-c738a96fa255 tags:
## Load datasets
%% Cell type:markdown id:e790ed9f-451c-4368-849d-06d9c50f797c tags:
### Load observations
%% Cell type:code id:0862e0c8-06df-45d6-bc1b-002ffb6e9915 tags:
```
python
# load observations
y_true
=
xr
.
open_dataset
(
OBS_PATH
.
joinpath
(
PREDICTAND
,
'
OBS_{}_1980_2018.nc
'
.
format
(
PREDICTAND
)),
chunks
=
{
'
time
'
:
365
})
y_true
=
y_true
.
sel
(
time
=
VALID_PERIOD
)
# subset to time period covered by predictions
y_true
=
y_true
.
rename
({
NAMES
[
PREDICTAND
]:
PREDICTAND
})
if
PREDICTAND
==
'
pr
'
else
y_true
```
%% Cell type:code id:aba38642-85d1-404a-81f3-65d23985fb7a tags:
```
python
# mask of missing values
missing
=
np
.
isnan
(
y_true
[
PREDICTAND
])
```
%% Cell type:markdown id:d4512ed2-d503-4bc1-ae76-84560c101a14 tags:
### Load reference data
%% Cell type:code id:f90f6abf-5fd6-49c0-a1ad-f62242b3d3a0 tags:
```
python
# ERA-5 reference dataset
if
PREDICTAND
==
'
pr
'
:
y_refe
=
xr
.
open_dataset
(
search_files
(
ERA5_PATH
.
joinpath
(
'
ERA5
'
,
'
total_precipitation
'
),
'
.nc$
'
).
pop
(),
chunks
=
{
'
time
'
:
365
})
y_refe
=
y_refe
.
rename
({
'
tp
'
:
'
pr
'
})
else
:
y_refe
=
xr
.
open_dataset
(
search_files
(
ERA5_PATH
.
joinpath
(
'
ERA5
'
,
'
2m_{}_temperature
'
.
format
(
PREDICTAND
.
lstrip
(
'
tas
'
))),
'
.nc$
'
).
pop
(),
chunks
=
{
'
time
'
:
365
})
y_refe
=
y_refe
-
273.15
# convert to °C
y_refe
=
y_refe
.
rename
({
'
t2m
'
:
PREDICTAND
})
```
%% Cell type:code id:ea6d5f56-4f39-4e9a-976d-00ff28fce95c tags:
```
python
# subset to time period covered by predictions
y_refe
=
y_refe
.
sel
(
time
=
VALID_PERIOD
).
drop_vars
(
'
lambert_azimuthal_equal_area
'
)
y_refe
=
y_refe
.
transpose
(
'
time
'
,
'
y
'
,
'
x
'
)
# change order of dimensions
```
%% Cell type:markdown id:d37702de-da5f-4306-acc1-e569471c1f12 tags:
### Load QM-adjusted reference data
%% Cell type:code id:fffbd267-d08b-44f4-869c-7056c4f19c28 tags:
```
python
y_refe_qm
=
xr
.
open_dataset
(
ERA5_PATH
.
joinpath
(
'
QM_ERA5_{}_day_19912010.nc
'
.
format
(
PREDICTAND
)),
chunks
=
{
'
time
'
:
365
})
y_refe_qm
=
y_refe_qm
.
transpose
(
'
time
'
,
'
y
'
,
'
x
'
)
# change order of dimensions
```
%% Cell type:code id:16fa580e-27a7-4758-9164-7f607df7179d tags:
```
python
# center hours at 00:00:00 rather than 12:00:00
y_refe_qm
[
'
time
'
]
=
np
.
asarray
([
t
.
astype
(
'
datetime64[D]
'
)
for
t
in
y_refe_qm
.
time
.
values
])
```
%% Cell type:code id:6789791f-006b-49b3-aa04-34e4ed8e1571 tags:
```
python
# subset to time period covered by predictions
y_refe_qm
=
y_refe_qm
.
sel
(
time
=
VALID_PERIOD
).
drop_vars
(
'
lambert_azimuthal_equal_area
'
)
```
%% Cell type:code id:b51cfb3f-caa8-413e-a12d-47bbafcef1df tags:
```
python
# align datasets and mask missing values
y_true
,
y_refe
,
y_refe_qm
=
xr
.
align
(
y_true
[
PREDICTAND
],
y_refe
[
PREDICTAND
],
y_refe_qm
[
PREDICTAND
],
join
=
'
override
'
)
y_refe
=
y_refe
.
where
(
~
missing
,
other
=
np
.
nan
)
y_refe_qm
=
y_refe_qm
.
where
(
~
missing
,
other
=
np
.
nan
)
```
%% Cell type:markdown id:b4a6c286-6b88-487d-866c-3cb633686dac tags:
### Load model predictions
%% Cell type:code id:eb889059-17e4-4d8c-b796-e8b1e2d0bf8c tags:
```
python
y_pred_raw
=
{
k
:
[
xr
.
open_dataset
(
v
,
chunks
=
{
'
time
'
:
365
})
for
v
in
models
[
k
]]
for
k
in
models
.
keys
()}
if
PREDICTAND
==
'
pr
'
:
y_pred_raw
=
{
k
:
[
v
.
rename
({
NAMES
[
PREDICTAND
]:
PREDICTAND
})
if
k
==
'
BernoulliGammaLoss
'
else
v
.
rename
({
PREDICTAND
:
PREDICTAND
})
for
v
in
y_pred_raw
[
k
]]
for
k
in
y_pred_raw
.
keys
()}
y_pred_raw
=
{
k
:
[
v
.
transpose
(
'
time
'
,
'
y
'
,
'
x
'
)
for
v
in
y_pred_raw
[
k
]]
for
k
in
y_pred_raw
.
keys
()}
```
%% Cell type:code id:534e020d-96b2-403c-b8e4-86de98fbbe3b tags:
```
python
# align datasets and mask missing values
y_prob
=
{}
y_pred
=
{}
for
loss
,
sim
in
y_pred_raw
.
items
():
y_pred
[
loss
],
y_prob
[
loss
]
=
[],
[]
for
y_p
in
sim
:
# check whether evaluating precipitation or temperatures
if
len
(
y_p
.
data_vars
)
>
1
:
_
,
_
,
y_p
,
y_p_prob
=
xr
.
align
(
y_true
,
y_refe
,
y_p
[
PREDICTAND
],
y_p
.
prob
,
join
=
'
override
'
)
y_p_prob
=
y_p_prob
.
where
(
~
missing
,
other
=
np
.
nan
)
# mask missing values
y_prob
[
loss
].
append
(
y_p_prob
)
else
:
_
,
_
,
y_p
=
xr
.
align
(
y_true
,
y_refe
,
y_p
[
PREDICTAND
],
join
=
'
override
'
)
# mask missing values
y_p
=
y_p
.
where
(
~
missing
,
other
=
np
.
nan
)
y_pred
[
loss
].
append
(
y_p
)
```
%% Cell type:markdown id:6a718ea3-54d3-400a-8c89-76d04347de2d tags:
## Ensemble predictions
%% Cell type:code id:5a6c0bfe-c1d2-4e43-9f8e-35c63c46bb10 tags:
```
python
# create and save ensemble dataset
ensemble
=
{
k
:
xr
.
Dataset
({
'
Member-{}
'
.
format
(
i
):
member
for
i
,
member
in
enumerate
(
y_pred
[
k
])}).
to_array
(
'
members
'
)
for
k
in
y_pred
.
keys
()
if
y_pred
[
k
]}
```
%% Cell type:code id:0e526227-cd4c-4a1c-ab72-51b72a4f821f tags:
```
python
# full ensemble mean prediction and standard deviation
ensemble_mean_full
=
{
k
:
v
.
mean
(
dim
=
'
members
'
)
for
k
,
v
in
ensemble
.
items
()}
ensemble_std_full
=
{
k
:
v
.
std
(
dim
=
'
members
'
)
for
k
,
v
in
ensemble
.
items
()}
```
%% Cell type:markdown id:f8b31e39-d4b9-4347-953f-87af04c0dd7a tags:
# Model validation
%% Cell type:code id:e8adcb5e-c7b4-4156-85b3-4751020160e6 tags:
```
python
# extreme quantile of interest
quantile
=
0.02
if
PREDICTAND
==
'
tasmin
'
else
0.98
```
%% Cell type:code id:8aa8d57d-8e41-4c6e-a43f-063650ac8e4b tags:
```
python
def
r2
(
y_pred
,
y_true
,
precipitation
=
False
):
# compute daily anomalies wrt. monthly mean values
anom_pred
=
ERA5Dataset
.
anomalies
(
y_pred
,
timescale
=
'
time.month
'
)
anom_true
=
ERA5Dataset
.
anomalies
(
y_true
,
timescale
=
'
time.month
'
)
# get predicted and observed daily anomalies
y_pred_av
=
anom_pred
.
values
.
flatten
()
y_true_av
=
anom_true
.
values
.
flatten
()
# apply mask of valid pixels
mask
=
(
~
np
.
isnan
(
y_pred_av
)
&
~
np
.
isnan
(
y_true_av
))
y_pred_av
=
y_pred_av
[
mask
]
y_true_av
=
y_true_av
[
mask
]
# get predicted and observed monthly sums/means
if
precipitation
:
y_pred_mv
=
y_pred
.
resample
(
time
=
'
1M
'
).
sum
(
skipna
=
False
).
values
.
flatten
()
y_true_mv
=
y_true
.
resample
(
time
=
'
1M
'
).
sum
(
skipna
=
False
).
values
.
flatten
()
else
:
y_pred_mv
=
y_pred
.
groupby
(
'
time.month
'
).
mean
(
dim
=
(
'
time
'
)).
values
.
flatten
()
y_true_mv
=
y_true
.
groupby
(
'
time.month
'
).
mean
(
dim
=
(
'
time
'
)).
values
.
flatten
()
# apply mask of valid pixels
mask
=
(
~
np
.
isnan
(
y_pred_mv
)
&
~
np
.
isnan
(
y_true_mv
))
y_pred_mv
=
y_pred_mv
[
mask
]
y_true_mv
=
y_true_mv
[
mask
]
# calculate coefficient of determination on monthly sums/means
r2_mm
=
r2_score
(
y_true_mv
,
y_pred_mv
)
print
(
'
R2 on monthly means: {:.2f}
'
.
format
(
r2_mm
))
# calculate coefficient of determination on daily anomalies
r2_anom
=
r2_score
(
y_true_av
,
y_pred_av
)
print
(
'
R2 on daily anomalies: {:.2f}
'
.
format
(
r2_anom
))
return
r2_mm
,
r2_anom
```
%% Cell type:code id:074d7405-a01b-4368-b98b-06d8d46f1ce6 tags:
```
python
def
bias
(
y_pred
,
y_true
,
relative
=
False
):
return
(((
y_pred
-
y_true
)
/
y_true
)
*
100
).
mean
().
values
.
item
()
if
relative
else
(
y_pred
-
y_true
).
mean
().
values
.
item
()
```
%% Cell type:code id:2fc13939-e517-47bd-aa7d-0addd6715538 tags:
```
python
def
mae
(
y_pred
,
y_true
):
return
np
.
abs
(
y_pred
-
y_true
).
mean
().
values
.
item
()
```
%% Cell type:code id:c93f497e-760e-4484-aeb6-ce54f561a7f6 tags:
```
python
def
rmse
(
y_pred
,
y_true
):
return
np
.
sqrt
(((
y_pred
-
y_true
)
**
2
).
mean
().
values
.
item
())
```
%% Cell type:markdown id:3e6ecc98-f32f-42f7-9971-64b270aa5453 tags:
## R2, Bias, MAE, and RMSE for reference data
%% Cell type:markdown id:671cd3c0-8d6c-41c1-bf8e-93f5943bf9aa tags:
### Metrics for mean values
%% Cell type:code id:7939a4d2-4eff-4507-86f8-dba7c0b635df tags:
```
python
# yearly average values over validation period
y_refe_yearly_avg
=
y_refe
.
groupby
(
'
time.year
'
).
mean
(
dim
=
'
time
'
)
y_refe_qm_yearly_avg
=
y_refe_qm
.
groupby
(
'
time.year
'
).
mean
(
dim
=
'
time
'
)
y_true_yearly_avg
=
y_true
.
groupby
(
'
time.year
'
).
mean
(
dim
=
'
time
'
)
```
%% Cell type:code id:64e29db7-998d-4952-84b0-1c79016ab9a9 tags:
```
python
# yearly average r2, bias, mae, and rmse for ERA-5
r2_refe_mm
,
r2_refe_anom
=
r2
(
y_refe
,
y_true
)
bias_refe
=
bias
(
y_refe_yearly_avg
,
y_true_yearly_avg
,
relative
=
True
if
PREDICTAND
==
'
pr
'
else
False
)
mae_refe
=
mae
(
y_refe_yearly_avg
,
y_true_yearly_avg
)
rmse_refe
=
rmse
(
y_refe_yearly_avg
,
y_true_yearly_avg
)
```
%% Cell type:code id:d0d4c974-876f-45e6-85cc-df91501ead20 tags:
```
python
# yearly average r2, bias, mae, and rmse for QM-Adjusted ERA-5
r2_refe_qm_mm
,
r2_refe_qm_anom
=
r2
(
y_refe_qm
,
y_true
)
bias_refe_qm
=
bias
(
y_refe_qm_yearly_avg
,
y_true_yearly_avg
,
relative
=
True
if
PREDICTAND
==
'
pr
'
else
False
)
mae_refe_qm
=
mae
(
y_refe_qm_yearly_avg
,
y_true_yearly_avg
)
rmse_refe_qm
=
rmse
(
y_refe_qm_yearly_avg
,
y_true_yearly_avg
)
```
%% Cell type:markdown id:c07684d1-76c0-4088-bdd7-7e1a6ccc4716 tags:
### Metrics for extreme values
%% Cell type:code id:343aad59-4b0a-4eec-9ac3-86e5f9d06fc6 tags:
```
python
# calculate extreme quantile for each year
with
warnings
.
catch_warnings
():
warnings
.
simplefilter
(
'
ignore
'
,
category
=
RuntimeWarning
)
y_true_ex
=
y_true
.
chunk
(
dict
(
time
=-
1
)).
groupby
(
'
time.year
'
).
quantile
(
quantile
,
dim
=
'
time
'
)
y_refe_ex
=
y_refe
.
chunk
(
dict
(
time
=-
1
)).
groupby
(
'
time.year
'
).
quantile
(
quantile
,
dim
=
'
time
'
)
y_refe_qm_ex
=
y_refe_qm
.
chunk
(
dict
(
time
=-
1
)).
groupby
(
'
time.year
'
).
quantile
(
quantile
,
dim
=
'
time
'
)
```
%% Cell type:code id:fbbc648a-82a7-4137-b8ea-5dccb56a65c7 tags:
```
python
# bias in extreme quantile
bias_ex_refe
=
bias
(
y_refe_ex
,
y_true_ex
,
relative
=
True
if
PREDICTAND
==
'
pr
'
else
False
)
bias_ex_refe_qm
=
bias
(
y_refe_qm_ex
,
y_true_ex
,
relative
=
True
if
PREDICTAND
==
'
pr
'
else
False
)
```
%% Cell type:code id:44a3a0e7-ca39-49ce-b569-51b0022161ed tags:
```
python
# mean absolute error in extreme quantile
mae_ex_refe
=
mae
(
y_refe_ex
,
y_true_ex
)
mae_ex_refe_qm
=
mae
(
y_refe_qm_ex
,
y_true_ex
)
```
%% Cell type:code id:a90ce1dc-cf94-4081-9add-1d26195f2302 tags:
```
python
# root mean squared error in extreme quantile
rmse_ex_refe
=
rmse
(
y_refe_ex
,
y_true_ex
)
rmse_ex_refe_qm
=
rmse
(
y_refe_qm_ex
,
y_true_ex
)
```
%% Cell type:code id:d6efe5b9-3a6d-41ea-9f26-295b167cf0af tags:
```
python
# compute validation metrics for reference datasets
filename
=
RESULTS
.
joinpath
(
PREDICTAND
,
'
reference.csv
'
)
if
filename
.
exists
():
# check if validation metrics for reference already exist
df_refe
=
pd
.
read_csv
(
filename
)
else
:
# compute validation metrics
df_refe
=
pd
.
DataFrame
([],
columns
=
[
'
r2_mm
'
,
'
r2_anom
'
,
'
bias
'
,
'
mae
'
,
'
rmse
'
,
'
bias_ex
'
,
'
mae_ex
'
,
'
rmse_ex
'
,
'
product
'
])
for
product
,
metrics
in
zip
([
'
Era-5
'
,
'
Era-5 QM
'
],
[[
r2_refe_mm
,
r2_refe_anom
,
bias_refe
,
mae_refe
,
rmse_refe
,
bias_ex_refe
,
mae_ex_refe
,
rmse_ex_refe
],
[
r2_refe_qm_mm
,
r2_refe_qm_anom
,
bias_refe_qm
,
mae_refe_qm
,
rmse_refe_qm
,
bias_ex_refe_qm
,
mae_ex_refe_qm
,
rmse_ex_refe_qm
]]):
df_refe
=
df_refe
.
append
(
pd
.
DataFrame
([
metrics
+
[
product
]],
columns
=
df_refe
.
columns
),
ignore_index
=
True
)
# save metrics to disk
df_refe
.
to_csv
(
filename
,
index
=
False
)
```
%% Cell type:markdown id:258cb3c6-c2fc-457d-885e-28eaf48f1d5b tags:
## R2, Bias, MAE, and RMSE for model predictions
%% Cell type:markdown id:630ce1c5-b018-437f-a7cf-8c8d99cd8f84 tags:
### Metrics for mean values
%% Cell type:code id:6980833a-3848-43ca-bcca-d759b4fd9f69 tags:
```
python
# yearly average bias, mae, and rmse for each ensemble member
y_pred_yearly_avg
=
{
k
:
v
.
groupby
(
'
time.year
'
).
mean
(
dim
=
'
time
'
)
for
k
,
v
in
ensemble
.
items
()}
bias_pred
=
{
k
:
[
bias
(
v
[
i
],
y_true_yearly_avg
,
relative
=
True
if
PREDICTAND
==
'
pr
'
else
False
)
for
i
in
range
(
len
(
ensemble
[
k
]))]
for
k
,
v
in
y_pred_yearly_avg
.
items
()}
mae_pred
=
{
k
:
[
mae
(
v
[
i
],
y_true_yearly_avg
)
for
i
in
range
(
len
(
ensemble
[
k
]))]
for
k
,
v
in
y_pred_yearly_avg
.
items
()}
rmse_pred
=
{
k
:
[
rmse
(
v
[
i
],
y_true_yearly_avg
)
for
i
in
range
(
len
(
ensemble
[
k
]))]
for
k
,
v
in
y_pred_yearly_avg
.
items
()}
```
%% Cell type:markdown id:122e84f8-211d-4816-9b03-2e1abc24eb9e tags:
### Metrics for extreme values
%% Cell type:code id:8b48a065-1a0b-4457-a97a-642c26d56c51 tags:
```
python
# calculate extreme quantile for each year
with
warnings
.
catch_warnings
():
warnings
.
simplefilter
(
'
ignore
'
,
category
=
RuntimeWarning
)
y_pred_ex
=
{
k
:
v
.
chunk
(
dict
(
time
=-
1
)).
groupby
(
'
time.year
'
).
quantile
(
quantile
,
dim
=
'
time
'
)
for
k
,
v
in
ensemble
.
items
()}
```
%% Cell type:code id:2e6893da-271b-4b0e-bbc1-46b5eb9ecee3 tags:
```
python
# yearly average bias, mae, and rmse for each ensemble member
bias_pred_ex
=
{
k
:
[
bias
(
v
[
i
],
y_true_ex
,
relative
=
True
if
PREDICTAND
==
'
pr
'
else
False
)
for
i
in
range
(
len
(
ensemble
[
k
]))]
for
k
,
v
in
y_pred_ex
.
items
()}
mae_pred_ex
=
{
k
:
[
mae
(
v
[
i
],
y_true_ex
)
for
i
in
range
(
len
(
ensemble
[
k
]))]
for
k
,
v
in
y_pred_ex
.
items
()}
rmse_pred_ex
=
{
k
:
[
rmse
(
v
[
i
],
y_true_ex
)
for
i
in
range
(
len
(
ensemble
[
k
]))]
for
k
,
v
in
y_pred_ex
.
items
()}
```
%% Cell type:code id:64f7a0b9-a772-4a03-9160-7839a48e56cd tags:
```
python
# compute validation metrics for model predictions
filename
=
(
RESULTS
.
joinpath
(
PREDICTAND
,
'
prediction_pr-only.csv
'
)
if
PREDICTAND
==
'
pr
'
and
PR_ONLY
else
RESULTS
.
joinpath
(
PREDICTAND
,
'
prediction.csv
'
))
if
filename
.
exists
():
# check if validation metrics for predictions already exist
df_pred
=
pd
.
read_csv
(
filename
)
else
:
# validation metrics for each ensemble member
df_pred
=
pd
.
DataFrame
([],
columns
=
[
'
r2_mm
'
,
'
r2_anom
'
,
'
bias
'
,
'
mae
'
,
'
rmse
'
,
'
bias_ex
'
,
'
mae_ex
'
,
'
rmse_ex
'
,
'
product
'
,
'
loss
'
])
for
k
in
y_pred_yearly_avg
.
keys
():
for
i
in
range
(
len
(
ensemble
[
k
])):
# bias, mae, and rmse
values
=
pd
.
DataFrame
([[
bias_pred
[
k
][
i
],
mae_pred
[
k
][
i
],
rmse_pred
[
k
][
i
],
bias_pred_ex
[
k
][
i
],
mae_pred_ex
[
k
][
i
],
rmse_pred_ex
[
k
][
i
],
'
Member-{:d}
'
.
format
(
i
),
k
]],
columns
=
df_pred
.
columns
[
2
:])
# r2 scores
values
[
'
r2_mm
'
],
values
[
'
r2_anom
'
]
=
r2
(
ensemble
[
k
][
i
],
y_true
,
precipitation
=
True
if
PREDICTAND
==
'
pr
'
else
False
)
df_pred
=
df_pred
.
append
(
values
,
ignore_index
=
True
)
# validation metrics for ensemble
for
k
,
v
in
ensemble_mean_full
.
items
():
# metrics for mean values
means
=
v
.
groupby
(
'
time.year
'
).
mean
(
dim
=
'
time
'
)
bias_mean
=
bias
(
means
,
y_true_yearly_avg
,
relative
=
True
if
PREDICTAND
==
'
pr
'
else
False
)
mae_mean
=
mae
(
means
,
y_true_yearly_avg
)
rmse_mean
=
rmse
(
means
,
y_true_yearly_avg
)
# metrics for extreme values
with
warnings
.
catch_warnings
():
warnings
.
simplefilter
(
'
ignore
'
,
category
=
RuntimeWarning
)
extremes
=
v
.
chunk
(
dict
(
time
=-
1
)).
groupby
(
'
time.year
'
).
quantile
(
quantile
,
dim
=
'
time
'
)
bias_ex
=
bias
(
extremes
,
y_true_ex
,
relative
=
True
if
PREDICTAND
==
'
pr
'
else
False
)
mae_ex
=
mae
(
extremes
,
y_true_ex
)
rmse_ex
=
rmse
(
extremes
,
y_true_ex
)
# r2 scores
r2_mm
,
r2_anom
=
r2
(
v
,
y_true
,
precipitation
=
True
if
PREDICTAND
==
'
pr
'
else
False
)
df_pred
=
df_pred
.
append
(
pd
.
DataFrame
([[
r2_mm
,
r2_anom
,
bias_mean
,
mae_mean
,
rmse_mean
,
bias_ex
,
mae_ex
,
rmse_ex
,
'
Ensemble-{:d}
'
.
format
(
len
(
ensemble
[
k
])),
k
]],
columns
=
df_pred
.
columns
),
ignore_index
=
True
)
# save metrics to disk
df_pred
.
to_csv
(
filename
,
index
=
False
)
```
%% Cell type:markdown id:da948e96-4a8c-4a56-9177-846851fe8ef8 tags:
### AUC and ROCSS for precipitation
%% Cell type:code id:0b7a824b-418a-4499-a3c0-627190e00941 tags:
```
python
def
auc_rocss
(
p_pred
,
y_true
,
wet_day_threshold
=
1
):
# true and predicted probability of precipitation
p_true
=
(
y_true
>=
float
(
wet_day_threshold
)).
values
.
flatten
()
p_pred
=
p_pred
.
values
.
flatten
()
# apply mask of valid pixels
mask
=
(
~
np
.
isnan
(
p_true
)
&
~
np
.
isnan
(
p_pred
))
p_pred
=
p_pred
[
mask
]
p_true
=
p_true
[
mask
].
astype
(
float
)
# calculate ROC: false positive rate vs. true positive rate
fpr
,
tpr
,
_
=
roc_curve
(
p_true
,
p_pred
)
area
=
auc
(
fpr
,
tpr
)
# area under ROC curve
rocss
=
2
*
area
-
1
# ROC skill score (cf. https://journals.ametsoc.org/view/journals/clim/16/24/1520-0442_2003_016_4145_otrsop_2.0.co_2.xml)
return
area
,
rocss
```
%% Cell type:code id:43138ade-148a-4d8d-be48-f1280d40e5b0 tags:
```
python
if
PREDICTAND
==
'
pr
'
:
# precipitation threshold to consider as wet day
WET_DAY_THRESHOLD
=
1
# ensemble prediction for precipitation probability
ensemble_prob
=
xr
.
Dataset
({
'
Member-{}
'
.
format
(
i
):
member
for
i
,
member
in
enumerate
(
y_prob
[
'
BernoulliGammaLoss
'
])}).
to_array
(
'
members
'
)
ensemble_mean_prob
=
ensemble_prob
.
mean
(
dim
=
'
members
'
)
# filename for probability metrics
filename
=
(
RESULTS
.
joinpath
(
PREDICTAND
,
'
probability_pr-only.csv
'
)
if
PREDICTAND
==
'
pr
'
and
PR_ONLY
else
RESULTS
.
joinpath
(
PREDICTAND
,
'
probability.csv
'
))
if
filename
.
exists
():
# check if validation metrics for probabilities already exist
df_prob
=
pd
.
read_csv
(
filename
)
else
:
# AUC and ROCSS for each ensemble member
df_prob
=
pd
.
DataFrame
([],
columns
=
[
'
auc
'
,
'
rocss
'
,
'
product
'
,
'
loss
'
])
for
i
in
range
(
len
(
ensemble_prob
)):
auc_score
,
rocss
=
auc_rocss
(
ensemble_prob
[
i
],
y_true
,
wet_day_threshold
=
WET_DAY_THRESHOLD
)
df_prob
=
df_prob
.
append
(
pd
.
DataFrame
([[
auc_score
,
rocss
,
ensemble_prob
[
i
].
members
.
item
(),
'
BernoulliGammaLoss
'
]],
columns
=
df_prob
.
columns
),
ignore_index
=
True
)
# AUC and ROCSS for ensemble mean
auc_score
,
rocss
=
auc_rocss
(
ensemble_mean_prob
,
y_true
,
wet_day_threshold
=
WET_DAY_THRESHOLD
)
df_prob
=
df_prob
.
append
(
pd
.
DataFrame
([[
auc_score
,
rocss
,
'
Ensemble-{:d}
'
.
format
(
len
(
ensemble_prob
)),
'
BernoulliGammaLoss
'
]],
columns
=
df_prob
.
columns
),
ignore_index
=
True
)
# save metrics to disk
df_prob
.
to_csv
(
filename
,
index
=
False
)
```
...
...
This diff is collapsed.
Click to expand it.
Preview
0%
Loading
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Save comment
Cancel
Please
register
or
sign in
to comment