From 801495821d62307f6466566588bad1b4ba35f219 Mon Sep 17 00:00:00 2001 From: "Daniel.Frisinghelli" <daniel.frisinghelli@eurac.edu> Date: Fri, 1 Oct 2021 10:26:41 +0200 Subject: [PATCH] Module to train CNN using cross-validation. --- climax/main/downscale_train_cv.py | 174 ++++++++++++++++++++++++++++++ 1 file changed, 174 insertions(+) create mode 100644 climax/main/downscale_train_cv.py diff --git a/climax/main/downscale_train_cv.py b/climax/main/downscale_train_cv.py new file mode 100644 index 0000000..5aaec33 --- /dev/null +++ b/climax/main/downscale_train_cv.py @@ -0,0 +1,174 @@ +"""Dynamical climate downscaling using deep convolutional neural networks.""" + +# !/usr/bin/env python +# -*- coding: utf-8 -*- + +# builtins +import time +import logging +from datetime import timedelta +from logging.config import dictConfig + +# externals +import torch +import xarray as xr +from sklearn.model_selection import TimeSeriesSplit +from torch.utils.data import DataLoader + +# locals +from pysegcnn.core.utils import search_files +from pysegcnn.core.trainer import NetworkTrainer, LogConfig +from pysegcnn.core.models import Network +from pysegcnn.core.logging import log_conf +from climax.core.dataset import ERA5Dataset, NetCDFDataset +from climax.core.loss import BernoulliGammaLoss +from climax.main.config import (ERA5_PLEVELS, ERA5_PREDICTORS, PREDICTAND, + CALIB_PERIOD, DOY, SHUFFLE, BATCH_SIZE, LR, + LAMBDA, NORM, TRAIN_CONFIG, NET, LOSS, FILTERS, + OVERWRITE, DEM, DEM_FEATURES) +from climax.main.io import ERA5_PATH, OBS_PATH, DEM_PATH, MODEL_PATH + +# module level logger +LOGGER = logging.getLogger(__name__) + + +if __name__ == '__main__': + + # initialize timing + start_time = time.monotonic() + + # initialize network filename + state_file = ERA5Dataset.state_file( + NET, PREDICTAND, ERA5_PREDICTORS, ERA5_PLEVELS, dem=DEM, + dem_features=DEM_FEATURES, doy=DOY) + + # adjust statefile name for precipitation + if PREDICTAND == 'pr': + if isinstance(LOSS, BernoulliGammaLoss): + state_file = state_file.replace('.pt', '_{}mm_{}.pt'.format( + str(LOSS.min_amount).replace('.', ''), + repr(LOSS).strip('()'))) + else: + state_file = state_file.replace('.pt', '_{}.pt'.format( + repr(LOSS).strip('()'))) + + # add suffix for training with cross-validation + state_file = state_file.replace('.pt', '_cv.pt') + + # path to model state + state_file = MODEL_PATH.joinpath(PREDICTAND, state_file) + + # initialize logging + log_file = MODEL_PATH.joinpath(PREDICTAND, + state_file.name.replace('.pt', '_log.txt')) + if log_file.exists(): + log_file.unlink() + dictConfig(log_conf(log_file)) + + # initialize downscaling + LogConfig.init_log('Initializing downscaling for period: {}'.format( + ' - '.join([str(CALIB_PERIOD[0]), str(CALIB_PERIOD[-1])]))) + + # check if model exists + if state_file.exists() and not OVERWRITE: + # load pretrained network + net, _ = Network.load_pretrained_model(state_file, NET) + else: + # initialize ERA5 predictor dataset + LogConfig.init_log('Initializing ERA5 predictors.') + Era5 = ERA5Dataset(ERA5_PATH.joinpath('ERA5'), ERA5_PREDICTORS, + plevels=ERA5_PLEVELS) + Era5_ds = Era5.merge(chunks=-1) + + # initialize OBS predictand dataset + LogConfig.init_log('Initializing observations for predictand: {}' + .format(PREDICTAND)) + + # check whether to joinlty train tasmin and tasmax + if PREDICTAND == 'tas': + # read both tasmax and tasmin + tasmax = xr.open_dataset( + search_files(OBS_PATH.joinpath('tasmax'), '.nc$').pop()) + tasmin = xr.open_dataset( + search_files(OBS_PATH.joinpath('tasmin'), '.nc$').pop()) + Obs_ds = xr.merge([tasmax, tasmin]) + else: + # read in-situ gridded observations + Obs_ds = search_files(OBS_PATH.joinpath(PREDICTAND), '.nc$').pop() + Obs_ds = xr.open_dataset(Obs_ds) + + # whether to use digital elevation model + if DEM: + # digital elevation model: Copernicus EU-Dem v1.1 + dem = search_files(DEM_PATH, '^eu_dem_v11_stt.nc$').pop() + + # read elevation and compute slope and aspect + dem = ERA5Dataset.dem_features( + dem, {'y': Era5_ds.y, 'x': Era5_ds.x}, + add_coord={'time': Era5_ds.time}) + + # check whether to use slope and aspect + if not DEM_FEATURES: + dem = dem.drop_vars(['slope', 'aspect']) + + # add dem to set of predictor variables + Era5_ds = xr.merge([Era5_ds, dem]) + + # initialize network and optimizer + LogConfig.init_log('Initializing network and optimizer.') + + # define number of output fields + # check whether modelling pr with probabilistic approach + outputs = len(Obs_ds.data_vars) + if PREDICTAND == 'pr' and isinstance(LOSS, BernoulliGammaLoss): + outputs = 3 + + # instanciate network + net = NET(state_file, len(Era5_ds.data_vars), outputs, filters=FILTERS) + + # initialize optimizer + optimizer = torch.optim.Adam(net.parameters(), lr=LR, + weight_decay=LAMBDA) + + # initialize training data + LogConfig.init_log('Initializing training data.') + + # split calibration period using cross-validation TimeSeriesSplit + cv = TimeSeriesSplit() + for i, (train_idx, valid_idx) in enumerate(cv.split(CALIB_PERIOD)): + + # time steps for training and validation set + train = CALIB_PERIOD[train_idx] + valid = CALIB_PERIOD[valid_idx] + LogConfig.init_log('Fold {}/{}: {} - {}'.format( + i + 1, cv.n_splits, str(train[0]), str(train[-1]))) + + # training and validation dataset + Era5_train, Obs_train = (Era5_ds.sel(time=train), + Obs_ds.sel(time=train)) + Era5_valid, Obs_valid = (Era5_ds.sel(time=valid), + Obs_ds.sel(time=valid)) + + # create PyTorch compliant dataset and dataloader instances for + # model training + train_ds = NetCDFDataset(Era5_train, Obs_train, normalize=NORM, + doy=DOY) + valid_ds = NetCDFDataset(Era5_valid, Obs_valid, normalize=NORM, + doy=DOY) + train_dl = DataLoader(train_ds, batch_size=BATCH_SIZE, + shuffle=SHUFFLE, drop_last=False) + valid_dl = DataLoader(valid_ds, batch_size=BATCH_SIZE, + shuffle=SHUFFLE, drop_last=False) + + # initialize network trainer + trainer = NetworkTrainer(net, optimizer, net.state_file, train_dl, + valid_dl, loss_function=LOSS, + **TRAIN_CONFIG) + + # train model + state = trainer.train() + + # log execution time of script + LogConfig.init_log('Execution time of script {}: {}' + .format(__file__, timedelta(seconds=time.monotonic() - + start_time))) -- GitLab