Skip to content
Snippets Groups Projects
Commit b06b70dc authored by Frisinghelli Daniel's avatar Frisinghelli Daniel
Browse files

Added stratified sampling for precipitation.

parent 1681c051
No related branches found
No related tags found
No related merge requests found
......@@ -21,6 +21,7 @@ from pysegcnn.core.trainer import NetworkTrainer, LogConfig
from pysegcnn.core.models import Network
from pysegcnn.core.logging import log_conf
from climax.core.dataset import ERA5Dataset, NetCDFDataset
from climax.core.config import WET_DAY_THRESHOLD
from climax.main.config import (ERA5_PLEVELS, ERA5_PREDICTORS, PREDICTAND,
CALIB_PERIOD, DOY, SHUFFLE, BATCH_SIZE, LR,
LAMBDA, NORM, TRAIN_CONFIG, NET, LOSS, FILTERS,
......@@ -101,7 +102,16 @@ if __name__ == '__main__':
LogConfig.init_log('Initializing training data.')
# split calibration period into training and validation period
train, valid = train_test_split(CALIB_PERIOD, shuffle=False)
if PREDICTAND == 'pr':
# stratify training and validation dataset by number of observed
# wet days for precipitation
wet_days = (Obs_ds.sel(time=CALIB_PERIOD).mean(dim=('y', 'x')) >=
WET_DAY_THRESHOLD).to_array().values
train, valid = train_test_split(
CALIB_PERIOD, stratify=wet_days, test_size=0.5)
train, valid = sorted(train), sorted(valid) # sort chronologically
else:
train, valid = train_test_split(CALIB_PERIOD, shuffle=False)
# training and validation dataset
Era5_train, Obs_train = Era5_ds.sel(time=train), Obs_ds.sel(time=train)
......@@ -133,7 +143,6 @@ if __name__ == '__main__':
# train model
state = trainer.train()
# log execution time of script
LogConfig.init_log('Execution time of script {}: {}'
.format(__file__, timedelta(seconds=time.monotonic() -
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment