From 53d9422394e2a1c6713e47dcae8d433ee053acd3 Mon Sep 17 00:00:00 2001
From: "Daniel.Frisinghelli" <daniel.frisinghelli@eurac.edu>
Date: Thu, 1 Jul 2021 12:42:33 +0200
Subject: [PATCH] Initial implementation of downscaling script.

---
 climax/core/constants.py | 12 +++----
 climax/main/config.py    | 69 +++++++++++++++++++++++++++++++++++++---
 climax/main/downscale.py | 53 ++++++++++++++++++++++++++++++
 3 files changed, 123 insertions(+), 11 deletions(-)
 create mode 100644 climax/main/downscale.py

diff --git a/climax/core/constants.py b/climax/core/constants.py
index a02265c..c1b435d 100644
--- a/climax/core/constants.py
+++ b/climax/core/constants.py
@@ -31,18 +31,18 @@ EUROCORDEX_RCMS = ['SMHI-RCA4', 'CLMcom-CCLM4-8-17',
 # climate data operator (cdo) resampling modes
 CDO_RESAMPLING_MODES = ['bilinear', 'conservative']
 
-# ERA5 variables on pressure levels
+# ERA5 predictor variables on pressure levels
 ERA5_P_VARIABLES = ['geopotential', 'temperature', 'u_component_of_wind',
                     'v_component_of_wind', 'specific_humidity']
 
-# ERA5 variables on single levels
+# ERA5 predictor variables on single levels
 ERA5_S_VARIABLES = ['mean_sea_level_pressure']
 
-# ERA5 variables
+# ERA5 predictor variables
 ERA5_VARIABLES = ERA5_P_VARIABLES + ERA5_S_VARIABLES
 
-# ERA5 pressure levels
-ERA5_PLEVELS = [500, 850]
-
 # name of target projection
 PROJECTION = 'lambert_azimuthal_equal_area'
+
+# predictand variables: covered by observations
+PREDICTANDS = ['tasmin', 'tasmax', 'pr']
diff --git a/climax/main/config.py b/climax/main/config.py
index 312c9a9..62908b4 100644
--- a/climax/main/config.py
+++ b/climax/main/config.py
@@ -7,12 +7,71 @@
 import pathlib
 import datetime
 
+# externals
+import numpy as np
+
+# locals
+from climax.core.constants import PREDICTANDS
+
+# -----------------------------------------------------------------------------
+# Paths to input data ---------------------------------------------------------
+# -----------------------------------------------------------------------------
+
+# project root path
+ROOT = pathlib.Path('/mnt/CEPH_PROJECTS/FACT_CLIMAX/')
+
 # path to this file
 HERE = pathlib.Path(__file__).parent
 
-# calibration period
-P_CAL = (datetime.datetime.strptime('1981-01-01', '%Y-%m-%d').date(),
-         datetime.datetime.strptime('2011-01-01', '%Y-%m-%d').date())
-
 # path to ERA5 reanalysis data
-ERA5_PATH = pathlib.Path('/mnt/CEPH_PROJECTS/FACT_CLIMAX/REANALYSIS/')
+ERA5_PATH = ROOT.joinpath('REANALYSIS')
+
+# path to OBServation data
+OBS_PATH = ROOT.joinpath('OBSERVATION')
+
+# path to save trained models
+MODEL_PATH = ROOT.joinpath('Models')
+
+# -----------------------------------------------------------------------------
+# ERA5 downscaling configuration ----------------------------------------------
+# -----------------------------------------------------------------------------
+
+# ERA5 predictor variables on pressure levels
+# ERA5_P_PREDICTORS = ERA5_P_VARIABLES
+
+# # ERA5 predictor variables on single levels
+# ERA5_S_PREDICTORS = ERA5_S_VARIABLES
+
+# # ERA5 predictor variables
+# ERA5_PREDICTORS = ERA5_VARIABLES
+
+# ERA5 pressure levels
+ERA5_PLEVELS = [500, 850]
+
+# -----------------------------------------------------------------------------
+# Observations ----------------------------------------------------------------
+# -----------------------------------------------------------------------------
+
+# target variable: check if target variable is valid
+PREDICTAND = 'tasmin'
+assert PREDICTAND in PREDICTANDS
+
+# -----------------------------------------------------------------------------
+# Calibration period  ---------------------------------------------------------
+# -----------------------------------------------------------------------------
+
+# calibration period: training and validation
+CALIB_PERIOD = np.arange(
+    datetime.datetime.strptime('1981-01-01', '%Y-%m-%d').date(),
+    datetime.datetime.strptime('2011-01-01', '%Y-%m-%d').date())
+
+# -----------------------------------------------------------------------------
+# Model training configuration ------------------------------------------------
+# -----------------------------------------------------------------------------
+
+# whether to randomly shuffle time steps or to conserve time series for model
+# training
+SHUFFLE = False
+
+# batch size: number of time steps processed by the net in each iteration
+BATCH_SIZE = 64
diff --git a/climax/main/downscale.py b/climax/main/downscale.py
new file mode 100644
index 0000000..3f3bc7f
--- /dev/null
+++ b/climax/main/downscale.py
@@ -0,0 +1,53 @@
+"""Dynamical climate downscaling using deep convolutional neural networks."""
+
+# !/usr/bin/env python
+# -*- coding: utf-8 -*-
+
+# externals
+import xarray as xr
+from sklearn.model_selection import train_test_split
+from torch.utils.data import DataLoader
+
+# locals
+from pysegcnn.core.utils import search_files
+from pysegcnn.core.models import SegNet
+from climax.core.dataset import ERA5Dataset, NetCDFDataset
+from climax.core.constants import (ERA5_P_VARIABLES, ERA5_S_VARIABLES,
+                                   ERA5_VARIABLES)
+from climax.main.config import (ERA5_PATH, ERA5_PLEVELS, OBS_PATH, PREDICTAND,
+                                CALIB_PERIOD, MODEL_PATH, SHUFFLE, BATCH_SIZE)
+
+
+if __name__ == '__main__':
+
+    # initialize ERA5 predictor dataset
+    Era5 = ERA5Dataset(ERA5_PATH, ERA5_VARIABLES, plevels=ERA5_PLEVELS)
+    Era5_ds = Era5.merge()
+
+    # initialize OBS predictand dataset
+    Obs_ds = search_files(OBS_PATH.joinpath(PREDICTAND), '.nc$').pop()
+    Obs_ds = xr.open_dataset(Obs_ds)
+
+    # split calibration period into training and validation period
+    train, valid = train_test_split(CALIB_PERIOD, shuffle=SHUFFLE)
+
+    # training and validation dataset
+    Era5_train, Obs_train = Era5_ds.sel(time=train), Obs_ds.sel(time=train)
+    Era5_valid, Obs_valid = Era5_ds.sel(time=valid), Obs_ds.sel(time=valid)
+
+    # create PyTorch compliant dataset and dataloader instances for model
+    # training
+    train_ds = NetCDFDataset(Era5_train, Obs_train, dim='time')
+    valid_ds = NetCDFDataset(Era5_valid, Obs_valid, dim='time')
+    train_dl = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=SHUFFLE,
+                          drop_last=False)
+    valid_dl = DataLoader(valid_ds, batch_size=BATCH_SIZE, shuffle=SHUFFLE,
+                          drop_last=False)
+
+    # initialize network: calculate number of input variables
+    in_channels = int(len(ERA5_P_VARIABLES) * len(ERA5_PLEVELS) +
+                      len(ERA5_S_VARIABLES))
+    net = SegNet(MODEL_PATH.joinpath(PREDICTAND + '.pt'), in_channels, 1)
+
+    # initialize network training
+    # TODO: Extend ClassificationNetworkTrainer -> RegressionNetworkTrainer
-- 
GitLab