From bb163db5d6c6a844de163ca5c91a4931789d874b Mon Sep 17 00:00:00 2001 From: "Daniel.Frisinghelli" <daniel.frisinghelli@eurac.edu> Date: Wed, 17 Nov 2021 10:28:54 +0100 Subject: [PATCH] Optimal hyperparameters from sensitivity analysis. --- climax/main/config.py | 29 ++++++++++++++++++++++------- 1 file changed, 22 insertions(+), 7 deletions(-) diff --git a/climax/main/config.py b/climax/main/config.py index a8e78c2..6e7816e 100644 --- a/climax/main/config.py +++ b/climax/main/config.py @@ -28,8 +28,6 @@ ERA5_P_PREDICTORS = ['geopotential', 'temperature', 'u_component_of_wind', assert all([var in ERA5_P_VARIABLES for var in ERA5_P_PREDICTORS]) # ERA5 predictor variables on single levels -# ERA5_S_PREDICTORS = ['mean_sea_level_pressure', 'orography', '2m_temperature'] -# ERA5_S_PREDICTORS = ['mean_sea_level_pressure'] ERA5_S_PREDICTORS = ['surface_pressure'] # ERA5_S_PREDICTORS = ['total_precipitation'] assert all([var in ERA5_S_VARIABLES for var in ERA5_S_PREDICTORS]) @@ -129,44 +127,61 @@ LOSS = MSELoss() # LOSS = BernoulliGammaLoss(min_amount=1) # LOSS = BernoulliWeibullLoss(min_amount=1) +# batch size: number of time steps processed by the net in each iteration +BATCH_SIZE = 16 + # stochastic optimization algorithm -# OPTIM = torch.optim.SGD OPTIM = torch.optim.Adam -# batch size: number of time steps processed by the net in each iteration -BATCH_SIZE = 16 +# stochastic hyperparameters determined from sensitivity analysis -# maximum learning rate determined from learning rate range test +# minimum temperature if PREDICTAND is 'tasmin': + + # learning rate and weight decay: based on sensitivity analysis if isinstance(LOSS, L1Loss): MAX_LR = 0.001 if OPTIM is torch.optim.Adam else 0.004 + WEIGHT_DECAY = 1e-3 if isinstance(LOSS, MSELoss): MAX_LR = 0.001 if OPTIM is torch.optim.Adam else 0.002 + WEIGHT_DECAY = 0 +# maximum temperature if PREDICTAND is 'tasmax': + if isinstance(LOSS, L1Loss): MAX_LR = 0.001 + WEIGHT_DECAY = 1e-3 if isinstance(LOSS, MSELoss): MAX_LR = 0.001 if OPTIM is torch.optim.Adam else 0.004 + WEIGHT_DECAY = 1e-2 +# precipitation if PREDICTAND is 'pr': + if isinstance(LOSS, L1Loss): MAX_LR = 0.001 + WEIGHT_DECAY = 1e-5 if isinstance(LOSS, MSELoss): MAX_LR = 0.0004 + WEIGHT_DECAY = 1e-3 if isinstance(LOSS, BernoulliGammaLoss): + # learning rates for supersampling task if not any(ERA5_P_PREDICTORS) and ERA5_S_PREDICTORS == 'pr': MAX_LR = 0.001 else: MAX_LR = 0.0005 if OPTIM is torch.optim.Adam else 0.001 + # weight decay + WEIGHT_DECAY = 1e-2 + # base learning rate: MAX_LR / 4 (Smith L. (2017)) BASE_LR = MAX_LR / 4 # optimization parameters OPTIM_PARAMS = {'lr': BASE_LR, - 'weight_decay': 0 + 'weight_decay': WEIGHT_DECAY } if OPTIM is torch.optim.SGD: OPTIM_PARAMS['momentum'] = 0.99 # SGD with momentum -- GitLab