diff --git a/climax/main/downscale_train.py b/climax/main/downscale_train.py index d1bcb5b75f0f021422158d4bc80abeb1d0f7c912..18076cba5b001f76699d48253bd64418b253a15b 100644 --- a/climax/main/downscale_train.py +++ b/climax/main/downscale_train.py @@ -21,6 +21,7 @@ from pysegcnn.core.trainer import NetworkTrainer, LogConfig from pysegcnn.core.models import Network from pysegcnn.core.logging import log_conf from climax.core.dataset import ERA5Dataset, NetCDFDataset +from climax.core.config import WET_DAY_THRESHOLD from climax.main.config import (ERA5_PLEVELS, ERA5_PREDICTORS, PREDICTAND, CALIB_PERIOD, DOY, SHUFFLE, BATCH_SIZE, LR, LAMBDA, NORM, TRAIN_CONFIG, NET, LOSS, FILTERS, @@ -101,7 +102,16 @@ if __name__ == '__main__': LogConfig.init_log('Initializing training data.') # split calibration period into training and validation period - train, valid = train_test_split(CALIB_PERIOD, shuffle=False) + if PREDICTAND == 'pr': + # stratify training and validation dataset by number of observed + # wet days for precipitation + wet_days = (Obs_ds.sel(time=CALIB_PERIOD).mean(dim=('y', 'x')) >= + WET_DAY_THRESHOLD).to_array().values + train, valid = train_test_split( + CALIB_PERIOD, stratify=wet_days, test_size=0.5) + train, valid = sorted(train), sorted(valid) # sort chronologically + else: + train, valid = train_test_split(CALIB_PERIOD, shuffle=False) # training and validation dataset Era5_train, Obs_train = Era5_ds.sel(time=train), Obs_ds.sel(time=train) @@ -133,7 +143,6 @@ if __name__ == '__main__': # train model state = trainer.train() - # log execution time of script LogConfig.init_log('Execution time of script {}: {}' .format(__file__, timedelta(seconds=time.monotonic() -