Skip to content
Snippets Groups Projects
Commit 56e8d37c authored by Frisinghelli Daniel's avatar Frisinghelli Daniel
Browse files

Clean version.

parent 5998cb2b
No related branches found
No related tags found
No related merge requests found
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# builtins # builtins
import sys
import time import time
import logging import logging
from datetime import timedelta from datetime import timedelta
...@@ -61,107 +62,106 @@ if __name__ == '__main__': ...@@ -61,107 +62,106 @@ if __name__ == '__main__':
if state_file.exists() and not OVERWRITE: if state_file.exists() and not OVERWRITE:
# load pretrained network # load pretrained network
net, _ = Network.load_pretrained_model(state_file, NET) net, _ = Network.load_pretrained_model(state_file, NET)
sys.exit()
# initialize ERA5 predictor dataset
LogConfig.init_log('Initializing ERA5 predictors.')
Era5 = ERA5Dataset(ERA5_PATH.joinpath('ERA5'), ERA5_PREDICTORS,
plevels=ERA5_PLEVELS)
Era5_ds = Era5.merge(chunks=-1)
# initialize OBS predictand dataset
LogConfig.init_log('Initializing observations for predictand: {}'
.format(PREDICTAND))
# check whether to joinlty train tasmin and tasmax
if PREDICTAND == 'tas':
# read both tasmax and tasmin
tasmax = xr.open_dataset(
search_files(OBS_PATH.joinpath('tasmax'), '.nc$').pop())
tasmin = xr.open_dataset(
search_files(OBS_PATH.joinpath('tasmin'), '.nc$').pop())
Obs_ds = xr.merge([tasmax, tasmin])
else: else:
# initialize ERA5 predictor dataset # read in-situ gridded observations
LogConfig.init_log('Initializing ERA5 predictors.') Obs_ds = search_files(OBS_PATH.joinpath(PREDICTAND), '.nc$').pop()
Era5 = ERA5Dataset(ERA5_PATH.joinpath('ERA5'), ERA5_PREDICTORS, Obs_ds = xr.open_dataset(Obs_ds)
plevels=ERA5_PLEVELS)
Era5_ds = Era5.merge(chunks=-1) # whether to use digital elevation model
if DEM:
# initialize OBS predictand dataset # digital elevation model: Copernicus EU-Dem v1.1
LogConfig.init_log('Initializing observations for predictand: {}' dem = search_files(DEM_PATH, '^eu_dem_v11_stt.nc$').pop()
.format(PREDICTAND))
# read elevation and compute slope and aspect
# check whether to joinlty train tasmin and tasmax dem = ERA5Dataset.dem_features(
if PREDICTAND == 'tas': dem, {'y': Era5_ds.y, 'x': Era5_ds.x},
# read both tasmax and tasmin add_coord={'time': Era5_ds.time})
tasmax = xr.open_dataset(
search_files(OBS_PATH.joinpath('tasmax'), '.nc$').pop()) # check whether to use slope and aspect
tasmin = xr.open_dataset( if not DEM_FEATURES:
search_files(OBS_PATH.joinpath('tasmin'), '.nc$').pop()) dem = dem.drop_vars(['slope', 'aspect'])
Obs_ds = xr.merge([tasmax, tasmin])
else: # add dem to set of predictor variables
# read in-situ gridded observations Era5_ds = xr.merge([Era5_ds, dem])
Obs_ds = search_files(OBS_PATH.joinpath(PREDICTAND), '.nc$').pop()
Obs_ds = xr.open_dataset(Obs_ds) # initialize network and optimizer
LogConfig.init_log('Initializing network and optimizer.')
# whether to use digital elevation model
if DEM: # define number of output fields
# digital elevation model: Copernicus EU-Dem v1.1 # check whether modelling pr with probabilistic approach
dem = search_files(DEM_PATH, '^eu_dem_v11_stt.nc$').pop() outputs = len(Obs_ds.data_vars)
if PREDICTAND == 'pr' and (isinstance(LOSS, BernoulliGammaLoss) or
# read elevation and compute slope and aspect isinstance(LOSS, BernoulliGenParetoLoss)):
dem = ERA5Dataset.dem_features( outputs = 3
dem, {'y': Era5_ds.y, 'x': Era5_ds.x},
add_coord={'time': Era5_ds.time}) # instanciate network
inputs = len(Era5_ds.data_vars) + 2 if DOY else len(Era5_ds.data_vars)
# check whether to use slope and aspect net = NET(state_file, inputs, outputs, filters=FILTERS)
if not DEM_FEATURES:
dem = dem.drop_vars(['slope', 'aspect']) # initialize optimizer
# optimizer = torch.optim.Adam(net.parameters(), lr=LR,
# add dem to set of predictor variables # weight_decay=LAMBDA)
Era5_ds = xr.merge([Era5_ds, dem]) optimizer = torch.optim.SGD(net.parameters(), lr=LR, momentum=0.9,
weight_decay=LAMBDA)
# initialize network and optimizer
LogConfig.init_log('Initializing network and optimizer.') # initialize training data
LogConfig.init_log('Initializing training data.')
# define number of output fields
# check whether modelling pr with probabilistic approach # split calibration period into training and validation period
outputs = len(Obs_ds.data_vars) if PREDICTAND == 'pr' and STRATIFY:
if PREDICTAND == 'pr' and (isinstance(LOSS, BernoulliGammaLoss) or # stratify training and validation dataset by number of
isinstance(LOSS, BernoulliGenParetoLoss)): # observed wet days for precipitation
outputs = 3 wet_days = (Obs_ds.sel(time=CALIB_PERIOD).mean(dim=('y', 'x'))
>= WET_DAY_THRESHOLD).to_array().values.squeeze()
# instanciate network train, valid = train_test_split(
inputs = len(Era5_ds.data_vars) + 2 if DOY else len(Era5_ds.data_vars) CALIB_PERIOD, stratify=wet_days, test_size=0.1)
net = NET(state_file, inputs, outputs, filters=FILTERS)
# sort chronologically
# initialize optimizer train, valid = sorted(train), sorted(valid)
# optimizer = torch.optim.Adam(net.parameters(), lr=LR, else:
# weight_decay=LAMBDA) train, valid = train_test_split(CALIB_PERIOD, shuffle=False,
optimizer = torch.optim.SGD(net.parameters(), lr=LR, momentum=0.9, test_size=0.1)
weight_decay=LAMBDA)
# training and validation dataset
# initialize training data Era5_train, Obs_train = Era5_ds.sel(time=train), Obs_ds.sel(time=train)
LogConfig.init_log('Initializing training data.') Era5_valid, Obs_valid = Era5_ds.sel(time=valid), Obs_ds.sel(time=valid)
# split calibration period into training and validation period # create PyTorch compliant dataset and dataloader instances for model
if PREDICTAND == 'pr' and STRATIFY: # training
# stratify training and validation dataset by number of train_ds = NetCDFDataset(Era5_train, Obs_train, normalize=NORM, doy=DOY)
# observed wet days for precipitation valid_ds = NetCDFDataset(Era5_valid, Obs_valid, normalize=NORM, doy=DOY)
wet_days = (Obs_ds.sel(time=CALIB_PERIOD).mean(dim=('y', 'x')) train_dl = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=SHUFFLE,
>= WET_DAY_THRESHOLD).to_array().values.squeeze() drop_last=False)
train, valid = train_test_split( valid_dl = DataLoader(valid_ds, batch_size=BATCH_SIZE, shuffle=SHUFFLE,
CALIB_PERIOD, stratify=wet_days, test_size=0.1) drop_last=False)
# sort chronologically # initialize network trainer
train, valid = sorted(train), sorted(valid) trainer = NetworkTrainer(net, optimizer, net.state_file, train_dl,
else: valid_dl, loss_function=LOSS,
train, valid = train_test_split(CALIB_PERIOD, shuffle=False, **TRAIN_CONFIG)
test_size=0.1)
# train model
# training and validation dataset state = trainer.train()
Era5_train, Obs_train = Era5_ds.sel(time=train), Obs_ds.sel(time=train)
Era5_valid, Obs_valid = Era5_ds.sel(time=valid), Obs_ds.sel(time=valid)
# create PyTorch compliant dataset and dataloader instances for model
# training
train_ds = NetCDFDataset(Era5_train, Obs_train, normalize=NORM,
doy=DOY)
valid_ds = NetCDFDataset(Era5_valid, Obs_valid, normalize=NORM,
doy=DOY)
train_dl = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=SHUFFLE,
drop_last=False)
valid_dl = DataLoader(valid_ds, batch_size=BATCH_SIZE, shuffle=SHUFFLE,
drop_last=False)
# initialize network trainer
trainer = NetworkTrainer(net, optimizer, net.state_file, train_dl,
valid_dl, loss_function=LOSS,
**TRAIN_CONFIG)
# train model
state = trainer.train()
# log execution time of script # log execution time of script
LogConfig.init_log('Execution time of script {}: {}' LogConfig.init_log('Execution time of script {}: {}'
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment