diff --git a/climax/main/config.py b/climax/main/config.py
index bba8d4d5831531cf5f86e63bbf2f0d37e25d3483..068f82a4d946ae577594e090c986eff3aa4d9fb7 100644
--- a/climax/main/config.py
+++ b/climax/main/config.py
@@ -7,6 +7,7 @@
 import datetime
 
 # externals
+import torch
 import numpy as np
 
 # locals
@@ -29,6 +30,7 @@ assert all([var in ERA5_P_VARIABLES for var in ERA5_P_PREDICTORS])
 # ERA5_S_PREDICTORS = ['mean_sea_level_pressure', 'orography', '2m_temperature']
 # ERA5_S_PREDICTORS = ['mean_sea_level_pressure']
 ERA5_S_PREDICTORS = ['surface_pressure']
+# ERA5_S_PREDICTORS = ['total_precipitation']
 assert all([var in ERA5_S_VARIABLES for var in ERA5_S_PREDICTORS])
 
 # ERA5 predictor variables
@@ -81,8 +83,7 @@ STRATIFY = False
 VALID_SIZE = 0.1
 
 # whether to train using cross-validation
-# TODO: define number of folds, description
-CV = False
+CV = 5
 
 # -----------------------------------------------------------------------------
 # Observations ----------------------------------------------------------------
@@ -115,6 +116,19 @@ LOSS = MSELoss()
 # LOSS = BernoulliGammaLoss(min_amount=1)
 # LOSS = BernoulliWeibullLoss(min_amount=1)
 
+# stochastic optimization algorithm
+OPTIM = torch.optim.SGD
+OPTIM_PARAMS = {'lr': 0.005, # learning rate
+                'weight_decay': 1e-6  # regularization rate
+                }
+if OPTIM == torch.optim.SGD:
+    OPTIM_PARAMS['momentum'] = 0.9
+
+# learning rate scheduler
+# LR_SCHEDULER = torch.optim.lr_scheduler.ExponentialLR
+LR_SCHEDULER = None
+LR_SCHEDULER_PARAMS = {'gamma': 0.9}
+
 # whether to randomly shuffle time steps or to conserve time series for model
 # training
 SHUFFLE = True
@@ -126,23 +140,18 @@ NORM = True
 # batch size: number of time steps processed by the net in each iteration
 BATCH_SIZE = 16
 
-# learning rate
-LR = 0.0005
-
-# regularization rate
-LAMBDA = 0.05
-
 # network training configuration
 TRAIN_CONFIG = {
     'checkpoint_state': {},
-    'epochs': 250,
+    'epochs': 50,
     'save': True,
     'save_loaders': False,
     'early_stop': True,
-    'patience': 50,
+    'patience': 5,
     'multi_gpu': True,
     'classification': False,
-    'clip_gradients': True
+    'clip_gradients': True,
+    # 'lr_scheduler': torch.optim.lr_scheduler.
     }
 
 # whether to overwrite existing models
diff --git a/climax/main/downscale_infer.py b/climax/main/downscale_infer.py
index 1f07b341e361e1560ad56df9f80fff02f1dcce5e..eb6cf2eaa02481aee6de5f5cb69dd2df972619f5 100644
--- a/climax/main/downscale_infer.py
+++ b/climax/main/downscale_infer.py
@@ -23,7 +23,8 @@ from climax.core.predict import predict_ERA5
 from climax.core.utils import split_date_range
 from climax.main.config import (ERA5_PREDICTORS, ERA5_PLEVELS, PREDICTAND, NET,
                                 VALID_PERIOD, BATCH_SIZE, NORM, DOY, NYEARS,
-                                DEM, DEM_FEATURES, LOSS, ANOMALIES, LAMBDA)
+                                DEM, DEM_FEATURES, LOSS, ANOMALIES,
+                                OPTIM_PARAMS)
 from climax.main.io import ERA5_PATH, DEM_PATH, MODEL_PATH, TARGET_PATH
 
 # module level logger
@@ -39,7 +40,7 @@ if __name__ == '__main__':
     state_file = ERA5Dataset.state_file(
         NET, PREDICTAND, ERA5_PREDICTORS, ERA5_PLEVELS, dem=DEM,
         dem_features=DEM_FEATURES, doy=DOY, loss=LOSS, anomalies=ANOMALIES,
-        decay=LAMBDA)
+        decay=OPTIM_PARAMS['weight_decay'])
 
     # path to model state
     state_file = MODEL_PATH.joinpath(PREDICTAND, state_file)
diff --git a/climax/main/downscale_train.py b/climax/main/downscale_train.py
index 749257c602a450a29051157a889e2fbacbfb43ee..314ae378e8b7367cf1aa4627961e4b68a9e42a4a 100644
--- a/climax/main/downscale_train.py
+++ b/climax/main/downscale_train.py
@@ -11,7 +11,6 @@ from datetime import timedelta
 from logging.config import dictConfig
 
 # externals
-import torch
 import xarray as xr
 from sklearn.model_selection import train_test_split
 from torch.utils.data import DataLoader
@@ -24,10 +23,12 @@ from pysegcnn.core.logging import log_conf
 from climax.core.dataset import ERA5Dataset, NetCDFDataset
 from climax.core.loss import MSELoss, L1Loss
 from climax.main.config import (ERA5_PLEVELS, ERA5_PREDICTORS, PREDICTAND,
-                                CALIB_PERIOD, DOY, SHUFFLE, BATCH_SIZE, LR,
-                                LAMBDA, NORM, TRAIN_CONFIG, NET, LOSS, FILTERS,
+                                CALIB_PERIOD, DOY, SHUFFLE, BATCH_SIZE, OPTIM,
+                                NORM, TRAIN_CONFIG, NET, LOSS, FILTERS,
                                 OVERWRITE, DEM, DEM_FEATURES, STRATIFY,
-                                WET_DAY_THRESHOLD, VALID_SIZE, ANOMALIES)
+                                WET_DAY_THRESHOLD, VALID_SIZE, ANOMALIES,
+                                OPTIM_PARAMS, LR_SCHEDULER,
+                                LR_SCHEDULER_PARAMS)
 from climax.main.io import ERA5_PATH, OBS_PATH, DEM_PATH, MODEL_PATH
 
 # module level logger
@@ -43,7 +44,7 @@ if __name__ == '__main__':
     state_file = ERA5Dataset.state_file(
         NET, PREDICTAND, ERA5_PREDICTORS, ERA5_PLEVELS, dem=DEM,
         dem_features=DEM_FEATURES, doy=DOY, loss=LOSS, anomalies=ANOMALIES,
-        decay=LAMBDA)
+        decay=OPTIM_PARAMS['weight_decay'])
 
     # path to model state
     state_file = MODEL_PATH.joinpath(PREDICTAND, state_file)
@@ -119,11 +120,12 @@ if __name__ == '__main__':
     inputs = len(Era5_ds.data_vars) + 2 if DOY else len(Era5_ds.data_vars)
     net = NET(state_file, inputs, outputs, filters=FILTERS)
 
-    	# initialize optimizer
-    # optimizer = torch.optim.Adam(net.parameters(), lr=LR,
-    #                              weight_decay=LAMBDA)
-    optimizer = torch.optim.SGD(net.parameters(), lr=LR, momentum=0.9,
-                                weight_decay=LAMBDA)
+    # initialize optimizer
+    optimizer = OPTIM(net.parameters(), **OPTIM_PARAMS)
+
+    # initialize learning rate scheduler
+    if LR_SCHEDULER is not None:
+        LR_SCHEDULER = LR_SCHEDULER(optimizer, **LR_SCHEDULER_PARAMS)
 
     # initialize training data
     LogConfig.init_log('Initializing training data.')