From f1d002467f1caa35c26872595529ff191a4bac9d Mon Sep 17 00:00:00 2001 From: "Daniel.Frisinghelli" <daniel.frisinghelli@eurac.edu> Date: Wed, 6 Oct 2021 09:45:26 +0200 Subject: [PATCH] Refactor. --- climax/main/downscale_infer.py | 2 +- climax/main/downscale_train.py | 9 ++++----- 2 files changed, 5 insertions(+), 6 deletions(-) diff --git a/climax/main/downscale_infer.py b/climax/main/downscale_infer.py index 1aa01f1..dac64d1 100644 --- a/climax/main/downscale_infer.py +++ b/climax/main/downscale_infer.py @@ -95,7 +95,7 @@ if __name__ == '__main__': batch_size=BATCH_SIZE, doy=DOY)) # merge predictions for entire validation period - LOGGER.info('Merging refernce periods ...') + LOGGER.info('Merging reference periods ...') trg_ds = xr.concat(trg_ds, dim='time') # save model predictions as NetCDF file diff --git a/climax/main/downscale_train.py b/climax/main/downscale_train.py index aa78bb1..f5de32b 100644 --- a/climax/main/downscale_train.py +++ b/climax/main/downscale_train.py @@ -27,7 +27,7 @@ from climax.main.config import (ERA5_PLEVELS, ERA5_PREDICTORS, PREDICTAND, CALIB_PERIOD, DOY, SHUFFLE, BATCH_SIZE, LR, LAMBDA, NORM, TRAIN_CONFIG, NET, LOSS, FILTERS, OVERWRITE, DEM, DEM_FEATURES, STRATIFY, - WET_DAY_THRESHOLD) + WET_DAY_THRESHOLD, VALID_SIZE) from climax.main.io import ERA5_PATH, OBS_PATH, DEM_PATH, MODEL_PATH # module level logger @@ -134,13 +134,13 @@ if __name__ == '__main__': wet_days = (Obs_ds.sel(time=CALIB_PERIOD).mean(dim=('y', 'x')) >= WET_DAY_THRESHOLD).to_array().values.squeeze() train, valid = train_test_split( - CALIB_PERIOD, stratify=wet_days, test_size=0.1) + CALIB_PERIOD, stratify=wet_days, test_size=VALID_SIZE) # sort chronologically train, valid = sorted(train), sorted(valid) else: train, valid = train_test_split(CALIB_PERIOD, shuffle=False, - test_size=0.1) + test_size=VALID_SIZE) # training and validation dataset Era5_train, Obs_train = Era5_ds.sel(time=train), Obs_ds.sel(time=train) @@ -157,8 +157,7 @@ if __name__ == '__main__': # initialize network trainer trainer = NetworkTrainer(net, optimizer, net.state_file, train_dl, - valid_dl, loss_function=LOSS, - **TRAIN_CONFIG) + valid_dl, loss_function=LOSS, **TRAIN_CONFIG) # train model state = trainer.train() -- GitLab