From 0a790e97cb00cff78cb6249ba92e374da4a739b3 Mon Sep 17 00:00:00 2001 From: "Daniel.Frisinghelli" <daniel.frisinghelli@eurac.edu> Date: Mon, 27 Sep 2021 14:39:49 +0200 Subject: [PATCH] Added stratified sampling for precipitation. --- climax/main/config.py | 3 +++ climax/main/downscale_train.py | 6 +++--- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/climax/main/config.py b/climax/main/config.py index f0ad782..4aa39d0 100644 --- a/climax/main/config.py +++ b/climax/main/config.py @@ -51,6 +51,9 @@ if DEM: # whether to use DEM slope and aspect as predictors DEM_FEATURES = False +# stratify training/validation set for precipitation by number of wet days +STRATIFY = True + # ----------------------------------------------------------------------------- # Observations ---------------------------------------------------------------- # ----------------------------------------------------------------------------- diff --git a/climax/main/downscale_train.py b/climax/main/downscale_train.py index 18076cb..1ac0cb7 100644 --- a/climax/main/downscale_train.py +++ b/climax/main/downscale_train.py @@ -21,11 +21,11 @@ from pysegcnn.core.trainer import NetworkTrainer, LogConfig from pysegcnn.core.models import Network from pysegcnn.core.logging import log_conf from climax.core.dataset import ERA5Dataset, NetCDFDataset -from climax.core.config import WET_DAY_THRESHOLD from climax.main.config import (ERA5_PLEVELS, ERA5_PREDICTORS, PREDICTAND, CALIB_PERIOD, DOY, SHUFFLE, BATCH_SIZE, LR, LAMBDA, NORM, TRAIN_CONFIG, NET, LOSS, FILTERS, - OVERWRITE, DEM, DEM_FEATURES) + OVERWRITE, DEM, DEM_FEATURES, STRATIFY, + WET_DAY_THRESHOLD) from climax.main.io import ERA5_PATH, OBS_PATH, DEM_PATH, MODEL_PATH # module level logger @@ -102,7 +102,7 @@ if __name__ == '__main__': LogConfig.init_log('Initializing training data.') # split calibration period into training and validation period - if PREDICTAND == 'pr': + if PREDICTAND == 'pr' and STRATIFY: # stratify training and validation dataset by number of observed # wet days for precipitation wet_days = (Obs_ds.sel(time=CALIB_PERIOD).mean(dim=('y', 'x')) >= -- GitLab