Skip to content
Snippets Groups Projects
Commit 56e8d37c authored by Frisinghelli Daniel's avatar Frisinghelli Daniel
Browse files

Clean version.

parent 5998cb2b
No related branches found
No related tags found
No related merge requests found
......@@ -4,6 +4,7 @@
# -*- coding: utf-8 -*-
# builtins
import sys
import time
import logging
from datetime import timedelta
......@@ -61,107 +62,106 @@ if __name__ == '__main__':
if state_file.exists() and not OVERWRITE:
# load pretrained network
net, _ = Network.load_pretrained_model(state_file, NET)
sys.exit()
# initialize ERA5 predictor dataset
LogConfig.init_log('Initializing ERA5 predictors.')
Era5 = ERA5Dataset(ERA5_PATH.joinpath('ERA5'), ERA5_PREDICTORS,
plevels=ERA5_PLEVELS)
Era5_ds = Era5.merge(chunks=-1)
# initialize OBS predictand dataset
LogConfig.init_log('Initializing observations for predictand: {}'
.format(PREDICTAND))
# check whether to joinlty train tasmin and tasmax
if PREDICTAND == 'tas':
# read both tasmax and tasmin
tasmax = xr.open_dataset(
search_files(OBS_PATH.joinpath('tasmax'), '.nc$').pop())
tasmin = xr.open_dataset(
search_files(OBS_PATH.joinpath('tasmin'), '.nc$').pop())
Obs_ds = xr.merge([tasmax, tasmin])
else:
# initialize ERA5 predictor dataset
LogConfig.init_log('Initializing ERA5 predictors.')
Era5 = ERA5Dataset(ERA5_PATH.joinpath('ERA5'), ERA5_PREDICTORS,
plevels=ERA5_PLEVELS)
Era5_ds = Era5.merge(chunks=-1)
# initialize OBS predictand dataset
LogConfig.init_log('Initializing observations for predictand: {}'
.format(PREDICTAND))
# check whether to joinlty train tasmin and tasmax
if PREDICTAND == 'tas':
# read both tasmax and tasmin
tasmax = xr.open_dataset(
search_files(OBS_PATH.joinpath('tasmax'), '.nc$').pop())
tasmin = xr.open_dataset(
search_files(OBS_PATH.joinpath('tasmin'), '.nc$').pop())
Obs_ds = xr.merge([tasmax, tasmin])
else:
# read in-situ gridded observations
Obs_ds = search_files(OBS_PATH.joinpath(PREDICTAND), '.nc$').pop()
Obs_ds = xr.open_dataset(Obs_ds)
# whether to use digital elevation model
if DEM:
# digital elevation model: Copernicus EU-Dem v1.1
dem = search_files(DEM_PATH, '^eu_dem_v11_stt.nc$').pop()
# read elevation and compute slope and aspect
dem = ERA5Dataset.dem_features(
dem, {'y': Era5_ds.y, 'x': Era5_ds.x},
add_coord={'time': Era5_ds.time})
# check whether to use slope and aspect
if not DEM_FEATURES:
dem = dem.drop_vars(['slope', 'aspect'])
# add dem to set of predictor variables
Era5_ds = xr.merge([Era5_ds, dem])
# initialize network and optimizer
LogConfig.init_log('Initializing network and optimizer.')
# define number of output fields
# check whether modelling pr with probabilistic approach
outputs = len(Obs_ds.data_vars)
if PREDICTAND == 'pr' and (isinstance(LOSS, BernoulliGammaLoss) or
isinstance(LOSS, BernoulliGenParetoLoss)):
outputs = 3
# instanciate network
inputs = len(Era5_ds.data_vars) + 2 if DOY else len(Era5_ds.data_vars)
net = NET(state_file, inputs, outputs, filters=FILTERS)
# initialize optimizer
# optimizer = torch.optim.Adam(net.parameters(), lr=LR,
# weight_decay=LAMBDA)
optimizer = torch.optim.SGD(net.parameters(), lr=LR, momentum=0.9,
weight_decay=LAMBDA)
# initialize training data
LogConfig.init_log('Initializing training data.')
# split calibration period into training and validation period
if PREDICTAND == 'pr' and STRATIFY:
# stratify training and validation dataset by number of
# observed wet days for precipitation
wet_days = (Obs_ds.sel(time=CALIB_PERIOD).mean(dim=('y', 'x'))
>= WET_DAY_THRESHOLD).to_array().values.squeeze()
train, valid = train_test_split(
CALIB_PERIOD, stratify=wet_days, test_size=0.1)
# sort chronologically
train, valid = sorted(train), sorted(valid)
else:
train, valid = train_test_split(CALIB_PERIOD, shuffle=False,
test_size=0.1)
# training and validation dataset
Era5_train, Obs_train = Era5_ds.sel(time=train), Obs_ds.sel(time=train)
Era5_valid, Obs_valid = Era5_ds.sel(time=valid), Obs_ds.sel(time=valid)
# create PyTorch compliant dataset and dataloader instances for model
# training
train_ds = NetCDFDataset(Era5_train, Obs_train, normalize=NORM,
doy=DOY)
valid_ds = NetCDFDataset(Era5_valid, Obs_valid, normalize=NORM,
doy=DOY)
train_dl = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=SHUFFLE,
drop_last=False)
valid_dl = DataLoader(valid_ds, batch_size=BATCH_SIZE, shuffle=SHUFFLE,
drop_last=False)
# initialize network trainer
trainer = NetworkTrainer(net, optimizer, net.state_file, train_dl,
valid_dl, loss_function=LOSS,
**TRAIN_CONFIG)
# train model
state = trainer.train()
# read in-situ gridded observations
Obs_ds = search_files(OBS_PATH.joinpath(PREDICTAND), '.nc$').pop()
Obs_ds = xr.open_dataset(Obs_ds)
# whether to use digital elevation model
if DEM:
# digital elevation model: Copernicus EU-Dem v1.1
dem = search_files(DEM_PATH, '^eu_dem_v11_stt.nc$').pop()
# read elevation and compute slope and aspect
dem = ERA5Dataset.dem_features(
dem, {'y': Era5_ds.y, 'x': Era5_ds.x},
add_coord={'time': Era5_ds.time})
# check whether to use slope and aspect
if not DEM_FEATURES:
dem = dem.drop_vars(['slope', 'aspect'])
# add dem to set of predictor variables
Era5_ds = xr.merge([Era5_ds, dem])
# initialize network and optimizer
LogConfig.init_log('Initializing network and optimizer.')
# define number of output fields
# check whether modelling pr with probabilistic approach
outputs = len(Obs_ds.data_vars)
if PREDICTAND == 'pr' and (isinstance(LOSS, BernoulliGammaLoss) or
isinstance(LOSS, BernoulliGenParetoLoss)):
outputs = 3
# instanciate network
inputs = len(Era5_ds.data_vars) + 2 if DOY else len(Era5_ds.data_vars)
net = NET(state_file, inputs, outputs, filters=FILTERS)
# initialize optimizer
# optimizer = torch.optim.Adam(net.parameters(), lr=LR,
# weight_decay=LAMBDA)
optimizer = torch.optim.SGD(net.parameters(), lr=LR, momentum=0.9,
weight_decay=LAMBDA)
# initialize training data
LogConfig.init_log('Initializing training data.')
# split calibration period into training and validation period
if PREDICTAND == 'pr' and STRATIFY:
# stratify training and validation dataset by number of
# observed wet days for precipitation
wet_days = (Obs_ds.sel(time=CALIB_PERIOD).mean(dim=('y', 'x'))
>= WET_DAY_THRESHOLD).to_array().values.squeeze()
train, valid = train_test_split(
CALIB_PERIOD, stratify=wet_days, test_size=0.1)
# sort chronologically
train, valid = sorted(train), sorted(valid)
else:
train, valid = train_test_split(CALIB_PERIOD, shuffle=False,
test_size=0.1)
# training and validation dataset
Era5_train, Obs_train = Era5_ds.sel(time=train), Obs_ds.sel(time=train)
Era5_valid, Obs_valid = Era5_ds.sel(time=valid), Obs_ds.sel(time=valid)
# create PyTorch compliant dataset and dataloader instances for model
# training
train_ds = NetCDFDataset(Era5_train, Obs_train, normalize=NORM, doy=DOY)
valid_ds = NetCDFDataset(Era5_valid, Obs_valid, normalize=NORM, doy=DOY)
train_dl = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=SHUFFLE,
drop_last=False)
valid_dl = DataLoader(valid_ds, batch_size=BATCH_SIZE, shuffle=SHUFFLE,
drop_last=False)
# initialize network trainer
trainer = NetworkTrainer(net, optimizer, net.state_file, train_dl,
valid_dl, loss_function=LOSS,
**TRAIN_CONFIG)
# train model
state = trainer.train()
# log execution time of script
LogConfig.init_log('Execution time of script {}: {}'
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment