From f1d002467f1caa35c26872595529ff191a4bac9d Mon Sep 17 00:00:00 2001
From: "Daniel.Frisinghelli" <daniel.frisinghelli@eurac.edu>
Date: Wed, 6 Oct 2021 09:45:26 +0200
Subject: [PATCH] Refactor.

---
 climax/main/downscale_infer.py | 2 +-
 climax/main/downscale_train.py | 9 ++++-----
 2 files changed, 5 insertions(+), 6 deletions(-)

diff --git a/climax/main/downscale_infer.py b/climax/main/downscale_infer.py
index 1aa01f1..dac64d1 100644
--- a/climax/main/downscale_infer.py
+++ b/climax/main/downscale_infer.py
@@ -95,7 +95,7 @@ if __name__ == '__main__':
                                    batch_size=BATCH_SIZE, doy=DOY))
 
     # merge predictions for entire validation period
-    LOGGER.info('Merging refernce periods ...')
+    LOGGER.info('Merging reference periods ...')
     trg_ds = xr.concat(trg_ds, dim='time')
 
     # save model predictions as NetCDF file
diff --git a/climax/main/downscale_train.py b/climax/main/downscale_train.py
index aa78bb1..f5de32b 100644
--- a/climax/main/downscale_train.py
+++ b/climax/main/downscale_train.py
@@ -27,7 +27,7 @@ from climax.main.config import (ERA5_PLEVELS, ERA5_PREDICTORS, PREDICTAND,
                                 CALIB_PERIOD, DOY, SHUFFLE, BATCH_SIZE, LR,
                                 LAMBDA, NORM, TRAIN_CONFIG, NET, LOSS, FILTERS,
                                 OVERWRITE, DEM, DEM_FEATURES, STRATIFY,
-                                WET_DAY_THRESHOLD)
+                                WET_DAY_THRESHOLD, VALID_SIZE)
 from climax.main.io import ERA5_PATH, OBS_PATH, DEM_PATH, MODEL_PATH
 
 # module level logger
@@ -134,13 +134,13 @@ if __name__ == '__main__':
         wet_days = (Obs_ds.sel(time=CALIB_PERIOD).mean(dim=('y', 'x'))
                     >= WET_DAY_THRESHOLD).to_array().values.squeeze()
         train, valid = train_test_split(
-            CALIB_PERIOD, stratify=wet_days, test_size=0.1)
+            CALIB_PERIOD, stratify=wet_days, test_size=VALID_SIZE)
 
         # sort chronologically
         train, valid = sorted(train), sorted(valid)
     else:
         train, valid = train_test_split(CALIB_PERIOD, shuffle=False,
-                                        test_size=0.1)
+                                        test_size=VALID_SIZE)
 
     # training and validation dataset
     Era5_train, Obs_train = Era5_ds.sel(time=train), Obs_ds.sel(time=train)
@@ -157,8 +157,7 @@ if __name__ == '__main__':
 
     # initialize network trainer
     trainer = NetworkTrainer(net, optimizer, net.state_file, train_dl,
-                             valid_dl, loss_function=LOSS,
-                             **TRAIN_CONFIG)
+                             valid_dl, loss_function=LOSS, **TRAIN_CONFIG)
 
     # train model
     state = trainer.train()
-- 
GitLab