From bb163db5d6c6a844de163ca5c91a4931789d874b Mon Sep 17 00:00:00 2001
From: "Daniel.Frisinghelli" <daniel.frisinghelli@eurac.edu>
Date: Wed, 17 Nov 2021 10:28:54 +0100
Subject: [PATCH] Optimal hyperparameters from sensitivity analysis.

---
 climax/main/config.py | 29 ++++++++++++++++++++++-------
 1 file changed, 22 insertions(+), 7 deletions(-)

diff --git a/climax/main/config.py b/climax/main/config.py
index a8e78c2..6e7816e 100644
--- a/climax/main/config.py
+++ b/climax/main/config.py
@@ -28,8 +28,6 @@ ERA5_P_PREDICTORS = ['geopotential', 'temperature', 'u_component_of_wind',
 assert all([var in ERA5_P_VARIABLES for var in ERA5_P_PREDICTORS])
 
 # ERA5 predictor variables on single levels
-# ERA5_S_PREDICTORS = ['mean_sea_level_pressure', 'orography', '2m_temperature']
-# ERA5_S_PREDICTORS = ['mean_sea_level_pressure']
 ERA5_S_PREDICTORS = ['surface_pressure']
 # ERA5_S_PREDICTORS = ['total_precipitation']
 assert all([var in ERA5_S_VARIABLES for var in ERA5_S_PREDICTORS])
@@ -129,44 +127,61 @@ LOSS = MSELoss()
 # LOSS = BernoulliGammaLoss(min_amount=1)
 # LOSS = BernoulliWeibullLoss(min_amount=1)
 
+# batch size: number of time steps processed by the net in each iteration
+BATCH_SIZE = 16
+
 # stochastic optimization algorithm
-# OPTIM = torch.optim.SGD
 OPTIM = torch.optim.Adam
 
-# batch size: number of time steps processed by the net in each iteration
-BATCH_SIZE = 16
+# stochastic hyperparameters determined from sensitivity analysis
 
-# maximum learning rate determined from learning rate range test
+# minimum temperature
 if PREDICTAND is 'tasmin':
+
+    # learning rate and weight decay: based on sensitivity analysis
     if isinstance(LOSS, L1Loss):
         MAX_LR = 0.001 if OPTIM is torch.optim.Adam else 0.004
+        WEIGHT_DECAY = 1e-3
     if isinstance(LOSS, MSELoss):
         MAX_LR = 0.001 if OPTIM is torch.optim.Adam else 0.002
+        WEIGHT_DECAY = 0
 
+# maximum temperature
 if PREDICTAND is 'tasmax':
+
     if isinstance(LOSS, L1Loss):
         MAX_LR = 0.001
+        WEIGHT_DECAY = 1e-3
     if isinstance(LOSS, MSELoss):
         MAX_LR = 0.001 if OPTIM is torch.optim.Adam else 0.004
+        WEIGHT_DECAY = 1e-2
 
+# precipitation
 if PREDICTAND is 'pr':
+
     if isinstance(LOSS, L1Loss):
         MAX_LR = 0.001
+        WEIGHT_DECAY = 1e-5
     if isinstance(LOSS, MSELoss):
         MAX_LR = 0.0004
+        WEIGHT_DECAY = 1e-3
     if isinstance(LOSS, BernoulliGammaLoss):
+
         # learning rates for supersampling task
         if not any(ERA5_P_PREDICTORS) and ERA5_S_PREDICTORS == 'pr':
             MAX_LR = 0.001
         else:
             MAX_LR = 0.0005 if OPTIM is torch.optim.Adam else 0.001
 
+        # weight decay
+        WEIGHT_DECAY = 1e-2
+
 # base learning rate: MAX_LR / 4 (Smith L. (2017))
 BASE_LR = MAX_LR / 4
 
 # optimization parameters
 OPTIM_PARAMS = {'lr': BASE_LR,
-                'weight_decay': 0
+                'weight_decay': WEIGHT_DECAY
                 }
 if OPTIM is torch.optim.SGD:
     OPTIM_PARAMS['momentum'] = 0.99  # SGD with momentum
-- 
GitLab