From 711f04c3387914ce4e2db4fa93e556927360d5f9 Mon Sep 17 00:00:00 2001
From: "Daniel.Frisinghelli" <daniel.frisinghelli@eurac.edu>
Date: Thu, 21 Oct 2021 11:44:43 +0200
Subject: [PATCH] Implemented CyclicLR schedule.

---
 climax/main/config.py | 39 +++++++++++++++++++++++----------------
 1 file changed, 23 insertions(+), 16 deletions(-)

diff --git a/climax/main/config.py b/climax/main/config.py
index a86c8d9..46a7ff8 100644
--- a/climax/main/config.py
+++ b/climax/main/config.py
@@ -24,12 +24,14 @@ from climax.core.loss import (BernoulliGammaLoss, MSELoss, L1Loss,
 # ERA5 predictor variables on pressure levels
 ERA5_P_PREDICTORS = ['geopotential', 'temperature', 'u_component_of_wind',
                      'v_component_of_wind', 'specific_humidity']
+# ERA5_P_PREDICTORS = []
 assert all([var in ERA5_P_VARIABLES for var in ERA5_P_PREDICTORS])
 
 # ERA5 predictor variables on single levels
 # ERA5_S_PREDICTORS = ['mean_sea_level_pressure', 'orography', '2m_temperature']
 # ERA5_S_PREDICTORS = ['mean_sea_level_pressure']
 ERA5_S_PREDICTORS = ['surface_pressure']
+# ERA5_S_PREDICTORS = ['total_precipitation']
 assert all([var in ERA5_S_VARIABLES for var in ERA5_S_PREDICTORS])
 
 # ERA5 predictor variables
@@ -83,8 +85,8 @@ REFERENCE_PERIOD = np.concatenate([CALIB_PERIOD, VALID_PERIOD], axis=0)
 STRATIFY = False
 
 # size of the validation set w.r.t. the training set
-# e.g., VALID_SIZE = 0.1 means: 90% of CALIB_PERIOD for training
-#                               10% of CALIB_PERIOD for validation
+# e.g., VALID_SIZE = 0.2 means: 80% of CALIB_PERIOD for training
+#                               20% of CALIB_PERIOD for validation
 VALID_SIZE = 0.2
 
 # number of folds for training with KFold cross-validation
@@ -117,23 +119,31 @@ FILTERS = [32, 64, 128, 256]
 #     BernoulliGammaLoss (NLL of Bernoulli-Gamma distribution)
 #     BernoulliWeibullLoss (NLL of Bernoulli-Weibull distribution)
 # LOSS = L1Loss()
-# LOSS = MSELoss()
-LOSS = BernoulliGammaLoss(min_amount=1)
+LOSS = MSELoss()
+# LOSS = BernoulliGammaLoss(min_amount=1)
 # LOSS = BernoulliWeibullLoss(min_amount=1)
 
+# batch size: number of time steps processed by the net in each iteration
+BATCH_SIZE = 16
+
+# base learning rate: constant or CyclicLR policy
+BASE_LR = 1e-4
+
+# maximum learning rate for CyclicLR policy
+MAX_LR = 1e-3
+
 # stochastic optimization algorithm
-OPTIM = torch.optim.SGD
-# OPTIM = torch.optim.Adam
-OPTIM_PARAMS = {'lr': 1e-3, # learning rate
-                'weight_decay': 0  # regularization rate
-                }
+# OPTIM = torch.optim.SGD
+OPTIM = torch.optim.Adam
+OPTIM_PARAMS = {'lr': BASE_LR, 'weight_decay': 0}
 if OPTIM == torch.optim.SGD:
-    OPTIM_PARAMS['momentum'] = 0.99
+    OPTIM_PARAMS['momentum'] = 0.99  # SGD with momentum
 
-# learning rate scheduler
-# LR_SCHEDULER = torch.optim.lr_scheduler.MultiStepLR
+# learning rate scheduler: CyclicLR policy
 LR_SCHEDULER = None
-LR_SCHEDULER_PARAMS = {'gamma': 0.25, 'milestones': [1, 3]}
+# LR_SCHEDULER = torch.optim.lr_scheduler.CyclicLR
+LR_SCHEDULER_PARAMS = {'base_lr': BASE_LR, 'max_lr': MAX_LR,
+                       'mode': 'triangular', 'step_size_up': 400}
 
 # whether to randomly shuffle time steps or to conserve time series for model
 # training
@@ -142,9 +152,6 @@ SHUFFLE = True
 # whether to normalize the training data to [0, 1]
 NORM = True
 
-# batch size: number of time steps processed by the net in each iteration
-BATCH_SIZE = 16
-
 # network training configuration
 TRAIN_CONFIG = {
     'checkpoint_state': {},
-- 
GitLab