diff --git a/climax/main/downscale_train.py b/climax/main/downscale_train.py
index d1bcb5b75f0f021422158d4bc80abeb1d0f7c912..18076cba5b001f76699d48253bd64418b253a15b 100644
--- a/climax/main/downscale_train.py
+++ b/climax/main/downscale_train.py
@@ -21,6 +21,7 @@ from pysegcnn.core.trainer import NetworkTrainer, LogConfig
 from pysegcnn.core.models import Network
 from pysegcnn.core.logging import log_conf
 from climax.core.dataset import ERA5Dataset, NetCDFDataset
+from climax.core.config import WET_DAY_THRESHOLD
 from climax.main.config import (ERA5_PLEVELS, ERA5_PREDICTORS, PREDICTAND,
                                 CALIB_PERIOD, DOY, SHUFFLE, BATCH_SIZE, LR,
                                 LAMBDA, NORM, TRAIN_CONFIG, NET, LOSS, FILTERS,
@@ -101,7 +102,16 @@ if __name__ == '__main__':
         LogConfig.init_log('Initializing training data.')
 
         # split calibration period into training and validation period
-        train, valid = train_test_split(CALIB_PERIOD, shuffle=False)
+        if PREDICTAND == 'pr':
+            # stratify training and validation dataset by number of observed
+            # wet days for precipitation
+            wet_days = (Obs_ds.sel(time=CALIB_PERIOD).mean(dim=('y', 'x')) >=
+                        WET_DAY_THRESHOLD).to_array().values
+            train, valid = train_test_split(
+                CALIB_PERIOD, stratify=wet_days, test_size=0.5)
+            train, valid = sorted(train), sorted(valid)  # sort chronologically
+        else:
+            train, valid = train_test_split(CALIB_PERIOD, shuffle=False)
 
         # training and validation dataset
         Era5_train, Obs_train = Era5_ds.sel(time=train), Obs_ds.sel(time=train)
@@ -133,7 +143,6 @@ if __name__ == '__main__':
         # train model
         state = trainer.train()
 
-
     # log execution time of script
     LogConfig.init_log('Execution time of script {}: {}'
                        .format(__file__, timedelta(seconds=time.monotonic() -