From 21fcaebb55728bcd460500ab2c85bbb00dd117e8 Mon Sep 17 00:00:00 2001
From: "Daniel.Frisinghelli" <daniel.frisinghelli@eurac.edu>
Date: Tue, 19 Oct 2021 14:54:25 +0200
Subject: [PATCH] Learning rate range test.

---
 climax/main/downscale_infer.py |   4 +-
 climax/main/downscale_train.py |   4 +-
 climax/main/lr_range_test.py   | 193 +++++++++++++++++++++++++++++++++
 3 files changed, 197 insertions(+), 4 deletions(-)
 create mode 100644 climax/main/lr_range_test.py

diff --git a/climax/main/downscale_infer.py b/climax/main/downscale_infer.py
index 2918217..47f104d 100644
--- a/climax/main/downscale_infer.py
+++ b/climax/main/downscale_infer.py
@@ -24,7 +24,7 @@ from climax.core.utils import split_date_range
 from climax.main.config import (ERA5_PREDICTORS, ERA5_PLEVELS, PREDICTAND, NET,
                                 VALID_PERIOD, BATCH_SIZE, NORM, DOY, NYEARS,
                                 DEM, DEM_FEATURES, LOSS, ANOMALIES, OPTIM,
-                                OPTIM_PARAMS)
+                                OPTIM_PARAMS, CHUNKS)
 from climax.main.io import ERA5_PATH, DEM_PATH, MODEL_PATH, TARGET_PATH
 
 # module level logger
@@ -58,7 +58,7 @@ if __name__ == '__main__':
     LogConfig.init_log('Initializing ERA5 predictors.')
     Era5 = ERA5Dataset(ERA5_PATH.joinpath('ERA5'), ERA5_PREDICTORS,
                        plevels=ERA5_PLEVELS)
-    Era5_ds = Era5.merge(chunks=-1)
+    Era5_ds = Era5.merge(chunks=CHUNKS)
 
     # whether to use digital elevation model
     if DEM:
diff --git a/climax/main/downscale_train.py b/climax/main/downscale_train.py
index 1535759..770750a 100644
--- a/climax/main/downscale_train.py
+++ b/climax/main/downscale_train.py
@@ -28,7 +28,7 @@ from climax.main.config import (ERA5_PLEVELS, ERA5_PREDICTORS, PREDICTAND,
                                 OVERWRITE, DEM, DEM_FEATURES, STRATIFY,
                                 WET_DAY_THRESHOLD, VALID_SIZE, ANOMALIES,
                                 OPTIM_PARAMS, LR_SCHEDULER,
-                                LR_SCHEDULER_PARAMS)
+                                LR_SCHEDULER_PARAMS, CHUNKS)
 from climax.main.io import ERA5_PATH, OBS_PATH, DEM_PATH, MODEL_PATH
 
 # module level logger
@@ -70,7 +70,7 @@ if __name__ == '__main__':
     LogConfig.init_log('Initializing ERA5 predictors.')
     Era5 = ERA5Dataset(ERA5_PATH.joinpath('ERA5'), ERA5_PREDICTORS,
                        plevels=ERA5_PLEVELS)
-    Era5_ds = Era5.merge(chunks={'time': 365})
+    Era5_ds = Era5.merge(chunks=CHUNKS)
 
     # initialize OBS predictand dataset
     LogConfig.init_log('Initializing observations for predictand: {}'
diff --git a/climax/main/lr_range_test.py b/climax/main/lr_range_test.py
new file mode 100644
index 0000000..99e716e
--- /dev/null
+++ b/climax/main/lr_range_test.py
@@ -0,0 +1,193 @@
+"""Learning rate range test."""
+
+# !/usr/bin/env python
+# -*- coding: utf-8 -*-
+
+# builtins
+import sys
+import time
+import logging
+from datetime import timedelta
+from logging.config import dictConfig
+
+# externals
+import xarray as xr
+from sklearn.model_selection import train_test_split
+from torch.utils.data import DataLoader
+
+# locals
+from pysegcnn.core.utils import search_files
+from pysegcnn.core.trainer import NetworkTrainer, LogConfig
+from pysegcnn.core.models import Network
+from pysegcnn.core.logging import log_conf
+from climax.core.dataset import ERA5Dataset, NetCDFDataset
+from climax.core.loss import MSELoss, L1Loss
+from climax.main.config import (ERA5_PLEVELS, ERA5_PREDICTORS, PREDICTAND,
+                                CALIB_PERIOD, DOY, SHUFFLE, BATCH_SIZE, OPTIM,
+                                NORM, NET, LOSS, FILTERS, DEM, DEM_FEATURES,
+                                STRATIFY, WET_DAY_THRESHOLD, VALID_SIZE,
+                                ANOMALIES)
+from climax.main.io import ERA5_PATH, OBS_PATH, DEM_PATH, MODEL_PATH
+
+# module level logger
+LOGGER = logging.getLogger(__name__)
+
+# network training configuration
+TRAIN_CONFIG = {
+    'checkpoint_state': {},
+    'epochs': 50,
+    'save': True,
+    'save_loaders': False,
+    'early_stop': True,
+    'patience': 100,
+    'multi_gpu': True,
+    'classification': False,
+    'clip_gradients': False
+    }
+
+# minimum and maximum learning rate
+MIN_LR = 1e-4
+
+# learning rate scheduler: increase lr each epoch
+LR_SCHEDULER = torch.optim.lr_scheduler.ExponentialLR
+LR_SCHEDULER_PARAMS = {'gamma': 1.15}
+
+
+if __name__ == '__main__':
+
+    # initialize timing
+    start_time = time.monotonic()
+
+    # initialize network filename
+    state_file = ERA5Dataset.state_file(
+        NET, PREDICTAND, ERA5_PREDICTORS, ERA5_PLEVELS, dem=DEM,
+        dem_features=DEM_FEATURES, doy=DOY, loss=LOSS, anomalies=ANOMALIES,
+        optim=OPTIM)
+
+    # indicate lr range test
+    state_file = state_file.replace('.pt', 'lr_test.pt')
+
+    # path to model state
+    state_file = MODEL_PATH.joinpath(PREDICTAND, state_file)
+
+    # initialize logging
+    log_file = MODEL_PATH.joinpath(PREDICTAND,
+                                   state_file.name.replace('.pt', '_log.txt'))
+    if log_file.exists():
+        log_file.unlink()
+    dictConfig(log_conf(log_file))
+
+    # initialize downscaling
+    LogConfig.init_log('Initializing learning rate test for period: {}'.format(
+        ' - '.join([str(CALIB_PERIOD[0]), str(CALIB_PERIOD[-1])])))
+
+    # initialize ERA5 predictor dataset
+    LogConfig.init_log('Initializing ERA5 predictors.')
+    Era5 = ERA5Dataset(ERA5_PATH.joinpath('ERA5'), ERA5_PREDICTORS,
+                       plevels=ERA5_PLEVELS)
+    Era5_ds = Era5.merge(chunks=CHUNKS)
+
+    # initialize OBS predictand dataset
+    LogConfig.init_log('Initializing observations for predictand: {}'
+                       .format(PREDICTAND))
+
+    # check whether to joinlty train tasmin and tasmax
+    if PREDICTAND == 'tas':
+        # read both tasmax and tasmin
+        tasmax = xr.open_dataset(
+            search_files(OBS_PATH.joinpath('tasmax'), '.nc$').pop())
+        tasmin = xr.open_dataset(
+            search_files(OBS_PATH.joinpath('tasmin'), '.nc$').pop())
+        Obs_ds = xr.merge([tasmax, tasmin])
+    else:
+        # read in-situ gridded observations
+        Obs_ds = search_files(OBS_PATH.joinpath(PREDICTAND), '.nc$').pop()
+        Obs_ds = xr.open_dataset(Obs_ds)
+
+    # whether to use digital elevation model
+    if DEM:
+        # digital elevation model: Copernicus EU-Dem v1.1
+        dem = search_files(DEM_PATH, '^eu_dem_v11_stt.nc$').pop()
+
+        # read elevation and compute slope and aspect
+        dem = ERA5Dataset.dem_features(
+            dem, {'y': Era5_ds.y, 'x': Era5_ds.x},
+            add_coord={'time': Era5_ds.time})
+
+        # check whether to use slope and aspect
+        if not DEM_FEATURES:
+            dem = dem.drop_vars(['slope', 'aspect']).chunk(Era5_ds.chunks)
+
+        # add dem to set of predictor variables
+        Era5_ds = xr.merge([Era5_ds, dem])
+
+    # initialize network and optimizer
+    LogConfig.init_log('Initializing network and optimizer.')
+
+    # define number of output fields
+    # check whether modelling pr with probabilistic approach
+    outputs = len(Obs_ds.data_vars)
+    if PREDICTAND == 'pr':
+        outputs = (1 if (isinstance(LOSS, MSELoss) or isinstance(LOSS, L1Loss))
+                   else 3)
+
+    # instanciate network
+    inputs = len(Era5_ds.data_vars) + 2 if DOY else len(Era5_ds.data_vars)
+    net = NET(state_file, inputs, outputs, filters=FILTERS)
+
+    # initialize optimizer
+    if OPTIM == torch.optim.SGD:
+        optimizer = OPTIM(net.parameters(), lr=MIN_LR, weight_decay=0,
+                          momentum=0.99)
+    else:
+        optimizer = OPTIM(net.parameters(), lr=MIN_LR, weight_decay=0)
+
+    # initialize learning rate scheduler
+    if LR_SCHEDULER is not None:
+        LR_SCHEDULER = LR_SCHEDULER(optimizer, **LR_SCHEDULER_PARAMS)
+
+    # initialize training data
+    LogConfig.init_log('Initializing training data.')
+
+    # split calibration period into training and validation period
+    if PREDICTAND == 'pr' and STRATIFY:
+        # stratify training and validation dataset by number of
+        # observed wet days for precipitation
+        wet_days = (Obs_ds.sel(time=CALIB_PERIOD).mean(dim=('y', 'x'))
+                    >= WET_DAY_THRESHOLD).to_array().values.squeeze()
+        train, valid = train_test_split(
+            CALIB_PERIOD, stratify=wet_days, test_size=VALID_SIZE)
+
+        # sort chronologically
+        train, valid = sorted(train), sorted(valid)
+    else:
+        train, valid = train_test_split(CALIB_PERIOD, shuffle=False,
+                                        test_size=VALID_SIZE)
+
+    # training and validation dataset
+    Era5_train, Obs_train = Era5_ds.sel(time=train), Obs_ds.sel(time=train)
+    Era5_valid, Obs_valid = Era5_ds.sel(time=valid), Obs_ds.sel(time=valid)
+
+    # create PyTorch compliant dataset and dataloader instances for model
+    # training
+    train_ds = NetCDFDataset(Era5_train, Obs_train, normalize=NORM, doy=DOY,
+                             anomalies=ANOMALIES)
+    valid_ds = NetCDFDataset(Era5_valid, Obs_valid, normalize=NORM, doy=DOY,
+                             anomalies=ANOMALIES)
+    train_dl = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=SHUFFLE,
+                          drop_last=False)
+    valid_dl = DataLoader(valid_ds, batch_size=BATCH_SIZE, shuffle=SHUFFLE,
+                          drop_last=False)
+
+    # initialize network trainer
+    trainer = NetworkTrainer(net, optimizer, net.state_file, train_dl,
+                             valid_dl, loss_function=LOSS,
+                             lr_scheduler=LR_SCHEDULER, **TRAIN_CONFIG)
+
+    # train model
+    state = trainer.train()
+
+    # log execution time of script
+    LogConfig.init_log('Execution time of script {}: {}'
+                       .format(__file__, timedelta(seconds=time.monotonic() -
+                                                   start_time)))
-- 
GitLab