diff --git a/climax/main/config.py b/climax/main/config.py index a86c8d9ad6e80ba2d0b2ab6400ac18f99502a704..46a7ff89beb8a0d29be7079c4e4cbe8176108ab0 100644 --- a/climax/main/config.py +++ b/climax/main/config.py @@ -24,12 +24,14 @@ from climax.core.loss import (BernoulliGammaLoss, MSELoss, L1Loss, # ERA5 predictor variables on pressure levels ERA5_P_PREDICTORS = ['geopotential', 'temperature', 'u_component_of_wind', 'v_component_of_wind', 'specific_humidity'] +# ERA5_P_PREDICTORS = [] assert all([var in ERA5_P_VARIABLES for var in ERA5_P_PREDICTORS]) # ERA5 predictor variables on single levels # ERA5_S_PREDICTORS = ['mean_sea_level_pressure', 'orography', '2m_temperature'] # ERA5_S_PREDICTORS = ['mean_sea_level_pressure'] ERA5_S_PREDICTORS = ['surface_pressure'] +# ERA5_S_PREDICTORS = ['total_precipitation'] assert all([var in ERA5_S_VARIABLES for var in ERA5_S_PREDICTORS]) # ERA5 predictor variables @@ -83,8 +85,8 @@ REFERENCE_PERIOD = np.concatenate([CALIB_PERIOD, VALID_PERIOD], axis=0) STRATIFY = False # size of the validation set w.r.t. the training set -# e.g., VALID_SIZE = 0.1 means: 90% of CALIB_PERIOD for training -# 10% of CALIB_PERIOD for validation +# e.g., VALID_SIZE = 0.2 means: 80% of CALIB_PERIOD for training +# 20% of CALIB_PERIOD for validation VALID_SIZE = 0.2 # number of folds for training with KFold cross-validation @@ -117,23 +119,31 @@ FILTERS = [32, 64, 128, 256] # BernoulliGammaLoss (NLL of Bernoulli-Gamma distribution) # BernoulliWeibullLoss (NLL of Bernoulli-Weibull distribution) # LOSS = L1Loss() -# LOSS = MSELoss() -LOSS = BernoulliGammaLoss(min_amount=1) +LOSS = MSELoss() +# LOSS = BernoulliGammaLoss(min_amount=1) # LOSS = BernoulliWeibullLoss(min_amount=1) +# batch size: number of time steps processed by the net in each iteration +BATCH_SIZE = 16 + +# base learning rate: constant or CyclicLR policy +BASE_LR = 1e-4 + +# maximum learning rate for CyclicLR policy +MAX_LR = 1e-3 + # stochastic optimization algorithm -OPTIM = torch.optim.SGD -# OPTIM = torch.optim.Adam -OPTIM_PARAMS = {'lr': 1e-3, # learning rate - 'weight_decay': 0 # regularization rate - } +# OPTIM = torch.optim.SGD +OPTIM = torch.optim.Adam +OPTIM_PARAMS = {'lr': BASE_LR, 'weight_decay': 0} if OPTIM == torch.optim.SGD: - OPTIM_PARAMS['momentum'] = 0.99 + OPTIM_PARAMS['momentum'] = 0.99 # SGD with momentum -# learning rate scheduler -# LR_SCHEDULER = torch.optim.lr_scheduler.MultiStepLR +# learning rate scheduler: CyclicLR policy LR_SCHEDULER = None -LR_SCHEDULER_PARAMS = {'gamma': 0.25, 'milestones': [1, 3]} +# LR_SCHEDULER = torch.optim.lr_scheduler.CyclicLR +LR_SCHEDULER_PARAMS = {'base_lr': BASE_LR, 'max_lr': MAX_LR, + 'mode': 'triangular', 'step_size_up': 400} # whether to randomly shuffle time steps or to conserve time series for model # training @@ -142,9 +152,6 @@ SHUFFLE = True # whether to normalize the training data to [0, 1] NORM = True -# batch size: number of time steps processed by the net in each iteration -BATCH_SIZE = 16 - # network training configuration TRAIN_CONFIG = { 'checkpoint_state': {},