Skip to content
Snippets Groups Projects
Commit 711f04c3 authored by Frisinghelli Daniel's avatar Frisinghelli Daniel
Browse files

Implemented CyclicLR schedule.

parent f15763b2
No related branches found
No related tags found
No related merge requests found
......@@ -24,12 +24,14 @@ from climax.core.loss import (BernoulliGammaLoss, MSELoss, L1Loss,
# ERA5 predictor variables on pressure levels
ERA5_P_PREDICTORS = ['geopotential', 'temperature', 'u_component_of_wind',
'v_component_of_wind', 'specific_humidity']
# ERA5_P_PREDICTORS = []
assert all([var in ERA5_P_VARIABLES for var in ERA5_P_PREDICTORS])
# ERA5 predictor variables on single levels
# ERA5_S_PREDICTORS = ['mean_sea_level_pressure', 'orography', '2m_temperature']
# ERA5_S_PREDICTORS = ['mean_sea_level_pressure']
ERA5_S_PREDICTORS = ['surface_pressure']
# ERA5_S_PREDICTORS = ['total_precipitation']
assert all([var in ERA5_S_VARIABLES for var in ERA5_S_PREDICTORS])
# ERA5 predictor variables
......@@ -83,8 +85,8 @@ REFERENCE_PERIOD = np.concatenate([CALIB_PERIOD, VALID_PERIOD], axis=0)
STRATIFY = False
# size of the validation set w.r.t. the training set
# e.g., VALID_SIZE = 0.1 means: 90% of CALIB_PERIOD for training
# 10% of CALIB_PERIOD for validation
# e.g., VALID_SIZE = 0.2 means: 80% of CALIB_PERIOD for training
# 20% of CALIB_PERIOD for validation
VALID_SIZE = 0.2
# number of folds for training with KFold cross-validation
......@@ -117,23 +119,31 @@ FILTERS = [32, 64, 128, 256]
# BernoulliGammaLoss (NLL of Bernoulli-Gamma distribution)
# BernoulliWeibullLoss (NLL of Bernoulli-Weibull distribution)
# LOSS = L1Loss()
# LOSS = MSELoss()
LOSS = BernoulliGammaLoss(min_amount=1)
LOSS = MSELoss()
# LOSS = BernoulliGammaLoss(min_amount=1)
# LOSS = BernoulliWeibullLoss(min_amount=1)
# batch size: number of time steps processed by the net in each iteration
BATCH_SIZE = 16
# base learning rate: constant or CyclicLR policy
BASE_LR = 1e-4
# maximum learning rate for CyclicLR policy
MAX_LR = 1e-3
# stochastic optimization algorithm
OPTIM = torch.optim.SGD
# OPTIM = torch.optim.Adam
OPTIM_PARAMS = {'lr': 1e-3, # learning rate
'weight_decay': 0 # regularization rate
}
# OPTIM = torch.optim.SGD
OPTIM = torch.optim.Adam
OPTIM_PARAMS = {'lr': BASE_LR, 'weight_decay': 0}
if OPTIM == torch.optim.SGD:
OPTIM_PARAMS['momentum'] = 0.99
OPTIM_PARAMS['momentum'] = 0.99 # SGD with momentum
# learning rate scheduler
# LR_SCHEDULER = torch.optim.lr_scheduler.MultiStepLR
# learning rate scheduler: CyclicLR policy
LR_SCHEDULER = None
LR_SCHEDULER_PARAMS = {'gamma': 0.25, 'milestones': [1, 3]}
# LR_SCHEDULER = torch.optim.lr_scheduler.CyclicLR
LR_SCHEDULER_PARAMS = {'base_lr': BASE_LR, 'max_lr': MAX_LR,
'mode': 'triangular', 'step_size_up': 400}
# whether to randomly shuffle time steps or to conserve time series for model
# training
......@@ -142,9 +152,6 @@ SHUFFLE = True
# whether to normalize the training data to [0, 1]
NORM = True
# batch size: number of time steps processed by the net in each iteration
BATCH_SIZE = 16
# network training configuration
TRAIN_CONFIG = {
'checkpoint_state': {},
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment