diff --git a/climax/core/dataset.py b/climax/core/dataset.py
index 1c58328da88931939644fa2ab69b26e9cec3834d..d90a72c929173f84cae69875de2c47776ddfcdfd 100644
--- a/climax/core/dataset.py
+++ b/climax/core/dataset.py
@@ -104,7 +104,7 @@ class EoDataset(torch.utils.data.Dataset):
     @staticmethod
     def state_file(model, predictand, predictors, plevels, dem=False,
                    dem_features=False, doy=False, loss=None, cv=None,
-                   season=None, anomalies=False, decay=None):
+                   season=None, anomalies=False, decay=None, optim=None):
 
         # naming convention:
         # <model>_<predictand>_<Ppredictors>_<plevels>_<Spredictors>.pt
@@ -134,6 +134,10 @@ class EoDataset(torch.utils.data.Dataset):
             # add name of loss function to state file
             state_file = '_'.join([state_file, repr(loss).strip('()')])
 
+        # add suffix for optimizer
+        state_file = ('_'.join([state_file, optim.__name__]) if optim is not
+                      None else state_file)
+
         # add suffix for weight decay values
         state_file = ('_'.join([state_file, 'd{:.0e}'.format(decay)]) if decay
                       is not None else state_file)
diff --git a/climax/main/config.py b/climax/main/config.py
index f246cb544d03311e01224c5e2f9e9015cc6e5530..2ab3734969a1e0df3f0990bce2dd882aa279e224 100644
--- a/climax/main/config.py
+++ b/climax/main/config.py
@@ -30,7 +30,6 @@ assert all([var in ERA5_P_VARIABLES for var in ERA5_P_PREDICTORS])
 # ERA5_S_PREDICTORS = ['mean_sea_level_pressure', 'orography', '2m_temperature']
 # ERA5_S_PREDICTORS = ['mean_sea_level_pressure']
 ERA5_S_PREDICTORS = ['surface_pressure']
-# ERA5_S_PREDICTORS = ['total_precipitation']
 assert all([var in ERA5_S_VARIABLES for var in ERA5_S_PREDICTORS])
 
 # ERA5 predictor variables
@@ -43,6 +42,9 @@ ERA5_PLEVELS = [500, 850]
 # Anomaly = (time_series  - mean(time_series)) / (std(time_series))
 ANOMALIES = False
 
+# Dask chunk size for loading the training data
+CHUNKS = {'time': 365}
+
 # -----------------------------------------------------------------------------
 # Auxiliary predictors --------------------------------------------------------
 # -----------------------------------------------------------------------------
@@ -74,6 +76,9 @@ VALID_PERIOD = np.arange(
     datetime.datetime.strptime('1991-01-01', '%Y-%m-%d').date(),
     datetime.datetime.strptime('2011-01-01', '%Y-%m-%d').date())
 
+# entire reference period
+REFERENCE_PERIOD = np.concatenate([CALIB_PERIOD, VALID_PERIOD], axis=0)
+
 # stratify training/validation set for precipitation by number of wet days
 STRATIFY = False
 
@@ -82,7 +87,7 @@ STRATIFY = False
 #                               10% of CALIB_PERIOD for validation
 VALID_SIZE = 0.1
 
-# whether to train using cross-validation
+# number of folds for training with KFold cross-validation
 CV = 5
 
 # -----------------------------------------------------------------------------
@@ -118,23 +123,23 @@ LOSS = MSELoss()
 
 # stochastic optimization algorithm
 OPTIM = torch.optim.SGD
-OPTIM_PARAMS = {'lr': 0.005, # learning rate
+# OPTIM = torch.optim.Adam
+OPTIM_PARAMS = {'lr': 1e-1, # learning rate
                 'weight_decay': 1e-6  # regularization rate
                 }
 if OPTIM == torch.optim.SGD:
     OPTIM_PARAMS['momentum'] = 0.9
 
 # learning rate scheduler
-# LR_SCHEDULER = torch.optim.lr_scheduler.ExponentialLR
+# LR_SCHEDULER = torch.optim.lr_scheduler.MultiStepLR
 LR_SCHEDULER = None
-LR_SCHEDULER_PARAMS = {'gamma': 0.9}
+LR_SCHEDULER_PARAMS = {'gamma': 0.25, 'milestones': [1, 3]}
 
 # whether to randomly shuffle time steps or to conserve time series for model
 # training
 SHUFFLE = True
 
-# whether to normalize the training data to [0, 1] (True) or to standardize to
-# mean=0, std=1 (False)
+# whether to normalize the training data to [0, 1]
 NORM = True
 
 # batch size: number of time steps processed by the net in each iteration
@@ -143,11 +148,11 @@ BATCH_SIZE = 16
 # network training configuration
 TRAIN_CONFIG = {
     'checkpoint_state': {},
-    'epochs': 50,
+    'epochs': 250,
     'save': True,
     'save_loaders': False,
     'early_stop': True,
-    'patience': 5,
+    'patience': 25,
     'multi_gpu': True,
     'classification': False,
     'clip_gradients': True
diff --git a/climax/main/downscale_infer.py b/climax/main/downscale_infer.py
index eb6cf2eaa02481aee6de5f5cb69dd2df972619f5..beff9e71ec02d17880b435c90e13cb6c3265963b 100644
--- a/climax/main/downscale_infer.py
+++ b/climax/main/downscale_infer.py
@@ -23,7 +23,7 @@ from climax.core.predict import predict_ERA5
 from climax.core.utils import split_date_range
 from climax.main.config import (ERA5_PREDICTORS, ERA5_PLEVELS, PREDICTAND, NET,
                                 VALID_PERIOD, BATCH_SIZE, NORM, DOY, NYEARS,
-                                DEM, DEM_FEATURES, LOSS, ANOMALIES,
+                                DEM, DEM_FEATURES, LOSS, ANOMALIES, OPTIM,
                                 OPTIM_PARAMS)
 from climax.main.io import ERA5_PATH, DEM_PATH, MODEL_PATH, TARGET_PATH
 
@@ -40,7 +40,7 @@ if __name__ == '__main__':
     state_file = ERA5Dataset.state_file(
         NET, PREDICTAND, ERA5_PREDICTORS, ERA5_PLEVELS, dem=DEM,
         dem_features=DEM_FEATURES, doy=DOY, loss=LOSS, anomalies=ANOMALIES,
-        decay=OPTIM_PARAMS['weight_decay'])
+        decay=OPTIM_PARAMS['weight_decay'], optim=OPTIM)
 
     # path to model state
     state_file = MODEL_PATH.joinpath(PREDICTAND, state_file)
diff --git a/climax/main/downscale_train.py b/climax/main/downscale_train.py
index 74390e55bdd7787d77a799c03abd68a111fd2731..65c15300f7fb67c756a6357cf74bba17c25f21ae 100644
--- a/climax/main/downscale_train.py
+++ b/climax/main/downscale_train.py
@@ -44,7 +44,7 @@ if __name__ == '__main__':
     state_file = ERA5Dataset.state_file(
         NET, PREDICTAND, ERA5_PREDICTORS, ERA5_PLEVELS, dem=DEM,
         dem_features=DEM_FEATURES, doy=DOY, loss=LOSS, anomalies=ANOMALIES,
-        decay=OPTIM_PARAMS['weight_decay'])
+        decay=OPTIM_PARAMS['weight_decay'], optim=OPTIM)
 
     # path to model state
     state_file = MODEL_PATH.joinpath(PREDICTAND, state_file)