diff --git a/climax/main/config.py b/climax/main/config.py index 6e7816e9456ee039bdc478fc35bddcc9b0871d0f..4c796a3ffa104330d7ddbe6e2766677786c09645 100644 --- a/climax/main/config.py +++ b/climax/main/config.py @@ -22,14 +22,14 @@ from climax.core.loss import (BernoulliGammaLoss, MSELoss, L1Loss, # ----------------------------------------------------------------------------- # ERA5 predictor variables on pressure levels -ERA5_P_PREDICTORS = ['geopotential', 'temperature', 'u_component_of_wind', - 'v_component_of_wind', 'specific_humidity'] -# ERA5_P_PREDICTORS = [] +# ERA5_P_PREDICTORS = ['geopotential', 'temperature', 'u_component_of_wind', +# 'v_component_of_wind', 'specific_humidity'] +ERA5_P_PREDICTORS = [] assert all([var in ERA5_P_VARIABLES for var in ERA5_P_PREDICTORS]) # ERA5 predictor variables on single levels -ERA5_S_PREDICTORS = ['surface_pressure'] -# ERA5_S_PREDICTORS = ['total_precipitation'] +# ERA5_S_PREDICTORS = ['surface_pressure'] +ERA5_S_PREDICTORS = ['total_precipitation'] assert all([var in ERA5_S_VARIABLES for var in ERA5_S_PREDICTORS]) # ERA5 predictor variables @@ -50,10 +50,10 @@ CHUNKS = {'time': 365} # ----------------------------------------------------------------------------- # include day of year as predictor -DOY = True +DOY = False # use digital elevation model instead of model orography -DEM = True +DEM = False if DEM: # remove model orography when using DEM if 'orography' in ERA5_S_PREDICTORS: @@ -91,14 +91,14 @@ VALID_SIZE = 0.2 CV = 5 # number of bootstrapped model trainings -BOOTSTRAP = 10 +BOOTSTRAP=10 # ----------------------------------------------------------------------------- # Observations ---------------------------------------------------------------- # ----------------------------------------------------------------------------- # target variable: check if target variable is valid -PREDICTAND = 'pr' +PREDICTAND='pr' assert PREDICTAND in PREDICTANDS # threshold defining the minimum amount of precipitation (mm) for a wet day @@ -122,10 +122,10 @@ FILTERS = [32, 64, 128, 256] # for precipitation: # BernoulliGammaLoss (NLL of Bernoulli-Gamma distribution) # BernoulliWeibullLoss (NLL of Bernoulli-Weibull distribution) -# LOSS = L1Loss() -LOSS = MSELoss() -# LOSS = BernoulliGammaLoss(min_amount=1) -# LOSS = BernoulliWeibullLoss(min_amount=1) +# LOSS=MSELoss() +LOSS=MSELoss() +# LOSS=MSELoss() +# LOSS=MSELoss() # batch size: number of time steps processed by the net in each iteration BATCH_SIZE = 16 @@ -164,17 +164,20 @@ if PREDICTAND is 'pr': WEIGHT_DECAY = 1e-5 if isinstance(LOSS, MSELoss): MAX_LR = 0.0004 - WEIGHT_DECAY = 1e-3 + # weight decay rates for supersampling task + if not any(ERA5_P_PREDICTORS) and ERA5_S_PREDICTORS == 'pr': + WEIGHT_DECAY = 1e-5 + else: + WEIGHT_DECAY = 1e-3 if isinstance(LOSS, BernoulliGammaLoss): # learning rates for supersampling task if not any(ERA5_P_PREDICTORS) and ERA5_S_PREDICTORS == 'pr': MAX_LR = 0.001 + WEIGHT_DECAY = 1e-4 else: MAX_LR = 0.0005 if OPTIM is torch.optim.Adam else 0.001 - - # weight decay - WEIGHT_DECAY = 1e-2 + WEIGHT_DECAY = 1e-2 # base learning rate: MAX_LR / 4 (Smith L. (2017)) BASE_LR = MAX_LR / 4