diff --git a/climax/main/config.py b/climax/main/config.py index 67a1406c9704c26e44908f3687ec29d28573fd83..a8e78c2b91eb98be9e7c28216b8be44f46514802 100644 --- a/climax/main/config.py +++ b/climax/main/config.py @@ -92,6 +92,9 @@ VALID_SIZE = 0.2 # number of folds for training with KFold cross-validation CV = 5 +# number of bootstrapped model trainings +BOOTSTRAP = 10 + # ----------------------------------------------------------------------------- # Observations ---------------------------------------------------------------- # ----------------------------------------------------------------------------- diff --git a/climax/main/downscale.py b/climax/main/downscale.py index 528062a33938872d9085848de5c2e2c970d45496..d8f37d8e036d1bb710d2fa601ffb74bdbc08321d 100644 --- a/climax/main/downscale.py +++ b/climax/main/downscale.py @@ -60,6 +60,10 @@ if __name__ == '__main__': state_file = MODEL_PATH.joinpath(PREDICTAND, state_file) target = TARGET_PATH.joinpath(PREDICTAND) + # check if output path exists + if not target.exists(): + target.mkdir(parents=True, exist_ok=True) + # initialize logging log_file = state_file.parent.joinpath( state_file.name.replace(state_file.suffix, '_log.txt')) @@ -82,12 +86,6 @@ if __name__ == '__main__': LogConfig.init_log('Initializing downscaling for period: {}'.format( ' - '.join([str(CALIB_PERIOD[0]), str(CALIB_PERIOD[-1])]))) - # check if model exists - if state_file.exists() and not OVERWRITE: - # load pretrained network - LogConfig.init_log('{} already exists.'.format(state_file)) - sys.exit() - # initialize ERA5 predictor dataset LogConfig.init_log('Initializing ERA5 predictors.') Era5 = ERA5Dataset(ERA5_PATH.joinpath('ERA5'), ERA5_PREDICTORS, @@ -131,15 +129,17 @@ if __name__ == '__main__': # initialize network and optimizer LogConfig.init_log('Initializing network and optimizer.') - # define number of output fields + # define number of output fields: # check whether modelling pr with probabilistic approach outputs = len(Obs_ds.data_vars) if PREDICTAND == 'pr': outputs = (1 if (isinstance(LOSS, MSELoss) or isinstance(LOSS, L1Loss)) else 3) - # instanciate network + # define number of input fields inputs = len(Era5_ds.data_vars) + 2 if DOY else len(Era5_ds.data_vars) + + # instanciate network net = NET(state_file, inputs, outputs, filters=FILTERS) # initialize optimizer @@ -208,10 +208,9 @@ if __name__ == '__main__': # merge predictions for entire validation period LOGGER.info('Merging reference periods ...') trg_ds = xr.concat(trg_ds, dim='time') + trg_ds = trg_ds.sortby(trg_ds.time) # sort predictions chronologically # save model predictions as NetCDF file - if not target.parent.exists(): - target.parent.mkdir(parents=True, exist_ok=True) LOGGER.info('Saving network predictions: {}.'.format(target)) trg_ds.to_netcdf(target, engine='h5netcdf') diff --git a/climax/main/downscale_bootstrap.py b/climax/main/downscale_bootstrap.py new file mode 100644 index 0000000000000000000000000000000000000000..53116b4baa459dab88a7a4f59ca9a1666efdf767 --- /dev/null +++ b/climax/main/downscale_bootstrap.py @@ -0,0 +1,226 @@ +"""Dynamical climate downscaling using deep convolutional neural networks.""" + +# !/usr/bin/env python +# -*- coding: utf-8 -*- + +# builtins +import sys +import time +import logging +from datetime import timedelta +from logging.config import dictConfig + +# externals +import xarray as xr +from sklearn.model_selection import train_test_split +from torch.utils.data import DataLoader + +# locals +from pysegcnn.core.utils import search_files +from pysegcnn.core.models import Network +from pysegcnn.core.trainer import NetworkTrainer, LogConfig +from pysegcnn.core.logging import log_conf +from climax.core.dataset import ERA5Dataset, NetCDFDataset +from climax.core.loss import MSELoss, L1Loss +from climax.core.predict import predict_ERA5 +from climax.core.utils import split_date_range +from climax.main.config import (ERA5_PLEVELS, ERA5_PREDICTORS, PREDICTAND, + CALIB_PERIOD, DOY, SHUFFLE, BATCH_SIZE, OPTIM, + NORM, TRAIN_CONFIG, NET, LOSS, FILTERS, + OVERWRITE, DEM, DEM_FEATURES, STRATIFY, + WET_DAY_THRESHOLD, VALID_SIZE, ANOMALIES, + OPTIM_PARAMS, LR_SCHEDULER, SENSITIVITY, + LR_SCHEDULER_PARAMS, CHUNKS, VALID_PERIOD, + NYEARS, BOOTSTRAP) +from climax.main.io import (ERA5_PATH, OBS_PATH, DEM_PATH, MODEL_PATH, + TARGET_PATH) + +# module level logger +LOGGER = logging.getLogger(__name__) + + +if __name__ == '__main__': + + # initialize timing + start_time = time.monotonic() + + # initialize network filename + state_file = ERA5Dataset.state_file( + NET, PREDICTAND, ERA5_PREDICTORS, ERA5_PLEVELS, dem=DEM, + dem_features=DEM_FEATURES, doy=DOY, loss=LOSS, anomalies=ANOMALIES, + decay=OPTIM_PARAMS['weight_decay'], optim=OPTIM, lr=OPTIM_PARAMS['lr'], + lr_scheduler=LR_SCHEDULER) + + # models trained with bootstrapping + state_file = MODEL_PATH.joinpath('bootstrap', PREDICTAND, state_file) + target = TARGET_PATH.joinpath('bootstrap', PREDICTAND, state_file.stem) + + # check if output path exists + if not target.exists(): + target.mkdir(parents=True, exist_ok=True) + + # initialize logging + log_file = state_file.parent.joinpath( + state_file.name.replace(state_file.suffix, '_log.txt')) + if log_file.exists(): + log_file.unlink() + dictConfig(log_conf(log_file)) + + # initialize downscaling + LogConfig.init_log('Initializing downscaling for period: {}'.format( + ' - '.join([str(CALIB_PERIOD[0]), str(CALIB_PERIOD[-1])]))) + + # initialize ERA5 predictor dataset + LogConfig.init_log('Initializing ERA5 predictors.') + Era5 = ERA5Dataset(ERA5_PATH.joinpath('ERA5'), ERA5_PREDICTORS, + plevels=ERA5_PLEVELS) + Era5_ds = Era5.merge(chunks=CHUNKS) + + # initialize OBS predictand dataset + LogConfig.init_log('Initializing observations for predictand: {}' + .format(PREDICTAND)) + + # check whether to joinlty train tasmin and tasmax + if PREDICTAND == 'tas': + # read both tasmax and tasmin + tasmax = xr.open_dataset( + search_files(OBS_PATH.joinpath('tasmax'), '.nc$').pop()) + tasmin = xr.open_dataset( + search_files(OBS_PATH.joinpath('tasmin'), '.nc$').pop()) + Obs_ds = xr.merge([tasmax, tasmin]) + else: + # read in-situ gridded observations + Obs_ds = search_files(OBS_PATH.joinpath(PREDICTAND), + '.nc$').pop() + Obs_ds = xr.open_dataset(Obs_ds) + + # whether to use digital elevation model + if DEM: + # digital elevation model: Copernicus EU-Dem v1.1 + dem = search_files(DEM_PATH, '^eu_dem_v11_stt.nc$').pop() + + # read elevation and compute slope and aspect + dem = ERA5Dataset.dem_features( + dem, {'y': Era5_ds.y, 'x': Era5_ds.x}, + add_coord={'time': Era5_ds.time}) + + # check whether to use slope and aspect + if not DEM_FEATURES: + dem = dem.drop_vars(['slope', 'aspect']).chunk( + Era5_ds.chunks) + + # add dem to set of predictor variables + Era5_ds = xr.merge([Era5_ds, dem]) + + # define number of output fields: + # check whether modelling pr with probabilistic approach + outputs = len(Obs_ds.data_vars) + if PREDICTAND == 'pr': + outputs = (1 if (isinstance(LOSS, MSELoss) or + isinstance(LOSS, L1Loss)) else 3) + + # define number of input fields + inputs = len(Era5_ds.data_vars) + 2 if DOY else len(Era5_ds.data_vars) + + # initialize training data + LogConfig.init_log('Initializing training data.') + + # split calibration period into training and validation period + if PREDICTAND == 'pr' and STRATIFY: + # stratify training and validation dataset by number of + # observed wet days for precipitation + wet_days = (Obs_ds.sel(time=CALIB_PERIOD).mean(dim=('y', 'x')) + >= WET_DAY_THRESHOLD).to_array().values.squeeze() + train, valid = train_test_split( + CALIB_PERIOD, stratify=wet_days, test_size=VALID_SIZE) + + # sort chronologically + train, valid = sorted(train), sorted(valid) + else: + train, valid = train_test_split(CALIB_PERIOD, shuffle=False, + test_size=VALID_SIZE) + + # training and validation dataset + Era5_train, Obs_train = Era5_ds.sel(time=train), Obs_ds.sel(time=train) + Era5_valid, Obs_valid = Era5_ds.sel(time=valid), Obs_ds.sel(time=valid) + + # create PyTorch compliant dataset and dataloader instances for model + # training + train_ds = NetCDFDataset(Era5_train, Obs_train, normalize=NORM, doy=DOY, + anomalies=ANOMALIES) + valid_ds = NetCDFDataset(Era5_valid, Obs_valid, normalize=NORM, doy=DOY, + anomalies=ANOMALIES) + train_dl = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=SHUFFLE, + drop_last=False) + valid_dl = DataLoader(valid_ds, batch_size=BATCH_SIZE, shuffle=SHUFFLE, + drop_last=False) + + # initialize bootstrapped model training + for i in range(BOOTSTRAP): + + # add suffix for n'th bootstrap + state_file = state_file.parent.joinpath(state_file.name.replace( + state_file.suffix, '_{}.pt'.format(i + 1))) + + # check if target dataset already exists + target_ds = target.joinpath(state_file.name.replace( + state_file.suffix, '.nc')) + if target_ds.exists() and not OVERWRITE: + LogConfig.init_log('{} already exists.'.format(target_ds)) + continue + + # load pretrained model + if state_file.exists() and not OVERWRITE: + # load pretrained network + net, _ = Network.load_pretrained_model(state_file, NET) + else: + # initialize network and optimizer + LogConfig.init_log('Initializing network and optimizer.') + + # instanciate network + net = NET(state_file, inputs, outputs, filters=FILTERS) + + # initialize optimizer + optimizer = OPTIM(net.parameters(), **OPTIM_PARAMS) + + # initialize learning rate scheduler + if LR_SCHEDULER is not None: + LR_SCHEDULER = LR_SCHEDULER(optimizer, **LR_SCHEDULER_PARAMS) + + + # initialize network trainer + trainer = NetworkTrainer(net, optimizer, net.state_file, train_dl, + valid_dl, loss_function=LOSS, + lr_scheduler=LR_SCHEDULER, **TRAIN_CONFIG) + + # train model + state = trainer.train() + + # predict reference period + LogConfig.init_log('Predicting reference period: {}'.format( + ' - '.join([str(VALID_PERIOD[0]), str(VALID_PERIOD[-1])]))) + + # subset to reference period and predict in NYEAR intervals + trg_ds = [] + for dates in split_date_range(VALID_PERIOD[0], VALID_PERIOD[-1], + years=NYEARS): + LogConfig.init_log('Predicting period: {}'.format( + ' - '.join([str(dates[0]), str(dates[-1])]))) + ref_ds = Era5_ds.sel(time=dates) + trg_ds.append(predict_ERA5(net, ref_ds, PREDICTAND, LOSS, + normalize=NORM, batch_size=BATCH_SIZE, + doy=DOY, anomalies=ANOMALIES)) + + # merge predictions for entire validation period + LOGGER.info('Merging reference periods ...') + trg_ds = xr.concat(trg_ds, dim='time') + trg_ds = trg_ds.sortby(trg_ds.time) # sort predictions chronologically + + # save model predictions as NetCDF file + LOGGER.info('Saving network predictions: {}.'.format(target_ds)) + trg_ds.to_netcdf(target_ds, engine='h5netcdf') + + # log execution time of script + LogConfig.init_log('Execution time of script {}: {}' + .format(__file__, timedelta(seconds=time.monotonic() - + start_time)))