diff --git a/climax/core/dataset.py b/climax/core/dataset.py index 1c58328da88931939644fa2ab69b26e9cec3834d..d90a72c929173f84cae69875de2c47776ddfcdfd 100644 --- a/climax/core/dataset.py +++ b/climax/core/dataset.py @@ -104,7 +104,7 @@ class EoDataset(torch.utils.data.Dataset): @staticmethod def state_file(model, predictand, predictors, plevels, dem=False, dem_features=False, doy=False, loss=None, cv=None, - season=None, anomalies=False, decay=None): + season=None, anomalies=False, decay=None, optim=None): # naming convention: # <model>_<predictand>_<Ppredictors>_<plevels>_<Spredictors>.pt @@ -134,6 +134,10 @@ class EoDataset(torch.utils.data.Dataset): # add name of loss function to state file state_file = '_'.join([state_file, repr(loss).strip('()')]) + # add suffix for optimizer + state_file = ('_'.join([state_file, optim.__name__]) if optim is not + None else state_file) + # add suffix for weight decay values state_file = ('_'.join([state_file, 'd{:.0e}'.format(decay)]) if decay is not None else state_file) diff --git a/climax/main/config.py b/climax/main/config.py index f246cb544d03311e01224c5e2f9e9015cc6e5530..2ab3734969a1e0df3f0990bce2dd882aa279e224 100644 --- a/climax/main/config.py +++ b/climax/main/config.py @@ -30,7 +30,6 @@ assert all([var in ERA5_P_VARIABLES for var in ERA5_P_PREDICTORS]) # ERA5_S_PREDICTORS = ['mean_sea_level_pressure', 'orography', '2m_temperature'] # ERA5_S_PREDICTORS = ['mean_sea_level_pressure'] ERA5_S_PREDICTORS = ['surface_pressure'] -# ERA5_S_PREDICTORS = ['total_precipitation'] assert all([var in ERA5_S_VARIABLES for var in ERA5_S_PREDICTORS]) # ERA5 predictor variables @@ -43,6 +42,9 @@ ERA5_PLEVELS = [500, 850] # Anomaly = (time_series - mean(time_series)) / (std(time_series)) ANOMALIES = False +# Dask chunk size for loading the training data +CHUNKS = {'time': 365} + # ----------------------------------------------------------------------------- # Auxiliary predictors -------------------------------------------------------- # ----------------------------------------------------------------------------- @@ -74,6 +76,9 @@ VALID_PERIOD = np.arange( datetime.datetime.strptime('1991-01-01', '%Y-%m-%d').date(), datetime.datetime.strptime('2011-01-01', '%Y-%m-%d').date()) +# entire reference period +REFERENCE_PERIOD = np.concatenate([CALIB_PERIOD, VALID_PERIOD], axis=0) + # stratify training/validation set for precipitation by number of wet days STRATIFY = False @@ -82,7 +87,7 @@ STRATIFY = False # 10% of CALIB_PERIOD for validation VALID_SIZE = 0.1 -# whether to train using cross-validation +# number of folds for training with KFold cross-validation CV = 5 # ----------------------------------------------------------------------------- @@ -118,23 +123,23 @@ LOSS = MSELoss() # stochastic optimization algorithm OPTIM = torch.optim.SGD -OPTIM_PARAMS = {'lr': 0.005, # learning rate +# OPTIM = torch.optim.Adam +OPTIM_PARAMS = {'lr': 1e-1, # learning rate 'weight_decay': 1e-6 # regularization rate } if OPTIM == torch.optim.SGD: OPTIM_PARAMS['momentum'] = 0.9 # learning rate scheduler -# LR_SCHEDULER = torch.optim.lr_scheduler.ExponentialLR +# LR_SCHEDULER = torch.optim.lr_scheduler.MultiStepLR LR_SCHEDULER = None -LR_SCHEDULER_PARAMS = {'gamma': 0.9} +LR_SCHEDULER_PARAMS = {'gamma': 0.25, 'milestones': [1, 3]} # whether to randomly shuffle time steps or to conserve time series for model # training SHUFFLE = True -# whether to normalize the training data to [0, 1] (True) or to standardize to -# mean=0, std=1 (False) +# whether to normalize the training data to [0, 1] NORM = True # batch size: number of time steps processed by the net in each iteration @@ -143,11 +148,11 @@ BATCH_SIZE = 16 # network training configuration TRAIN_CONFIG = { 'checkpoint_state': {}, - 'epochs': 50, + 'epochs': 250, 'save': True, 'save_loaders': False, 'early_stop': True, - 'patience': 5, + 'patience': 25, 'multi_gpu': True, 'classification': False, 'clip_gradients': True diff --git a/climax/main/downscale_infer.py b/climax/main/downscale_infer.py index eb6cf2eaa02481aee6de5f5cb69dd2df972619f5..beff9e71ec02d17880b435c90e13cb6c3265963b 100644 --- a/climax/main/downscale_infer.py +++ b/climax/main/downscale_infer.py @@ -23,7 +23,7 @@ from climax.core.predict import predict_ERA5 from climax.core.utils import split_date_range from climax.main.config import (ERA5_PREDICTORS, ERA5_PLEVELS, PREDICTAND, NET, VALID_PERIOD, BATCH_SIZE, NORM, DOY, NYEARS, - DEM, DEM_FEATURES, LOSS, ANOMALIES, + DEM, DEM_FEATURES, LOSS, ANOMALIES, OPTIM, OPTIM_PARAMS) from climax.main.io import ERA5_PATH, DEM_PATH, MODEL_PATH, TARGET_PATH @@ -40,7 +40,7 @@ if __name__ == '__main__': state_file = ERA5Dataset.state_file( NET, PREDICTAND, ERA5_PREDICTORS, ERA5_PLEVELS, dem=DEM, dem_features=DEM_FEATURES, doy=DOY, loss=LOSS, anomalies=ANOMALIES, - decay=OPTIM_PARAMS['weight_decay']) + decay=OPTIM_PARAMS['weight_decay'], optim=OPTIM) # path to model state state_file = MODEL_PATH.joinpath(PREDICTAND, state_file) diff --git a/climax/main/downscale_train.py b/climax/main/downscale_train.py index 74390e55bdd7787d77a799c03abd68a111fd2731..65c15300f7fb67c756a6357cf74bba17c25f21ae 100644 --- a/climax/main/downscale_train.py +++ b/climax/main/downscale_train.py @@ -44,7 +44,7 @@ if __name__ == '__main__': state_file = ERA5Dataset.state_file( NET, PREDICTAND, ERA5_PREDICTORS, ERA5_PLEVELS, dem=DEM, dem_features=DEM_FEATURES, doy=DOY, loss=LOSS, anomalies=ANOMALIES, - decay=OPTIM_PARAMS['weight_decay']) + decay=OPTIM_PARAMS['weight_decay'], optim=OPTIM) # path to model state state_file = MODEL_PATH.joinpath(PREDICTAND, state_file)