diff --git a/climax/main/downscale_infer.py b/climax/main/downscale_infer.py index 2918217a762e200471737489b773bf45dbe36972..47f104d55a76f755496c4d8ff7eb284f009681b7 100644 --- a/climax/main/downscale_infer.py +++ b/climax/main/downscale_infer.py @@ -24,7 +24,7 @@ from climax.core.utils import split_date_range from climax.main.config import (ERA5_PREDICTORS, ERA5_PLEVELS, PREDICTAND, NET, VALID_PERIOD, BATCH_SIZE, NORM, DOY, NYEARS, DEM, DEM_FEATURES, LOSS, ANOMALIES, OPTIM, - OPTIM_PARAMS) + OPTIM_PARAMS, CHUNKS) from climax.main.io import ERA5_PATH, DEM_PATH, MODEL_PATH, TARGET_PATH # module level logger @@ -58,7 +58,7 @@ if __name__ == '__main__': LogConfig.init_log('Initializing ERA5 predictors.') Era5 = ERA5Dataset(ERA5_PATH.joinpath('ERA5'), ERA5_PREDICTORS, plevels=ERA5_PLEVELS) - Era5_ds = Era5.merge(chunks=-1) + Era5_ds = Era5.merge(chunks=CHUNKS) # whether to use digital elevation model if DEM: diff --git a/climax/main/downscale_train.py b/climax/main/downscale_train.py index 1535759bd3d24c586ef3228e1eda4d394a7cca81..770750a672fc4c5dc23218ebf899cab7abb5fc2e 100644 --- a/climax/main/downscale_train.py +++ b/climax/main/downscale_train.py @@ -28,7 +28,7 @@ from climax.main.config import (ERA5_PLEVELS, ERA5_PREDICTORS, PREDICTAND, OVERWRITE, DEM, DEM_FEATURES, STRATIFY, WET_DAY_THRESHOLD, VALID_SIZE, ANOMALIES, OPTIM_PARAMS, LR_SCHEDULER, - LR_SCHEDULER_PARAMS) + LR_SCHEDULER_PARAMS, CHUNKS) from climax.main.io import ERA5_PATH, OBS_PATH, DEM_PATH, MODEL_PATH # module level logger @@ -70,7 +70,7 @@ if __name__ == '__main__': LogConfig.init_log('Initializing ERA5 predictors.') Era5 = ERA5Dataset(ERA5_PATH.joinpath('ERA5'), ERA5_PREDICTORS, plevels=ERA5_PLEVELS) - Era5_ds = Era5.merge(chunks={'time': 365}) + Era5_ds = Era5.merge(chunks=CHUNKS) # initialize OBS predictand dataset LogConfig.init_log('Initializing observations for predictand: {}' diff --git a/climax/main/lr_range_test.py b/climax/main/lr_range_test.py new file mode 100644 index 0000000000000000000000000000000000000000..99e716e6da74cbf1f1fd3c5b4d5db4b327c4c8ff --- /dev/null +++ b/climax/main/lr_range_test.py @@ -0,0 +1,193 @@ +"""Learning rate range test.""" + +# !/usr/bin/env python +# -*- coding: utf-8 -*- + +# builtins +import sys +import time +import logging +from datetime import timedelta +from logging.config import dictConfig + +# externals +import xarray as xr +from sklearn.model_selection import train_test_split +from torch.utils.data import DataLoader + +# locals +from pysegcnn.core.utils import search_files +from pysegcnn.core.trainer import NetworkTrainer, LogConfig +from pysegcnn.core.models import Network +from pysegcnn.core.logging import log_conf +from climax.core.dataset import ERA5Dataset, NetCDFDataset +from climax.core.loss import MSELoss, L1Loss +from climax.main.config import (ERA5_PLEVELS, ERA5_PREDICTORS, PREDICTAND, + CALIB_PERIOD, DOY, SHUFFLE, BATCH_SIZE, OPTIM, + NORM, NET, LOSS, FILTERS, DEM, DEM_FEATURES, + STRATIFY, WET_DAY_THRESHOLD, VALID_SIZE, + ANOMALIES) +from climax.main.io import ERA5_PATH, OBS_PATH, DEM_PATH, MODEL_PATH + +# module level logger +LOGGER = logging.getLogger(__name__) + +# network training configuration +TRAIN_CONFIG = { + 'checkpoint_state': {}, + 'epochs': 50, + 'save': True, + 'save_loaders': False, + 'early_stop': True, + 'patience': 100, + 'multi_gpu': True, + 'classification': False, + 'clip_gradients': False + } + +# minimum and maximum learning rate +MIN_LR = 1e-4 + +# learning rate scheduler: increase lr each epoch +LR_SCHEDULER = torch.optim.lr_scheduler.ExponentialLR +LR_SCHEDULER_PARAMS = {'gamma': 1.15} + + +if __name__ == '__main__': + + # initialize timing + start_time = time.monotonic() + + # initialize network filename + state_file = ERA5Dataset.state_file( + NET, PREDICTAND, ERA5_PREDICTORS, ERA5_PLEVELS, dem=DEM, + dem_features=DEM_FEATURES, doy=DOY, loss=LOSS, anomalies=ANOMALIES, + optim=OPTIM) + + # indicate lr range test + state_file = state_file.replace('.pt', 'lr_test.pt') + + # path to model state + state_file = MODEL_PATH.joinpath(PREDICTAND, state_file) + + # initialize logging + log_file = MODEL_PATH.joinpath(PREDICTAND, + state_file.name.replace('.pt', '_log.txt')) + if log_file.exists(): + log_file.unlink() + dictConfig(log_conf(log_file)) + + # initialize downscaling + LogConfig.init_log('Initializing learning rate test for period: {}'.format( + ' - '.join([str(CALIB_PERIOD[0]), str(CALIB_PERIOD[-1])]))) + + # initialize ERA5 predictor dataset + LogConfig.init_log('Initializing ERA5 predictors.') + Era5 = ERA5Dataset(ERA5_PATH.joinpath('ERA5'), ERA5_PREDICTORS, + plevels=ERA5_PLEVELS) + Era5_ds = Era5.merge(chunks=CHUNKS) + + # initialize OBS predictand dataset + LogConfig.init_log('Initializing observations for predictand: {}' + .format(PREDICTAND)) + + # check whether to joinlty train tasmin and tasmax + if PREDICTAND == 'tas': + # read both tasmax and tasmin + tasmax = xr.open_dataset( + search_files(OBS_PATH.joinpath('tasmax'), '.nc$').pop()) + tasmin = xr.open_dataset( + search_files(OBS_PATH.joinpath('tasmin'), '.nc$').pop()) + Obs_ds = xr.merge([tasmax, tasmin]) + else: + # read in-situ gridded observations + Obs_ds = search_files(OBS_PATH.joinpath(PREDICTAND), '.nc$').pop() + Obs_ds = xr.open_dataset(Obs_ds) + + # whether to use digital elevation model + if DEM: + # digital elevation model: Copernicus EU-Dem v1.1 + dem = search_files(DEM_PATH, '^eu_dem_v11_stt.nc$').pop() + + # read elevation and compute slope and aspect + dem = ERA5Dataset.dem_features( + dem, {'y': Era5_ds.y, 'x': Era5_ds.x}, + add_coord={'time': Era5_ds.time}) + + # check whether to use slope and aspect + if not DEM_FEATURES: + dem = dem.drop_vars(['slope', 'aspect']).chunk(Era5_ds.chunks) + + # add dem to set of predictor variables + Era5_ds = xr.merge([Era5_ds, dem]) + + # initialize network and optimizer + LogConfig.init_log('Initializing network and optimizer.') + + # define number of output fields + # check whether modelling pr with probabilistic approach + outputs = len(Obs_ds.data_vars) + if PREDICTAND == 'pr': + outputs = (1 if (isinstance(LOSS, MSELoss) or isinstance(LOSS, L1Loss)) + else 3) + + # instanciate network + inputs = len(Era5_ds.data_vars) + 2 if DOY else len(Era5_ds.data_vars) + net = NET(state_file, inputs, outputs, filters=FILTERS) + + # initialize optimizer + if OPTIM == torch.optim.SGD: + optimizer = OPTIM(net.parameters(), lr=MIN_LR, weight_decay=0, + momentum=0.99) + else: + optimizer = OPTIM(net.parameters(), lr=MIN_LR, weight_decay=0) + + # initialize learning rate scheduler + if LR_SCHEDULER is not None: + LR_SCHEDULER = LR_SCHEDULER(optimizer, **LR_SCHEDULER_PARAMS) + + # initialize training data + LogConfig.init_log('Initializing training data.') + + # split calibration period into training and validation period + if PREDICTAND == 'pr' and STRATIFY: + # stratify training and validation dataset by number of + # observed wet days for precipitation + wet_days = (Obs_ds.sel(time=CALIB_PERIOD).mean(dim=('y', 'x')) + >= WET_DAY_THRESHOLD).to_array().values.squeeze() + train, valid = train_test_split( + CALIB_PERIOD, stratify=wet_days, test_size=VALID_SIZE) + + # sort chronologically + train, valid = sorted(train), sorted(valid) + else: + train, valid = train_test_split(CALIB_PERIOD, shuffle=False, + test_size=VALID_SIZE) + + # training and validation dataset + Era5_train, Obs_train = Era5_ds.sel(time=train), Obs_ds.sel(time=train) + Era5_valid, Obs_valid = Era5_ds.sel(time=valid), Obs_ds.sel(time=valid) + + # create PyTorch compliant dataset and dataloader instances for model + # training + train_ds = NetCDFDataset(Era5_train, Obs_train, normalize=NORM, doy=DOY, + anomalies=ANOMALIES) + valid_ds = NetCDFDataset(Era5_valid, Obs_valid, normalize=NORM, doy=DOY, + anomalies=ANOMALIES) + train_dl = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=SHUFFLE, + drop_last=False) + valid_dl = DataLoader(valid_ds, batch_size=BATCH_SIZE, shuffle=SHUFFLE, + drop_last=False) + + # initialize network trainer + trainer = NetworkTrainer(net, optimizer, net.state_file, train_dl, + valid_dl, loss_function=LOSS, + lr_scheduler=LR_SCHEDULER, **TRAIN_CONFIG) + + # train model + state = trainer.train() + + # log execution time of script + LogConfig.init_log('Execution time of script {}: {}' + .format(__file__, timedelta(seconds=time.monotonic() - + start_time)))