Skip to content
Snippets Groups Projects
Commit bb163db5 authored by Frisinghelli Daniel's avatar Frisinghelli Daniel
Browse files

Optimal hyperparameters from sensitivity analysis.

parent c8a39ab6
No related branches found
No related tags found
No related merge requests found
......@@ -28,8 +28,6 @@ ERA5_P_PREDICTORS = ['geopotential', 'temperature', 'u_component_of_wind',
assert all([var in ERA5_P_VARIABLES for var in ERA5_P_PREDICTORS])
# ERA5 predictor variables on single levels
# ERA5_S_PREDICTORS = ['mean_sea_level_pressure', 'orography', '2m_temperature']
# ERA5_S_PREDICTORS = ['mean_sea_level_pressure']
ERA5_S_PREDICTORS = ['surface_pressure']
# ERA5_S_PREDICTORS = ['total_precipitation']
assert all([var in ERA5_S_VARIABLES for var in ERA5_S_PREDICTORS])
......@@ -129,44 +127,61 @@ LOSS = MSELoss()
# LOSS = BernoulliGammaLoss(min_amount=1)
# LOSS = BernoulliWeibullLoss(min_amount=1)
# batch size: number of time steps processed by the net in each iteration
BATCH_SIZE = 16
# stochastic optimization algorithm
# OPTIM = torch.optim.SGD
OPTIM = torch.optim.Adam
# batch size: number of time steps processed by the net in each iteration
BATCH_SIZE = 16
# stochastic hyperparameters determined from sensitivity analysis
# maximum learning rate determined from learning rate range test
# minimum temperature
if PREDICTAND is 'tasmin':
# learning rate and weight decay: based on sensitivity analysis
if isinstance(LOSS, L1Loss):
MAX_LR = 0.001 if OPTIM is torch.optim.Adam else 0.004
WEIGHT_DECAY = 1e-3
if isinstance(LOSS, MSELoss):
MAX_LR = 0.001 if OPTIM is torch.optim.Adam else 0.002
WEIGHT_DECAY = 0
# maximum temperature
if PREDICTAND is 'tasmax':
if isinstance(LOSS, L1Loss):
MAX_LR = 0.001
WEIGHT_DECAY = 1e-3
if isinstance(LOSS, MSELoss):
MAX_LR = 0.001 if OPTIM is torch.optim.Adam else 0.004
WEIGHT_DECAY = 1e-2
# precipitation
if PREDICTAND is 'pr':
if isinstance(LOSS, L1Loss):
MAX_LR = 0.001
WEIGHT_DECAY = 1e-5
if isinstance(LOSS, MSELoss):
MAX_LR = 0.0004
WEIGHT_DECAY = 1e-3
if isinstance(LOSS, BernoulliGammaLoss):
# learning rates for supersampling task
if not any(ERA5_P_PREDICTORS) and ERA5_S_PREDICTORS == 'pr':
MAX_LR = 0.001
else:
MAX_LR = 0.0005 if OPTIM is torch.optim.Adam else 0.001
# weight decay
WEIGHT_DECAY = 1e-2
# base learning rate: MAX_LR / 4 (Smith L. (2017))
BASE_LR = MAX_LR / 4
# optimization parameters
OPTIM_PARAMS = {'lr': BASE_LR,
'weight_decay': 0
'weight_decay': WEIGHT_DECAY
}
if OPTIM is torch.optim.SGD:
OPTIM_PARAMS['momentum'] = 0.99 # SGD with momentum
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment