From bd79fde19018349404c5d813457dd448c9184132 Mon Sep 17 00:00:00 2001
From: "Daniel.Frisinghelli" <daniel.frisinghelli@eurac.edu>
Date: Fri, 1 Oct 2021 10:40:06 +0200
Subject: [PATCH] Implemented training with cross-validation.

---
 climax/main/config.py             |   3 +
 climax/main/downscale_infer.py    |   5 +-
 climax/main/downscale_train.py    | 123 ++++++++++++++-------
 climax/main/downscale_train_cv.py | 174 ------------------------------
 4 files changed, 92 insertions(+), 213 deletions(-)
 delete mode 100644 climax/main/downscale_train_cv.py

diff --git a/climax/main/config.py b/climax/main/config.py
index 42baad4..7e86890 100644
--- a/climax/main/config.py
+++ b/climax/main/config.py
@@ -55,6 +55,9 @@ DEM_FEATURES = False
 # stratify training/validation set for precipitation by number of wet days
 STRATIFY = True
 
+# whether to train using cross-validation
+CV = False
+
 # -----------------------------------------------------------------------------
 # Observations ----------------------------------------------------------------
 # -----------------------------------------------------------------------------
diff --git a/climax/main/downscale_infer.py b/climax/main/downscale_infer.py
index e61269c..9f57fd8 100644
--- a/climax/main/downscale_infer.py
+++ b/climax/main/downscale_infer.py
@@ -24,7 +24,7 @@ from climax.core.utils import split_date_range
 from climax.core.loss import BernoulliGammaLoss
 from climax.main.config import (ERA5_PREDICTORS, ERA5_PLEVELS, PREDICTAND, NET,
                                 VALID_PERIOD, BATCH_SIZE, NORM, DOY, NYEARS,
-                                DEM, DEM_FEATURES, LOSS)
+                                DEM, DEM_FEATURES, LOSS, CV)
 from climax.main.io import ERA5_PATH, DEM_PATH, MODEL_PATH, TARGET_PATH
 
 # module level logger
@@ -51,6 +51,9 @@ if __name__ == '__main__':
             state_file = state_file.replace('.pt', '_{}.pt'.format(
                 repr(LOSS).strip('()')))
 
+    # add suffix for training with cross-validation
+    state_file = state_file.replace('.pt', '_cv.pt') if CV else state_file
+
     # path to model state
     state_file = MODEL_PATH.joinpath(PREDICTAND, state_file)
 
diff --git a/climax/main/downscale_train.py b/climax/main/downscale_train.py
index 79758be..b985de3 100644
--- a/climax/main/downscale_train.py
+++ b/climax/main/downscale_train.py
@@ -12,7 +12,7 @@ from logging.config import dictConfig
 # externals
 import torch
 import xarray as xr
-from sklearn.model_selection import train_test_split
+from sklearn.model_selection import train_test_split, TimeSeriesSplit
 from torch.utils.data import DataLoader
 
 # locals
@@ -26,7 +26,7 @@ from climax.main.config import (ERA5_PLEVELS, ERA5_PREDICTORS, PREDICTAND,
                                 CALIB_PERIOD, DOY, SHUFFLE, BATCH_SIZE, LR,
                                 LAMBDA, NORM, TRAIN_CONFIG, NET, LOSS, FILTERS,
                                 OVERWRITE, DEM, DEM_FEATURES, STRATIFY,
-                                WET_DAY_THRESHOLD)
+                                WET_DAY_THRESHOLD, CV)
 from climax.main.io import ERA5_PATH, OBS_PATH, DEM_PATH, MODEL_PATH
 
 # module level logger
@@ -53,6 +53,9 @@ if __name__ == '__main__':
             state_file = state_file.replace('.pt', '_{}.pt'.format(
                 repr(LOSS).strip('()')))
 
+    # add suffix for training with cross-validation
+    state_file = state_file.replace('.pt', '_cv.pt') if CV else state_file
+
     # path to model state
     state_file = MODEL_PATH.joinpath(PREDICTAND, state_file)
 
@@ -112,36 +115,8 @@ if __name__ == '__main__':
             # add dem to set of predictor variables
             Era5_ds = xr.merge([Era5_ds, dem])
 
-        # initialize training data
-        LogConfig.init_log('Initializing training data.')
-
-        # split calibration period into training and validation period
-        if PREDICTAND == 'pr' and STRATIFY:
-            # stratify training and validation dataset by number of observed
-            # wet days for precipitation
-            wet_days = (Obs_ds.sel(time=CALIB_PERIOD).mean(dim=('y', 'x')) >=
-                        WET_DAY_THRESHOLD).to_array().values.squeeze()
-            train, valid = train_test_split(
-                CALIB_PERIOD, stratify=wet_days, test_size=0.1)
-            train, valid = sorted(train), sorted(valid)  # sort chronologically
-        else:
-            train, valid = train_test_split(CALIB_PERIOD, shuffle=False,
-                                            test_size=0.1)
-
-        # training and validation dataset
-        Era5_train, Obs_train = Era5_ds.sel(time=train), Obs_ds.sel(time=train)
-        Era5_valid, Obs_valid = Era5_ds.sel(time=valid), Obs_ds.sel(time=valid)
-
-        # create PyTorch compliant dataset and dataloader instances for model
-        # training
-        train_ds = NetCDFDataset(Era5_train, Obs_train, normalize=NORM,
-                                 doy=DOY)
-        valid_ds = NetCDFDataset(Era5_valid, Obs_valid, normalize=NORM,
-                                 doy=DOY)
-        train_dl = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=SHUFFLE,
-                              drop_last=False)
-        valid_dl = DataLoader(valid_ds, batch_size=BATCH_SIZE, shuffle=SHUFFLE,
-                              drop_last=False)
+        # initialize network and optimizer
+        LogConfig.init_log('Initializing network and optimizer.')
 
         # define number of output fields
         # check whether modelling pr with probabilistic approach
@@ -150,18 +125,90 @@ if __name__ == '__main__':
             outputs = 3
 
         # instanciate network
-        net = NET(state_file, train_ds.X.shape[1], outputs, filters=FILTERS)
+        net = NET(state_file, len(Era5_ds.data_vars), outputs, filters=FILTERS)
 
         	# initialize optimizer
         optimizer = torch.optim.Adam(net.parameters(), lr=LR,
                                      weight_decay=LAMBDA)
 
-        # initialize network trainer
-        trainer = NetworkTrainer(net, optimizer, net.state_file, train_dl,
-                                 valid_dl, loss_function=LOSS, **TRAIN_CONFIG)
+        # initialize training data
+        LogConfig.init_log('Initializing training data.')
+        if CV:
+            # split calibration period using cross-validation TimeSeriesSplit
+            cv = TimeSeriesSplit()
+            for i, (train_idx, valid_idx) in enumerate(cv.split(CALIB_PERIOD)):
+
+                # time steps for training and validation set
+                train = CALIB_PERIOD[train_idx]
+                valid = CALIB_PERIOD[valid_idx]
+                LogConfig.init_log('Fold {}/{}: {} - {}'.format(
+                    i + 1, cv.n_splits, str(train[0]), str(train[-1])))
+
+                # training and validation dataset
+                Era5_train, Obs_train = (Era5_ds.sel(time=train),
+                                         Obs_ds.sel(time=train))
+                Era5_valid, Obs_valid = (Era5_ds.sel(time=valid),
+                                         Obs_ds.sel(time=valid))
+
+                # create PyTorch compliant dataset and dataloader instances for
+                # model training
+                train_ds = NetCDFDataset(Era5_train, Obs_train, normalize=NORM,
+                                         doy=DOY)
+                valid_ds = NetCDFDataset(Era5_valid, Obs_valid, normalize=NORM,
+                                         doy=DOY)
+                train_dl = DataLoader(train_ds, batch_size=BATCH_SIZE,
+                                      shuffle=SHUFFLE, drop_last=False)
+                valid_dl = DataLoader(valid_ds, batch_size=BATCH_SIZE,
+                                      shuffle=SHUFFLE, drop_last=False)
+
+                # initialize network trainer
+                trainer = NetworkTrainer(
+                    net, optimizer, net.state_file, train_dl, valid_dl,
+                    loss_function=LOSS, **TRAIN_CONFIG)
+
+                # train model
+                state = trainer.train()
 
-        # train model
-        state = trainer.train()
+        else:
+            # split calibration period into training and validation period
+            if PREDICTAND == 'pr' and STRATIFY:
+                # stratify training and validation dataset by number of
+                # observed wet days for precipitation
+                wet_days = (Obs_ds.sel(time=CALIB_PERIOD).mean(dim=('y', 'x'))
+                            >= WET_DAY_THRESHOLD).to_array().values.squeeze()
+                train, valid = train_test_split(
+                    CALIB_PERIOD, stratify=wet_days, test_size=0.1)
+
+                # sort chronologically
+                train, valid = sorted(train), sorted(valid)
+            else:
+                train, valid = train_test_split(CALIB_PERIOD, shuffle=False,
+                                                test_size=0.1)
+
+            # training and validation dataset
+            Era5_train, Obs_train = (Era5_ds.sel(time=train),
+                                     Obs_ds.sel(time=train))
+            Era5_valid, Obs_valid = (Era5_ds.sel(time=valid),
+                                     Obs_ds.sel(time=valid))
+
+            # create PyTorch compliant dataset and dataloader instances for model
+            # training
+            train_ds = NetCDFDataset(Era5_train, Obs_train, normalize=NORM,
+                                     doy=DOY)
+            valid_ds = NetCDFDataset(Era5_valid, Obs_valid, normalize=NORM,
+                                     doy=DOY)
+            train_dl = DataLoader(train_ds, batch_size=BATCH_SIZE,
+                                  shuffle=SHUFFLE, drop_last=False)
+            valid_dl = DataLoader(valid_ds, batch_size=BATCH_SIZE,
+                                  shuffle=SHUFFLE, drop_last=False)
+
+            # initialize network trainer
+            trainer = NetworkTrainer(net, optimizer, net.state_file, train_dl,
+                                     valid_dl, loss_function=LOSS,
+                                     **TRAIN_CONFIG)
+
+            # train model
+            state = trainer.train()
 
     # log execution time of script
     LogConfig.init_log('Execution time of script {}: {}'
diff --git a/climax/main/downscale_train_cv.py b/climax/main/downscale_train_cv.py
deleted file mode 100644
index 5aaec33..0000000
--- a/climax/main/downscale_train_cv.py
+++ /dev/null
@@ -1,174 +0,0 @@
-"""Dynamical climate downscaling using deep convolutional neural networks."""
-
-# !/usr/bin/env python
-# -*- coding: utf-8 -*-
-
-# builtins
-import time
-import logging
-from datetime import timedelta
-from logging.config import dictConfig
-
-# externals
-import torch
-import xarray as xr
-from sklearn.model_selection import TimeSeriesSplit
-from torch.utils.data import DataLoader
-
-# locals
-from pysegcnn.core.utils import search_files
-from pysegcnn.core.trainer import NetworkTrainer, LogConfig
-from pysegcnn.core.models import Network
-from pysegcnn.core.logging import log_conf
-from climax.core.dataset import ERA5Dataset, NetCDFDataset
-from climax.core.loss import BernoulliGammaLoss
-from climax.main.config import (ERA5_PLEVELS, ERA5_PREDICTORS, PREDICTAND,
-                                CALIB_PERIOD, DOY, SHUFFLE, BATCH_SIZE, LR,
-                                LAMBDA, NORM, TRAIN_CONFIG, NET, LOSS, FILTERS,
-                                OVERWRITE, DEM, DEM_FEATURES)
-from climax.main.io import ERA5_PATH, OBS_PATH, DEM_PATH, MODEL_PATH
-
-# module level logger
-LOGGER = logging.getLogger(__name__)
-
-
-if __name__ == '__main__':
-
-    # initialize timing
-    start_time = time.monotonic()
-
-    # initialize network filename
-    state_file = ERA5Dataset.state_file(
-        NET, PREDICTAND, ERA5_PREDICTORS, ERA5_PLEVELS, dem=DEM,
-        dem_features=DEM_FEATURES, doy=DOY)
-
-    # adjust statefile name for precipitation
-    if PREDICTAND == 'pr':
-        if isinstance(LOSS, BernoulliGammaLoss):
-            state_file = state_file.replace('.pt', '_{}mm_{}.pt'.format(
-                str(LOSS.min_amount).replace('.', ''),
-                repr(LOSS).strip('()')))
-        else:
-            state_file = state_file.replace('.pt', '_{}.pt'.format(
-                repr(LOSS).strip('()')))
-
-    # add suffix for training with cross-validation
-    state_file = state_file.replace('.pt', '_cv.pt')
-
-    # path to model state
-    state_file = MODEL_PATH.joinpath(PREDICTAND, state_file)
-
-    # initialize logging
-    log_file = MODEL_PATH.joinpath(PREDICTAND,
-                                   state_file.name.replace('.pt', '_log.txt'))
-    if log_file.exists():
-        log_file.unlink()
-    dictConfig(log_conf(log_file))
-
-    # initialize downscaling
-    LogConfig.init_log('Initializing downscaling for period: {}'.format(
-        ' - '.join([str(CALIB_PERIOD[0]), str(CALIB_PERIOD[-1])])))
-
-    # check if model exists
-    if state_file.exists() and not OVERWRITE:
-        # load pretrained network
-        net, _ = Network.load_pretrained_model(state_file, NET)
-    else:
-        # initialize ERA5 predictor dataset
-        LogConfig.init_log('Initializing ERA5 predictors.')
-        Era5 = ERA5Dataset(ERA5_PATH.joinpath('ERA5'), ERA5_PREDICTORS,
-                           plevels=ERA5_PLEVELS)
-        Era5_ds = Era5.merge(chunks=-1)
-
-        # initialize OBS predictand dataset
-        LogConfig.init_log('Initializing observations for predictand: {}'
-                           .format(PREDICTAND))
-
-        # check whether to joinlty train tasmin and tasmax
-        if PREDICTAND == 'tas':
-            # read both tasmax and tasmin
-            tasmax = xr.open_dataset(
-                search_files(OBS_PATH.joinpath('tasmax'), '.nc$').pop())
-            tasmin = xr.open_dataset(
-                search_files(OBS_PATH.joinpath('tasmin'), '.nc$').pop())
-            Obs_ds = xr.merge([tasmax, tasmin])
-        else:
-            # read in-situ gridded observations
-            Obs_ds = search_files(OBS_PATH.joinpath(PREDICTAND), '.nc$').pop()
-            Obs_ds = xr.open_dataset(Obs_ds)
-
-        # whether to use digital elevation model
-        if DEM:
-            # digital elevation model: Copernicus EU-Dem v1.1
-            dem = search_files(DEM_PATH, '^eu_dem_v11_stt.nc$').pop()
-
-            # read elevation and compute slope and aspect
-            dem = ERA5Dataset.dem_features(
-                dem, {'y': Era5_ds.y, 'x': Era5_ds.x},
-                add_coord={'time': Era5_ds.time})
-
-            # check whether to use slope and aspect
-            if not DEM_FEATURES:
-                dem = dem.drop_vars(['slope', 'aspect'])
-
-            # add dem to set of predictor variables
-            Era5_ds = xr.merge([Era5_ds, dem])
-
-        # initialize network and optimizer
-        LogConfig.init_log('Initializing network and optimizer.')
-
-        # define number of output fields
-        # check whether modelling pr with probabilistic approach
-        outputs = len(Obs_ds.data_vars)
-        if PREDICTAND == 'pr' and isinstance(LOSS, BernoulliGammaLoss):
-            outputs = 3
-
-        # instanciate network
-        net = NET(state_file, len(Era5_ds.data_vars), outputs, filters=FILTERS)
-
-        	# initialize optimizer
-        optimizer = torch.optim.Adam(net.parameters(), lr=LR,
-                                     weight_decay=LAMBDA)
-
-        # initialize training data
-        LogConfig.init_log('Initializing training data.')
-
-        # split calibration period using cross-validation TimeSeriesSplit
-        cv = TimeSeriesSplit()
-        for i, (train_idx, valid_idx) in enumerate(cv.split(CALIB_PERIOD)):
-
-            # time steps for training and validation set
-            train = CALIB_PERIOD[train_idx]
-            valid = CALIB_PERIOD[valid_idx]
-            LogConfig.init_log('Fold {}/{}: {} - {}'.format(
-                i + 1, cv.n_splits, str(train[0]), str(train[-1])))
-
-            # training and validation dataset
-            Era5_train, Obs_train = (Era5_ds.sel(time=train),
-                                     Obs_ds.sel(time=train))
-            Era5_valid, Obs_valid = (Era5_ds.sel(time=valid),
-                                     Obs_ds.sel(time=valid))
-
-            # create PyTorch compliant dataset and dataloader instances for
-            # model training
-            train_ds = NetCDFDataset(Era5_train, Obs_train, normalize=NORM,
-                                     doy=DOY)
-            valid_ds = NetCDFDataset(Era5_valid, Obs_valid, normalize=NORM,
-                                     doy=DOY)
-            train_dl = DataLoader(train_ds, batch_size=BATCH_SIZE,
-                                  shuffle=SHUFFLE, drop_last=False)
-            valid_dl = DataLoader(valid_ds, batch_size=BATCH_SIZE,
-                                  shuffle=SHUFFLE, drop_last=False)
-
-            # initialize network trainer
-            trainer = NetworkTrainer(net, optimizer, net.state_file, train_dl,
-                                     valid_dl, loss_function=LOSS,
-                                     **TRAIN_CONFIG)
-
-            # train model
-            state = trainer.train()
-
-    # log execution time of script
-    LogConfig.init_log('Execution time of script {}: {}'
-                       .format(__file__, timedelta(seconds=time.monotonic() -
-                                                   start_time)))
-- 
GitLab