diff --git a/climax/main/downscale_train.py b/climax/main/downscale_train.py index 3868528fa3051ebee34c19ef153d08b6e1b6acf4..aa78bb14fcd75cd44c0fd645bdbf6b415fe51dcd 100644 --- a/climax/main/downscale_train.py +++ b/climax/main/downscale_train.py @@ -4,6 +4,7 @@ # -*- coding: utf-8 -*- # builtins +import sys import time import logging from datetime import timedelta @@ -61,107 +62,106 @@ if __name__ == '__main__': if state_file.exists() and not OVERWRITE: # load pretrained network net, _ = Network.load_pretrained_model(state_file, NET) + sys.exit() + + # initialize ERA5 predictor dataset + LogConfig.init_log('Initializing ERA5 predictors.') + Era5 = ERA5Dataset(ERA5_PATH.joinpath('ERA5'), ERA5_PREDICTORS, + plevels=ERA5_PLEVELS) + Era5_ds = Era5.merge(chunks=-1) + + # initialize OBS predictand dataset + LogConfig.init_log('Initializing observations for predictand: {}' + .format(PREDICTAND)) + + # check whether to joinlty train tasmin and tasmax + if PREDICTAND == 'tas': + # read both tasmax and tasmin + tasmax = xr.open_dataset( + search_files(OBS_PATH.joinpath('tasmax'), '.nc$').pop()) + tasmin = xr.open_dataset( + search_files(OBS_PATH.joinpath('tasmin'), '.nc$').pop()) + Obs_ds = xr.merge([tasmax, tasmin]) else: - # initialize ERA5 predictor dataset - LogConfig.init_log('Initializing ERA5 predictors.') - Era5 = ERA5Dataset(ERA5_PATH.joinpath('ERA5'), ERA5_PREDICTORS, - plevels=ERA5_PLEVELS) - Era5_ds = Era5.merge(chunks=-1) - - # initialize OBS predictand dataset - LogConfig.init_log('Initializing observations for predictand: {}' - .format(PREDICTAND)) - - # check whether to joinlty train tasmin and tasmax - if PREDICTAND == 'tas': - # read both tasmax and tasmin - tasmax = xr.open_dataset( - search_files(OBS_PATH.joinpath('tasmax'), '.nc$').pop()) - tasmin = xr.open_dataset( - search_files(OBS_PATH.joinpath('tasmin'), '.nc$').pop()) - Obs_ds = xr.merge([tasmax, tasmin]) - else: - # read in-situ gridded observations - Obs_ds = search_files(OBS_PATH.joinpath(PREDICTAND), '.nc$').pop() - Obs_ds = xr.open_dataset(Obs_ds) - - # whether to use digital elevation model - if DEM: - # digital elevation model: Copernicus EU-Dem v1.1 - dem = search_files(DEM_PATH, '^eu_dem_v11_stt.nc$').pop() - - # read elevation and compute slope and aspect - dem = ERA5Dataset.dem_features( - dem, {'y': Era5_ds.y, 'x': Era5_ds.x}, - add_coord={'time': Era5_ds.time}) - - # check whether to use slope and aspect - if not DEM_FEATURES: - dem = dem.drop_vars(['slope', 'aspect']) - - # add dem to set of predictor variables - Era5_ds = xr.merge([Era5_ds, dem]) - - # initialize network and optimizer - LogConfig.init_log('Initializing network and optimizer.') - - # define number of output fields - # check whether modelling pr with probabilistic approach - outputs = len(Obs_ds.data_vars) - if PREDICTAND == 'pr' and (isinstance(LOSS, BernoulliGammaLoss) or - isinstance(LOSS, BernoulliGenParetoLoss)): - outputs = 3 - - # instanciate network - inputs = len(Era5_ds.data_vars) + 2 if DOY else len(Era5_ds.data_vars) - net = NET(state_file, inputs, outputs, filters=FILTERS) - - # initialize optimizer - # optimizer = torch.optim.Adam(net.parameters(), lr=LR, - # weight_decay=LAMBDA) - optimizer = torch.optim.SGD(net.parameters(), lr=LR, momentum=0.9, - weight_decay=LAMBDA) - - # initialize training data - LogConfig.init_log('Initializing training data.') - - # split calibration period into training and validation period - if PREDICTAND == 'pr' and STRATIFY: - # stratify training and validation dataset by number of - # observed wet days for precipitation - wet_days = (Obs_ds.sel(time=CALIB_PERIOD).mean(dim=('y', 'x')) - >= WET_DAY_THRESHOLD).to_array().values.squeeze() - train, valid = train_test_split( - CALIB_PERIOD, stratify=wet_days, test_size=0.1) - - # sort chronologically - train, valid = sorted(train), sorted(valid) - else: - train, valid = train_test_split(CALIB_PERIOD, shuffle=False, - test_size=0.1) - - # training and validation dataset - Era5_train, Obs_train = Era5_ds.sel(time=train), Obs_ds.sel(time=train) - Era5_valid, Obs_valid = Era5_ds.sel(time=valid), Obs_ds.sel(time=valid) - - # create PyTorch compliant dataset and dataloader instances for model - # training - train_ds = NetCDFDataset(Era5_train, Obs_train, normalize=NORM, - doy=DOY) - valid_ds = NetCDFDataset(Era5_valid, Obs_valid, normalize=NORM, - doy=DOY) - train_dl = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=SHUFFLE, - drop_last=False) - valid_dl = DataLoader(valid_ds, batch_size=BATCH_SIZE, shuffle=SHUFFLE, - drop_last=False) - - # initialize network trainer - trainer = NetworkTrainer(net, optimizer, net.state_file, train_dl, - valid_dl, loss_function=LOSS, - **TRAIN_CONFIG) - - # train model - state = trainer.train() + # read in-situ gridded observations + Obs_ds = search_files(OBS_PATH.joinpath(PREDICTAND), '.nc$').pop() + Obs_ds = xr.open_dataset(Obs_ds) + + # whether to use digital elevation model + if DEM: + # digital elevation model: Copernicus EU-Dem v1.1 + dem = search_files(DEM_PATH, '^eu_dem_v11_stt.nc$').pop() + + # read elevation and compute slope and aspect + dem = ERA5Dataset.dem_features( + dem, {'y': Era5_ds.y, 'x': Era5_ds.x}, + add_coord={'time': Era5_ds.time}) + + # check whether to use slope and aspect + if not DEM_FEATURES: + dem = dem.drop_vars(['slope', 'aspect']) + + # add dem to set of predictor variables + Era5_ds = xr.merge([Era5_ds, dem]) + + # initialize network and optimizer + LogConfig.init_log('Initializing network and optimizer.') + + # define number of output fields + # check whether modelling pr with probabilistic approach + outputs = len(Obs_ds.data_vars) + if PREDICTAND == 'pr' and (isinstance(LOSS, BernoulliGammaLoss) or + isinstance(LOSS, BernoulliGenParetoLoss)): + outputs = 3 + + # instanciate network + inputs = len(Era5_ds.data_vars) + 2 if DOY else len(Era5_ds.data_vars) + net = NET(state_file, inputs, outputs, filters=FILTERS) + + # initialize optimizer + # optimizer = torch.optim.Adam(net.parameters(), lr=LR, + # weight_decay=LAMBDA) + optimizer = torch.optim.SGD(net.parameters(), lr=LR, momentum=0.9, + weight_decay=LAMBDA) + + # initialize training data + LogConfig.init_log('Initializing training data.') + + # split calibration period into training and validation period + if PREDICTAND == 'pr' and STRATIFY: + # stratify training and validation dataset by number of + # observed wet days for precipitation + wet_days = (Obs_ds.sel(time=CALIB_PERIOD).mean(dim=('y', 'x')) + >= WET_DAY_THRESHOLD).to_array().values.squeeze() + train, valid = train_test_split( + CALIB_PERIOD, stratify=wet_days, test_size=0.1) + + # sort chronologically + train, valid = sorted(train), sorted(valid) + else: + train, valid = train_test_split(CALIB_PERIOD, shuffle=False, + test_size=0.1) + + # training and validation dataset + Era5_train, Obs_train = Era5_ds.sel(time=train), Obs_ds.sel(time=train) + Era5_valid, Obs_valid = Era5_ds.sel(time=valid), Obs_ds.sel(time=valid) + + # create PyTorch compliant dataset and dataloader instances for model + # training + train_ds = NetCDFDataset(Era5_train, Obs_train, normalize=NORM, doy=DOY) + valid_ds = NetCDFDataset(Era5_valid, Obs_valid, normalize=NORM, doy=DOY) + train_dl = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=SHUFFLE, + drop_last=False) + valid_dl = DataLoader(valid_ds, batch_size=BATCH_SIZE, shuffle=SHUFFLE, + drop_last=False) + + # initialize network trainer + trainer = NetworkTrainer(net, optimizer, net.state_file, train_dl, + valid_dl, loss_function=LOSS, + **TRAIN_CONFIG) + + # train model + state = trainer.train() # log execution time of script LogConfig.init_log('Execution time of script {}: {}'