From 9c912abe91b87844093de026bf0a1b1c70dbb22b Mon Sep 17 00:00:00 2001
From: "Daniel.Frisinghelli" <daniel.frisinghelli@eurac.edu>
Date: Tue, 21 Dec 2021 14:56:12 +0100
Subject: [PATCH] Set learning and weight decay rates for supersampling task.

---
 climax/main/config.py | 37 ++++++++++++++++++++-----------------
 1 file changed, 20 insertions(+), 17 deletions(-)

diff --git a/climax/main/config.py b/climax/main/config.py
index 6e7816e..4c796a3 100644
--- a/climax/main/config.py
+++ b/climax/main/config.py
@@ -22,14 +22,14 @@ from climax.core.loss import (BernoulliGammaLoss, MSELoss, L1Loss,
 # -----------------------------------------------------------------------------
 
 # ERA5 predictor variables on pressure levels
-ERA5_P_PREDICTORS = ['geopotential', 'temperature', 'u_component_of_wind',
-                     'v_component_of_wind', 'specific_humidity']
-# ERA5_P_PREDICTORS = []
+# ERA5_P_PREDICTORS = ['geopotential', 'temperature', 'u_component_of_wind',
+#                      'v_component_of_wind', 'specific_humidity']
+ERA5_P_PREDICTORS = []
 assert all([var in ERA5_P_VARIABLES for var in ERA5_P_PREDICTORS])
 
 # ERA5 predictor variables on single levels
-ERA5_S_PREDICTORS = ['surface_pressure']
-# ERA5_S_PREDICTORS = ['total_precipitation']
+# ERA5_S_PREDICTORS = ['surface_pressure']
+ERA5_S_PREDICTORS = ['total_precipitation']
 assert all([var in ERA5_S_VARIABLES for var in ERA5_S_PREDICTORS])
 
 # ERA5 predictor variables
@@ -50,10 +50,10 @@ CHUNKS = {'time': 365}
 # -----------------------------------------------------------------------------
 
 # include day of year as predictor
-DOY = True
+DOY = False
 
 # use digital elevation model instead of model orography
-DEM = True
+DEM = False
 if DEM:
     # remove model orography when using DEM
     if 'orography' in ERA5_S_PREDICTORS:
@@ -91,14 +91,14 @@ VALID_SIZE = 0.2
 CV = 5
 
 # number of bootstrapped model trainings
-BOOTSTRAP = 10
+BOOTSTRAP=10
 
 # -----------------------------------------------------------------------------
 # Observations ----------------------------------------------------------------
 # -----------------------------------------------------------------------------
 
 # target variable: check if target variable is valid
-PREDICTAND = 'pr'
+PREDICTAND='pr'
 assert PREDICTAND in PREDICTANDS
 
 # threshold  defining the minimum amount of precipitation (mm) for a wet day
@@ -122,10 +122,10 @@ FILTERS = [32, 64, 128, 256]
 # for precipitation:
 #     BernoulliGammaLoss (NLL of Bernoulli-Gamma distribution)
 #     BernoulliWeibullLoss (NLL of Bernoulli-Weibull distribution)
-# LOSS = L1Loss()
-LOSS = MSELoss()
-# LOSS = BernoulliGammaLoss(min_amount=1)
-# LOSS = BernoulliWeibullLoss(min_amount=1)
+# LOSS=MSELoss()
+LOSS=MSELoss()
+# LOSS=MSELoss()
+# LOSS=MSELoss()
 
 # batch size: number of time steps processed by the net in each iteration
 BATCH_SIZE = 16
@@ -164,17 +164,20 @@ if PREDICTAND is 'pr':
         WEIGHT_DECAY = 1e-5
     if isinstance(LOSS, MSELoss):
         MAX_LR = 0.0004
-        WEIGHT_DECAY = 1e-3
+        # weight decay rates for supersampling task
+        if not any(ERA5_P_PREDICTORS) and ERA5_S_PREDICTORS == 'pr':
+            WEIGHT_DECAY = 1e-5
+        else:
+            WEIGHT_DECAY = 1e-3
     if isinstance(LOSS, BernoulliGammaLoss):
 
         # learning rates for supersampling task
         if not any(ERA5_P_PREDICTORS) and ERA5_S_PREDICTORS == 'pr':
             MAX_LR = 0.001
+            WEIGHT_DECAY = 1e-4
         else:
             MAX_LR = 0.0005 if OPTIM is torch.optim.Adam else 0.001
-
-        # weight decay
-        WEIGHT_DECAY = 1e-2
+            WEIGHT_DECAY = 1e-2
 
 # base learning rate: MAX_LR / 4 (Smith L. (2017))
 BASE_LR = MAX_LR / 4
-- 
GitLab