From 56e8d37c580c2a2d92b2b9db2f2c039909e6591a Mon Sep 17 00:00:00 2001
From: "Daniel.Frisinghelli" <daniel.frisinghelli@eurac.edu>
Date: Wed, 6 Oct 2021 08:49:39 +0200
Subject: [PATCH] Clean version.

---
 climax/main/downscale_train.py | 200 ++++++++++++++++-----------------
 1 file changed, 100 insertions(+), 100 deletions(-)

diff --git a/climax/main/downscale_train.py b/climax/main/downscale_train.py
index 3868528..aa78bb1 100644
--- a/climax/main/downscale_train.py
+++ b/climax/main/downscale_train.py
@@ -4,6 +4,7 @@
 # -*- coding: utf-8 -*-
 
 # builtins
+import sys
 import time
 import logging
 from datetime import timedelta
@@ -61,107 +62,106 @@ if __name__ == '__main__':
     if state_file.exists() and not OVERWRITE:
         # load pretrained network
         net, _ = Network.load_pretrained_model(state_file, NET)
+        sys.exit()
+
+    # initialize ERA5 predictor dataset
+    LogConfig.init_log('Initializing ERA5 predictors.')
+    Era5 = ERA5Dataset(ERA5_PATH.joinpath('ERA5'), ERA5_PREDICTORS,
+                       plevels=ERA5_PLEVELS)
+    Era5_ds = Era5.merge(chunks=-1)
+
+    # initialize OBS predictand dataset
+    LogConfig.init_log('Initializing observations for predictand: {}'
+                       .format(PREDICTAND))
+
+    # check whether to joinlty train tasmin and tasmax
+    if PREDICTAND == 'tas':
+        # read both tasmax and tasmin
+        tasmax = xr.open_dataset(
+            search_files(OBS_PATH.joinpath('tasmax'), '.nc$').pop())
+        tasmin = xr.open_dataset(
+            search_files(OBS_PATH.joinpath('tasmin'), '.nc$').pop())
+        Obs_ds = xr.merge([tasmax, tasmin])
     else:
-        # initialize ERA5 predictor dataset
-        LogConfig.init_log('Initializing ERA5 predictors.')
-        Era5 = ERA5Dataset(ERA5_PATH.joinpath('ERA5'), ERA5_PREDICTORS,
-                           plevels=ERA5_PLEVELS)
-        Era5_ds = Era5.merge(chunks=-1)
-
-        # initialize OBS predictand dataset
-        LogConfig.init_log('Initializing observations for predictand: {}'
-                           .format(PREDICTAND))
-
-        # check whether to joinlty train tasmin and tasmax
-        if PREDICTAND == 'tas':
-            # read both tasmax and tasmin
-            tasmax = xr.open_dataset(
-                search_files(OBS_PATH.joinpath('tasmax'), '.nc$').pop())
-            tasmin = xr.open_dataset(
-                search_files(OBS_PATH.joinpath('tasmin'), '.nc$').pop())
-            Obs_ds = xr.merge([tasmax, tasmin])
-        else:
-            # read in-situ gridded observations
-            Obs_ds = search_files(OBS_PATH.joinpath(PREDICTAND), '.nc$').pop()
-            Obs_ds = xr.open_dataset(Obs_ds)
-
-        # whether to use digital elevation model
-        if DEM:
-            # digital elevation model: Copernicus EU-Dem v1.1
-            dem = search_files(DEM_PATH, '^eu_dem_v11_stt.nc$').pop()
-
-            # read elevation and compute slope and aspect
-            dem = ERA5Dataset.dem_features(
-                dem, {'y': Era5_ds.y, 'x': Era5_ds.x},
-                add_coord={'time': Era5_ds.time})
-
-            # check whether to use slope and aspect
-            if not DEM_FEATURES:
-                dem = dem.drop_vars(['slope', 'aspect'])
-
-            # add dem to set of predictor variables
-            Era5_ds = xr.merge([Era5_ds, dem])
-
-        # initialize network and optimizer
-        LogConfig.init_log('Initializing network and optimizer.')
-
-        # define number of output fields
-        # check whether modelling pr with probabilistic approach
-        outputs = len(Obs_ds.data_vars)
-        if PREDICTAND == 'pr' and (isinstance(LOSS, BernoulliGammaLoss) or
-                                   isinstance(LOSS, BernoulliGenParetoLoss)):
-            outputs = 3
-
-        # instanciate network
-        inputs = len(Era5_ds.data_vars) + 2 if DOY else len(Era5_ds.data_vars)
-        net = NET(state_file, inputs, outputs, filters=FILTERS)
-
-        	# initialize optimizer
-        # optimizer = torch.optim.Adam(net.parameters(), lr=LR,
-        #                              weight_decay=LAMBDA)
-        optimizer = torch.optim.SGD(net.parameters(), lr=LR, momentum=0.9,
-                                    weight_decay=LAMBDA)
-
-        # initialize training data
-        LogConfig.init_log('Initializing training data.')
-
-        # split calibration period into training and validation period
-        if PREDICTAND == 'pr' and STRATIFY:
-            # stratify training and validation dataset by number of
-            # observed wet days for precipitation
-            wet_days = (Obs_ds.sel(time=CALIB_PERIOD).mean(dim=('y', 'x'))
-                        >= WET_DAY_THRESHOLD).to_array().values.squeeze()
-            train, valid = train_test_split(
-                CALIB_PERIOD, stratify=wet_days, test_size=0.1)
-
-            # sort chronologically
-            train, valid = sorted(train), sorted(valid)
-        else:
-            train, valid = train_test_split(CALIB_PERIOD, shuffle=False,
-                                            test_size=0.1)
-
-        # training and validation dataset
-        Era5_train, Obs_train = Era5_ds.sel(time=train), Obs_ds.sel(time=train)
-        Era5_valid, Obs_valid = Era5_ds.sel(time=valid), Obs_ds.sel(time=valid)
-
-        # create PyTorch compliant dataset and dataloader instances for model
-        # training
-        train_ds = NetCDFDataset(Era5_train, Obs_train, normalize=NORM,
-                                 doy=DOY)
-        valid_ds = NetCDFDataset(Era5_valid, Obs_valid, normalize=NORM,
-                                 doy=DOY)
-        train_dl = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=SHUFFLE,
-                              drop_last=False)
-        valid_dl = DataLoader(valid_ds, batch_size=BATCH_SIZE, shuffle=SHUFFLE,
-                              drop_last=False)
-
-        # initialize network trainer
-        trainer = NetworkTrainer(net, optimizer, net.state_file, train_dl,
-                                 valid_dl, loss_function=LOSS,
-                                 **TRAIN_CONFIG)
-
-        # train model
-        state = trainer.train()
+        # read in-situ gridded observations
+        Obs_ds = search_files(OBS_PATH.joinpath(PREDICTAND), '.nc$').pop()
+        Obs_ds = xr.open_dataset(Obs_ds)
+
+    # whether to use digital elevation model
+    if DEM:
+        # digital elevation model: Copernicus EU-Dem v1.1
+        dem = search_files(DEM_PATH, '^eu_dem_v11_stt.nc$').pop()
+
+        # read elevation and compute slope and aspect
+        dem = ERA5Dataset.dem_features(
+            dem, {'y': Era5_ds.y, 'x': Era5_ds.x},
+            add_coord={'time': Era5_ds.time})
+
+        # check whether to use slope and aspect
+        if not DEM_FEATURES:
+            dem = dem.drop_vars(['slope', 'aspect'])
+
+        # add dem to set of predictor variables
+        Era5_ds = xr.merge([Era5_ds, dem])
+
+    # initialize network and optimizer
+    LogConfig.init_log('Initializing network and optimizer.')
+
+    # define number of output fields
+    # check whether modelling pr with probabilistic approach
+    outputs = len(Obs_ds.data_vars)
+    if PREDICTAND == 'pr' and (isinstance(LOSS, BernoulliGammaLoss) or
+                               isinstance(LOSS, BernoulliGenParetoLoss)):
+        outputs = 3
+
+    # instanciate network
+    inputs = len(Era5_ds.data_vars) + 2 if DOY else len(Era5_ds.data_vars)
+    net = NET(state_file, inputs, outputs, filters=FILTERS)
+
+    	# initialize optimizer
+    # optimizer = torch.optim.Adam(net.parameters(), lr=LR,
+    #                              weight_decay=LAMBDA)
+    optimizer = torch.optim.SGD(net.parameters(), lr=LR, momentum=0.9,
+                                weight_decay=LAMBDA)
+
+    # initialize training data
+    LogConfig.init_log('Initializing training data.')
+
+    # split calibration period into training and validation period
+    if PREDICTAND == 'pr' and STRATIFY:
+        # stratify training and validation dataset by number of
+        # observed wet days for precipitation
+        wet_days = (Obs_ds.sel(time=CALIB_PERIOD).mean(dim=('y', 'x'))
+                    >= WET_DAY_THRESHOLD).to_array().values.squeeze()
+        train, valid = train_test_split(
+            CALIB_PERIOD, stratify=wet_days, test_size=0.1)
+
+        # sort chronologically
+        train, valid = sorted(train), sorted(valid)
+    else:
+        train, valid = train_test_split(CALIB_PERIOD, shuffle=False,
+                                        test_size=0.1)
+
+    # training and validation dataset
+    Era5_train, Obs_train = Era5_ds.sel(time=train), Obs_ds.sel(time=train)
+    Era5_valid, Obs_valid = Era5_ds.sel(time=valid), Obs_ds.sel(time=valid)
+
+    # create PyTorch compliant dataset and dataloader instances for model
+    # training
+    train_ds = NetCDFDataset(Era5_train, Obs_train, normalize=NORM, doy=DOY)
+    valid_ds = NetCDFDataset(Era5_valid, Obs_valid, normalize=NORM, doy=DOY)
+    train_dl = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=SHUFFLE,
+                          drop_last=False)
+    valid_dl = DataLoader(valid_ds, batch_size=BATCH_SIZE, shuffle=SHUFFLE,
+                          drop_last=False)
+
+    # initialize network trainer
+    trainer = NetworkTrainer(net, optimizer, net.state_file, train_dl,
+                             valid_dl, loss_function=LOSS,
+                             **TRAIN_CONFIG)
+
+    # train model
+    state = trainer.train()
 
     # log execution time of script
     LogConfig.init_log('Execution time of script {}: {}'
-- 
GitLab