From 0a790e97cb00cff78cb6249ba92e374da4a739b3 Mon Sep 17 00:00:00 2001
From: "Daniel.Frisinghelli" <daniel.frisinghelli@eurac.edu>
Date: Mon, 27 Sep 2021 14:39:49 +0200
Subject: [PATCH] Added stratified sampling for precipitation.

---
 climax/main/config.py          | 3 +++
 climax/main/downscale_train.py | 6 +++---
 2 files changed, 6 insertions(+), 3 deletions(-)

diff --git a/climax/main/config.py b/climax/main/config.py
index f0ad782..4aa39d0 100644
--- a/climax/main/config.py
+++ b/climax/main/config.py
@@ -51,6 +51,9 @@ if DEM:
 # whether to use DEM slope and aspect as predictors
 DEM_FEATURES = False
 
+# stratify training/validation set for precipitation by number of wet days
+STRATIFY = True
+
 # -----------------------------------------------------------------------------
 # Observations ----------------------------------------------------------------
 # -----------------------------------------------------------------------------
diff --git a/climax/main/downscale_train.py b/climax/main/downscale_train.py
index 18076cb..1ac0cb7 100644
--- a/climax/main/downscale_train.py
+++ b/climax/main/downscale_train.py
@@ -21,11 +21,11 @@ from pysegcnn.core.trainer import NetworkTrainer, LogConfig
 from pysegcnn.core.models import Network
 from pysegcnn.core.logging import log_conf
 from climax.core.dataset import ERA5Dataset, NetCDFDataset
-from climax.core.config import WET_DAY_THRESHOLD
 from climax.main.config import (ERA5_PLEVELS, ERA5_PREDICTORS, PREDICTAND,
                                 CALIB_PERIOD, DOY, SHUFFLE, BATCH_SIZE, LR,
                                 LAMBDA, NORM, TRAIN_CONFIG, NET, LOSS, FILTERS,
-                                OVERWRITE, DEM, DEM_FEATURES)
+                                OVERWRITE, DEM, DEM_FEATURES, STRATIFY,
+                                WET_DAY_THRESHOLD)
 from climax.main.io import ERA5_PATH, OBS_PATH, DEM_PATH, MODEL_PATH
 
 # module level logger
@@ -102,7 +102,7 @@ if __name__ == '__main__':
         LogConfig.init_log('Initializing training data.')
 
         # split calibration period into training and validation period
-        if PREDICTAND == 'pr':
+        if PREDICTAND == 'pr' and STRATIFY:
             # stratify training and validation dataset by number of observed
             # wet days for precipitation
             wet_days = (Obs_ds.sel(time=CALIB_PERIOD).mean(dim=('y', 'x')) >=
-- 
GitLab