From 801495821d62307f6466566588bad1b4ba35f219 Mon Sep 17 00:00:00 2001
From: "Daniel.Frisinghelli" <daniel.frisinghelli@eurac.edu>
Date: Fri, 1 Oct 2021 10:26:41 +0200
Subject: [PATCH] Module to train CNN using cross-validation.

---
 climax/main/downscale_train_cv.py | 174 ++++++++++++++++++++++++++++++
 1 file changed, 174 insertions(+)
 create mode 100644 climax/main/downscale_train_cv.py

diff --git a/climax/main/downscale_train_cv.py b/climax/main/downscale_train_cv.py
new file mode 100644
index 0000000..5aaec33
--- /dev/null
+++ b/climax/main/downscale_train_cv.py
@@ -0,0 +1,174 @@
+"""Dynamical climate downscaling using deep convolutional neural networks."""
+
+# !/usr/bin/env python
+# -*- coding: utf-8 -*-
+
+# builtins
+import time
+import logging
+from datetime import timedelta
+from logging.config import dictConfig
+
+# externals
+import torch
+import xarray as xr
+from sklearn.model_selection import TimeSeriesSplit
+from torch.utils.data import DataLoader
+
+# locals
+from pysegcnn.core.utils import search_files
+from pysegcnn.core.trainer import NetworkTrainer, LogConfig
+from pysegcnn.core.models import Network
+from pysegcnn.core.logging import log_conf
+from climax.core.dataset import ERA5Dataset, NetCDFDataset
+from climax.core.loss import BernoulliGammaLoss
+from climax.main.config import (ERA5_PLEVELS, ERA5_PREDICTORS, PREDICTAND,
+                                CALIB_PERIOD, DOY, SHUFFLE, BATCH_SIZE, LR,
+                                LAMBDA, NORM, TRAIN_CONFIG, NET, LOSS, FILTERS,
+                                OVERWRITE, DEM, DEM_FEATURES)
+from climax.main.io import ERA5_PATH, OBS_PATH, DEM_PATH, MODEL_PATH
+
+# module level logger
+LOGGER = logging.getLogger(__name__)
+
+
+if __name__ == '__main__':
+
+    # initialize timing
+    start_time = time.monotonic()
+
+    # initialize network filename
+    state_file = ERA5Dataset.state_file(
+        NET, PREDICTAND, ERA5_PREDICTORS, ERA5_PLEVELS, dem=DEM,
+        dem_features=DEM_FEATURES, doy=DOY)
+
+    # adjust statefile name for precipitation
+    if PREDICTAND == 'pr':
+        if isinstance(LOSS, BernoulliGammaLoss):
+            state_file = state_file.replace('.pt', '_{}mm_{}.pt'.format(
+                str(LOSS.min_amount).replace('.', ''),
+                repr(LOSS).strip('()')))
+        else:
+            state_file = state_file.replace('.pt', '_{}.pt'.format(
+                repr(LOSS).strip('()')))
+
+    # add suffix for training with cross-validation
+    state_file = state_file.replace('.pt', '_cv.pt')
+
+    # path to model state
+    state_file = MODEL_PATH.joinpath(PREDICTAND, state_file)
+
+    # initialize logging
+    log_file = MODEL_PATH.joinpath(PREDICTAND,
+                                   state_file.name.replace('.pt', '_log.txt'))
+    if log_file.exists():
+        log_file.unlink()
+    dictConfig(log_conf(log_file))
+
+    # initialize downscaling
+    LogConfig.init_log('Initializing downscaling for period: {}'.format(
+        ' - '.join([str(CALIB_PERIOD[0]), str(CALIB_PERIOD[-1])])))
+
+    # check if model exists
+    if state_file.exists() and not OVERWRITE:
+        # load pretrained network
+        net, _ = Network.load_pretrained_model(state_file, NET)
+    else:
+        # initialize ERA5 predictor dataset
+        LogConfig.init_log('Initializing ERA5 predictors.')
+        Era5 = ERA5Dataset(ERA5_PATH.joinpath('ERA5'), ERA5_PREDICTORS,
+                           plevels=ERA5_PLEVELS)
+        Era5_ds = Era5.merge(chunks=-1)
+
+        # initialize OBS predictand dataset
+        LogConfig.init_log('Initializing observations for predictand: {}'
+                           .format(PREDICTAND))
+
+        # check whether to joinlty train tasmin and tasmax
+        if PREDICTAND == 'tas':
+            # read both tasmax and tasmin
+            tasmax = xr.open_dataset(
+                search_files(OBS_PATH.joinpath('tasmax'), '.nc$').pop())
+            tasmin = xr.open_dataset(
+                search_files(OBS_PATH.joinpath('tasmin'), '.nc$').pop())
+            Obs_ds = xr.merge([tasmax, tasmin])
+        else:
+            # read in-situ gridded observations
+            Obs_ds = search_files(OBS_PATH.joinpath(PREDICTAND), '.nc$').pop()
+            Obs_ds = xr.open_dataset(Obs_ds)
+
+        # whether to use digital elevation model
+        if DEM:
+            # digital elevation model: Copernicus EU-Dem v1.1
+            dem = search_files(DEM_PATH, '^eu_dem_v11_stt.nc$').pop()
+
+            # read elevation and compute slope and aspect
+            dem = ERA5Dataset.dem_features(
+                dem, {'y': Era5_ds.y, 'x': Era5_ds.x},
+                add_coord={'time': Era5_ds.time})
+
+            # check whether to use slope and aspect
+            if not DEM_FEATURES:
+                dem = dem.drop_vars(['slope', 'aspect'])
+
+            # add dem to set of predictor variables
+            Era5_ds = xr.merge([Era5_ds, dem])
+
+        # initialize network and optimizer
+        LogConfig.init_log('Initializing network and optimizer.')
+
+        # define number of output fields
+        # check whether modelling pr with probabilistic approach
+        outputs = len(Obs_ds.data_vars)
+        if PREDICTAND == 'pr' and isinstance(LOSS, BernoulliGammaLoss):
+            outputs = 3
+
+        # instanciate network
+        net = NET(state_file, len(Era5_ds.data_vars), outputs, filters=FILTERS)
+
+        	# initialize optimizer
+        optimizer = torch.optim.Adam(net.parameters(), lr=LR,
+                                     weight_decay=LAMBDA)
+
+        # initialize training data
+        LogConfig.init_log('Initializing training data.')
+
+        # split calibration period using cross-validation TimeSeriesSplit
+        cv = TimeSeriesSplit()
+        for i, (train_idx, valid_idx) in enumerate(cv.split(CALIB_PERIOD)):
+
+            # time steps for training and validation set
+            train = CALIB_PERIOD[train_idx]
+            valid = CALIB_PERIOD[valid_idx]
+            LogConfig.init_log('Fold {}/{}: {} - {}'.format(
+                i + 1, cv.n_splits, str(train[0]), str(train[-1])))
+
+            # training and validation dataset
+            Era5_train, Obs_train = (Era5_ds.sel(time=train),
+                                     Obs_ds.sel(time=train))
+            Era5_valid, Obs_valid = (Era5_ds.sel(time=valid),
+                                     Obs_ds.sel(time=valid))
+
+            # create PyTorch compliant dataset and dataloader instances for
+            # model training
+            train_ds = NetCDFDataset(Era5_train, Obs_train, normalize=NORM,
+                                     doy=DOY)
+            valid_ds = NetCDFDataset(Era5_valid, Obs_valid, normalize=NORM,
+                                     doy=DOY)
+            train_dl = DataLoader(train_ds, batch_size=BATCH_SIZE,
+                                  shuffle=SHUFFLE, drop_last=False)
+            valid_dl = DataLoader(valid_ds, batch_size=BATCH_SIZE,
+                                  shuffle=SHUFFLE, drop_last=False)
+
+            # initialize network trainer
+            trainer = NetworkTrainer(net, optimizer, net.state_file, train_dl,
+                                     valid_dl, loss_function=LOSS,
+                                     **TRAIN_CONFIG)
+
+            # train model
+            state = trainer.train()
+
+    # log execution time of script
+    LogConfig.init_log('Execution time of script {}: {}'
+                       .format(__file__, timedelta(seconds=time.monotonic() -
+                                                   start_time)))
-- 
GitLab