From bd79fde19018349404c5d813457dd448c9184132 Mon Sep 17 00:00:00 2001 From: "Daniel.Frisinghelli" <daniel.frisinghelli@eurac.edu> Date: Fri, 1 Oct 2021 10:40:06 +0200 Subject: [PATCH] Implemented training with cross-validation. --- climax/main/config.py | 3 + climax/main/downscale_infer.py | 5 +- climax/main/downscale_train.py | 123 ++++++++++++++------- climax/main/downscale_train_cv.py | 174 ------------------------------ 4 files changed, 92 insertions(+), 213 deletions(-) delete mode 100644 climax/main/downscale_train_cv.py diff --git a/climax/main/config.py b/climax/main/config.py index 42baad4..7e86890 100644 --- a/climax/main/config.py +++ b/climax/main/config.py @@ -55,6 +55,9 @@ DEM_FEATURES = False # stratify training/validation set for precipitation by number of wet days STRATIFY = True +# whether to train using cross-validation +CV = False + # ----------------------------------------------------------------------------- # Observations ---------------------------------------------------------------- # ----------------------------------------------------------------------------- diff --git a/climax/main/downscale_infer.py b/climax/main/downscale_infer.py index e61269c..9f57fd8 100644 --- a/climax/main/downscale_infer.py +++ b/climax/main/downscale_infer.py @@ -24,7 +24,7 @@ from climax.core.utils import split_date_range from climax.core.loss import BernoulliGammaLoss from climax.main.config import (ERA5_PREDICTORS, ERA5_PLEVELS, PREDICTAND, NET, VALID_PERIOD, BATCH_SIZE, NORM, DOY, NYEARS, - DEM, DEM_FEATURES, LOSS) + DEM, DEM_FEATURES, LOSS, CV) from climax.main.io import ERA5_PATH, DEM_PATH, MODEL_PATH, TARGET_PATH # module level logger @@ -51,6 +51,9 @@ if __name__ == '__main__': state_file = state_file.replace('.pt', '_{}.pt'.format( repr(LOSS).strip('()'))) + # add suffix for training with cross-validation + state_file = state_file.replace('.pt', '_cv.pt') if CV else state_file + # path to model state state_file = MODEL_PATH.joinpath(PREDICTAND, state_file) diff --git a/climax/main/downscale_train.py b/climax/main/downscale_train.py index 79758be..b985de3 100644 --- a/climax/main/downscale_train.py +++ b/climax/main/downscale_train.py @@ -12,7 +12,7 @@ from logging.config import dictConfig # externals import torch import xarray as xr -from sklearn.model_selection import train_test_split +from sklearn.model_selection import train_test_split, TimeSeriesSplit from torch.utils.data import DataLoader # locals @@ -26,7 +26,7 @@ from climax.main.config import (ERA5_PLEVELS, ERA5_PREDICTORS, PREDICTAND, CALIB_PERIOD, DOY, SHUFFLE, BATCH_SIZE, LR, LAMBDA, NORM, TRAIN_CONFIG, NET, LOSS, FILTERS, OVERWRITE, DEM, DEM_FEATURES, STRATIFY, - WET_DAY_THRESHOLD) + WET_DAY_THRESHOLD, CV) from climax.main.io import ERA5_PATH, OBS_PATH, DEM_PATH, MODEL_PATH # module level logger @@ -53,6 +53,9 @@ if __name__ == '__main__': state_file = state_file.replace('.pt', '_{}.pt'.format( repr(LOSS).strip('()'))) + # add suffix for training with cross-validation + state_file = state_file.replace('.pt', '_cv.pt') if CV else state_file + # path to model state state_file = MODEL_PATH.joinpath(PREDICTAND, state_file) @@ -112,36 +115,8 @@ if __name__ == '__main__': # add dem to set of predictor variables Era5_ds = xr.merge([Era5_ds, dem]) - # initialize training data - LogConfig.init_log('Initializing training data.') - - # split calibration period into training and validation period - if PREDICTAND == 'pr' and STRATIFY: - # stratify training and validation dataset by number of observed - # wet days for precipitation - wet_days = (Obs_ds.sel(time=CALIB_PERIOD).mean(dim=('y', 'x')) >= - WET_DAY_THRESHOLD).to_array().values.squeeze() - train, valid = train_test_split( - CALIB_PERIOD, stratify=wet_days, test_size=0.1) - train, valid = sorted(train), sorted(valid) # sort chronologically - else: - train, valid = train_test_split(CALIB_PERIOD, shuffle=False, - test_size=0.1) - - # training and validation dataset - Era5_train, Obs_train = Era5_ds.sel(time=train), Obs_ds.sel(time=train) - Era5_valid, Obs_valid = Era5_ds.sel(time=valid), Obs_ds.sel(time=valid) - - # create PyTorch compliant dataset and dataloader instances for model - # training - train_ds = NetCDFDataset(Era5_train, Obs_train, normalize=NORM, - doy=DOY) - valid_ds = NetCDFDataset(Era5_valid, Obs_valid, normalize=NORM, - doy=DOY) - train_dl = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=SHUFFLE, - drop_last=False) - valid_dl = DataLoader(valid_ds, batch_size=BATCH_SIZE, shuffle=SHUFFLE, - drop_last=False) + # initialize network and optimizer + LogConfig.init_log('Initializing network and optimizer.') # define number of output fields # check whether modelling pr with probabilistic approach @@ -150,18 +125,90 @@ if __name__ == '__main__': outputs = 3 # instanciate network - net = NET(state_file, train_ds.X.shape[1], outputs, filters=FILTERS) + net = NET(state_file, len(Era5_ds.data_vars), outputs, filters=FILTERS) # initialize optimizer optimizer = torch.optim.Adam(net.parameters(), lr=LR, weight_decay=LAMBDA) - # initialize network trainer - trainer = NetworkTrainer(net, optimizer, net.state_file, train_dl, - valid_dl, loss_function=LOSS, **TRAIN_CONFIG) + # initialize training data + LogConfig.init_log('Initializing training data.') + if CV: + # split calibration period using cross-validation TimeSeriesSplit + cv = TimeSeriesSplit() + for i, (train_idx, valid_idx) in enumerate(cv.split(CALIB_PERIOD)): + + # time steps for training and validation set + train = CALIB_PERIOD[train_idx] + valid = CALIB_PERIOD[valid_idx] + LogConfig.init_log('Fold {}/{}: {} - {}'.format( + i + 1, cv.n_splits, str(train[0]), str(train[-1]))) + + # training and validation dataset + Era5_train, Obs_train = (Era5_ds.sel(time=train), + Obs_ds.sel(time=train)) + Era5_valid, Obs_valid = (Era5_ds.sel(time=valid), + Obs_ds.sel(time=valid)) + + # create PyTorch compliant dataset and dataloader instances for + # model training + train_ds = NetCDFDataset(Era5_train, Obs_train, normalize=NORM, + doy=DOY) + valid_ds = NetCDFDataset(Era5_valid, Obs_valid, normalize=NORM, + doy=DOY) + train_dl = DataLoader(train_ds, batch_size=BATCH_SIZE, + shuffle=SHUFFLE, drop_last=False) + valid_dl = DataLoader(valid_ds, batch_size=BATCH_SIZE, + shuffle=SHUFFLE, drop_last=False) + + # initialize network trainer + trainer = NetworkTrainer( + net, optimizer, net.state_file, train_dl, valid_dl, + loss_function=LOSS, **TRAIN_CONFIG) + + # train model + state = trainer.train() - # train model - state = trainer.train() + else: + # split calibration period into training and validation period + if PREDICTAND == 'pr' and STRATIFY: + # stratify training and validation dataset by number of + # observed wet days for precipitation + wet_days = (Obs_ds.sel(time=CALIB_PERIOD).mean(dim=('y', 'x')) + >= WET_DAY_THRESHOLD).to_array().values.squeeze() + train, valid = train_test_split( + CALIB_PERIOD, stratify=wet_days, test_size=0.1) + + # sort chronologically + train, valid = sorted(train), sorted(valid) + else: + train, valid = train_test_split(CALIB_PERIOD, shuffle=False, + test_size=0.1) + + # training and validation dataset + Era5_train, Obs_train = (Era5_ds.sel(time=train), + Obs_ds.sel(time=train)) + Era5_valid, Obs_valid = (Era5_ds.sel(time=valid), + Obs_ds.sel(time=valid)) + + # create PyTorch compliant dataset and dataloader instances for model + # training + train_ds = NetCDFDataset(Era5_train, Obs_train, normalize=NORM, + doy=DOY) + valid_ds = NetCDFDataset(Era5_valid, Obs_valid, normalize=NORM, + doy=DOY) + train_dl = DataLoader(train_ds, batch_size=BATCH_SIZE, + shuffle=SHUFFLE, drop_last=False) + valid_dl = DataLoader(valid_ds, batch_size=BATCH_SIZE, + shuffle=SHUFFLE, drop_last=False) + + # initialize network trainer + trainer = NetworkTrainer(net, optimizer, net.state_file, train_dl, + valid_dl, loss_function=LOSS, + **TRAIN_CONFIG) + + # train model + state = trainer.train() # log execution time of script LogConfig.init_log('Execution time of script {}: {}' diff --git a/climax/main/downscale_train_cv.py b/climax/main/downscale_train_cv.py deleted file mode 100644 index 5aaec33..0000000 --- a/climax/main/downscale_train_cv.py +++ /dev/null @@ -1,174 +0,0 @@ -"""Dynamical climate downscaling using deep convolutional neural networks.""" - -# !/usr/bin/env python -# -*- coding: utf-8 -*- - -# builtins -import time -import logging -from datetime import timedelta -from logging.config import dictConfig - -# externals -import torch -import xarray as xr -from sklearn.model_selection import TimeSeriesSplit -from torch.utils.data import DataLoader - -# locals -from pysegcnn.core.utils import search_files -from pysegcnn.core.trainer import NetworkTrainer, LogConfig -from pysegcnn.core.models import Network -from pysegcnn.core.logging import log_conf -from climax.core.dataset import ERA5Dataset, NetCDFDataset -from climax.core.loss import BernoulliGammaLoss -from climax.main.config import (ERA5_PLEVELS, ERA5_PREDICTORS, PREDICTAND, - CALIB_PERIOD, DOY, SHUFFLE, BATCH_SIZE, LR, - LAMBDA, NORM, TRAIN_CONFIG, NET, LOSS, FILTERS, - OVERWRITE, DEM, DEM_FEATURES) -from climax.main.io import ERA5_PATH, OBS_PATH, DEM_PATH, MODEL_PATH - -# module level logger -LOGGER = logging.getLogger(__name__) - - -if __name__ == '__main__': - - # initialize timing - start_time = time.monotonic() - - # initialize network filename - state_file = ERA5Dataset.state_file( - NET, PREDICTAND, ERA5_PREDICTORS, ERA5_PLEVELS, dem=DEM, - dem_features=DEM_FEATURES, doy=DOY) - - # adjust statefile name for precipitation - if PREDICTAND == 'pr': - if isinstance(LOSS, BernoulliGammaLoss): - state_file = state_file.replace('.pt', '_{}mm_{}.pt'.format( - str(LOSS.min_amount).replace('.', ''), - repr(LOSS).strip('()'))) - else: - state_file = state_file.replace('.pt', '_{}.pt'.format( - repr(LOSS).strip('()'))) - - # add suffix for training with cross-validation - state_file = state_file.replace('.pt', '_cv.pt') - - # path to model state - state_file = MODEL_PATH.joinpath(PREDICTAND, state_file) - - # initialize logging - log_file = MODEL_PATH.joinpath(PREDICTAND, - state_file.name.replace('.pt', '_log.txt')) - if log_file.exists(): - log_file.unlink() - dictConfig(log_conf(log_file)) - - # initialize downscaling - LogConfig.init_log('Initializing downscaling for period: {}'.format( - ' - '.join([str(CALIB_PERIOD[0]), str(CALIB_PERIOD[-1])]))) - - # check if model exists - if state_file.exists() and not OVERWRITE: - # load pretrained network - net, _ = Network.load_pretrained_model(state_file, NET) - else: - # initialize ERA5 predictor dataset - LogConfig.init_log('Initializing ERA5 predictors.') - Era5 = ERA5Dataset(ERA5_PATH.joinpath('ERA5'), ERA5_PREDICTORS, - plevels=ERA5_PLEVELS) - Era5_ds = Era5.merge(chunks=-1) - - # initialize OBS predictand dataset - LogConfig.init_log('Initializing observations for predictand: {}' - .format(PREDICTAND)) - - # check whether to joinlty train tasmin and tasmax - if PREDICTAND == 'tas': - # read both tasmax and tasmin - tasmax = xr.open_dataset( - search_files(OBS_PATH.joinpath('tasmax'), '.nc$').pop()) - tasmin = xr.open_dataset( - search_files(OBS_PATH.joinpath('tasmin'), '.nc$').pop()) - Obs_ds = xr.merge([tasmax, tasmin]) - else: - # read in-situ gridded observations - Obs_ds = search_files(OBS_PATH.joinpath(PREDICTAND), '.nc$').pop() - Obs_ds = xr.open_dataset(Obs_ds) - - # whether to use digital elevation model - if DEM: - # digital elevation model: Copernicus EU-Dem v1.1 - dem = search_files(DEM_PATH, '^eu_dem_v11_stt.nc$').pop() - - # read elevation and compute slope and aspect - dem = ERA5Dataset.dem_features( - dem, {'y': Era5_ds.y, 'x': Era5_ds.x}, - add_coord={'time': Era5_ds.time}) - - # check whether to use slope and aspect - if not DEM_FEATURES: - dem = dem.drop_vars(['slope', 'aspect']) - - # add dem to set of predictor variables - Era5_ds = xr.merge([Era5_ds, dem]) - - # initialize network and optimizer - LogConfig.init_log('Initializing network and optimizer.') - - # define number of output fields - # check whether modelling pr with probabilistic approach - outputs = len(Obs_ds.data_vars) - if PREDICTAND == 'pr' and isinstance(LOSS, BernoulliGammaLoss): - outputs = 3 - - # instanciate network - net = NET(state_file, len(Era5_ds.data_vars), outputs, filters=FILTERS) - - # initialize optimizer - optimizer = torch.optim.Adam(net.parameters(), lr=LR, - weight_decay=LAMBDA) - - # initialize training data - LogConfig.init_log('Initializing training data.') - - # split calibration period using cross-validation TimeSeriesSplit - cv = TimeSeriesSplit() - for i, (train_idx, valid_idx) in enumerate(cv.split(CALIB_PERIOD)): - - # time steps for training and validation set - train = CALIB_PERIOD[train_idx] - valid = CALIB_PERIOD[valid_idx] - LogConfig.init_log('Fold {}/{}: {} - {}'.format( - i + 1, cv.n_splits, str(train[0]), str(train[-1]))) - - # training and validation dataset - Era5_train, Obs_train = (Era5_ds.sel(time=train), - Obs_ds.sel(time=train)) - Era5_valid, Obs_valid = (Era5_ds.sel(time=valid), - Obs_ds.sel(time=valid)) - - # create PyTorch compliant dataset and dataloader instances for - # model training - train_ds = NetCDFDataset(Era5_train, Obs_train, normalize=NORM, - doy=DOY) - valid_ds = NetCDFDataset(Era5_valid, Obs_valid, normalize=NORM, - doy=DOY) - train_dl = DataLoader(train_ds, batch_size=BATCH_SIZE, - shuffle=SHUFFLE, drop_last=False) - valid_dl = DataLoader(valid_ds, batch_size=BATCH_SIZE, - shuffle=SHUFFLE, drop_last=False) - - # initialize network trainer - trainer = NetworkTrainer(net, optimizer, net.state_file, train_dl, - valid_dl, loss_function=LOSS, - **TRAIN_CONFIG) - - # train model - state = trainer.train() - - # log execution time of script - LogConfig.init_log('Execution time of script {}: {}' - .format(__file__, timedelta(seconds=time.monotonic() - - start_time))) -- GitLab