From b06b70dc7124971611a9e35cdc432616a019bc2f Mon Sep 17 00:00:00 2001
From: "Daniel.Frisinghelli" <daniel.frisinghelli@eurac.edu>
Date: Mon, 27 Sep 2021 14:38:26 +0200
Subject: [PATCH] Added stratified sampling for precipitation.

---
 climax/main/downscale_train.py | 13 +++++++++++--
 1 file changed, 11 insertions(+), 2 deletions(-)

diff --git a/climax/main/downscale_train.py b/climax/main/downscale_train.py
index d1bcb5b..18076cb 100644
--- a/climax/main/downscale_train.py
+++ b/climax/main/downscale_train.py
@@ -21,6 +21,7 @@ from pysegcnn.core.trainer import NetworkTrainer, LogConfig
 from pysegcnn.core.models import Network
 from pysegcnn.core.logging import log_conf
 from climax.core.dataset import ERA5Dataset, NetCDFDataset
+from climax.core.config import WET_DAY_THRESHOLD
 from climax.main.config import (ERA5_PLEVELS, ERA5_PREDICTORS, PREDICTAND,
                                 CALIB_PERIOD, DOY, SHUFFLE, BATCH_SIZE, LR,
                                 LAMBDA, NORM, TRAIN_CONFIG, NET, LOSS, FILTERS,
@@ -101,7 +102,16 @@ if __name__ == '__main__':
         LogConfig.init_log('Initializing training data.')
 
         # split calibration period into training and validation period
-        train, valid = train_test_split(CALIB_PERIOD, shuffle=False)
+        if PREDICTAND == 'pr':
+            # stratify training and validation dataset by number of observed
+            # wet days for precipitation
+            wet_days = (Obs_ds.sel(time=CALIB_PERIOD).mean(dim=('y', 'x')) >=
+                        WET_DAY_THRESHOLD).to_array().values
+            train, valid = train_test_split(
+                CALIB_PERIOD, stratify=wet_days, test_size=0.5)
+            train, valid = sorted(train), sorted(valid)  # sort chronologically
+        else:
+            train, valid = train_test_split(CALIB_PERIOD, shuffle=False)
 
         # training and validation dataset
         Era5_train, Obs_train = Era5_ds.sel(time=train), Obs_ds.sel(time=train)
@@ -133,7 +143,6 @@ if __name__ == '__main__':
         # train model
         state = trainer.train()
 
-
     # log execution time of script
     LogConfig.init_log('Execution time of script {}: {}'
                        .format(__file__, timedelta(seconds=time.monotonic() -
-- 
GitLab