diff --git a/climax/main/downscale.py b/climax/main/downscale.py new file mode 100644 index 0000000000000000000000000000000000000000..528062a33938872d9085848de5c2e2c970d45496 --- /dev/null +++ b/climax/main/downscale.py @@ -0,0 +1,221 @@ +"""Dynamical climate downscaling using deep convolutional neural networks.""" + +# !/usr/bin/env python +# -*- coding: utf-8 -*- + +# builtins +import sys +import time +import logging +from datetime import timedelta +from logging.config import dictConfig + +# externals +import xarray as xr +from sklearn.model_selection import train_test_split +from torch.utils.data import DataLoader + +# locals +from pysegcnn.core.utils import search_files +from pysegcnn.core.models import Network +from pysegcnn.core.trainer import NetworkTrainer, LogConfig +from pysegcnn.core.logging import log_conf +from climax.core.dataset import ERA5Dataset, NetCDFDataset +from climax.core.loss import MSELoss, L1Loss +from climax.core.predict import predict_ERA5 +from climax.core.utils import split_date_range +from climax.main.config import (ERA5_PLEVELS, ERA5_PREDICTORS, PREDICTAND, + CALIB_PERIOD, DOY, SHUFFLE, BATCH_SIZE, OPTIM, + NORM, TRAIN_CONFIG, NET, LOSS, FILTERS, + OVERWRITE, DEM, DEM_FEATURES, STRATIFY, + WET_DAY_THRESHOLD, VALID_SIZE, ANOMALIES, + OPTIM_PARAMS, LR_SCHEDULER, SENSITIVITY, + LR_SCHEDULER_PARAMS, CHUNKS, VALID_PERIOD, + NYEARS) +from climax.main.io import (ERA5_PATH, OBS_PATH, DEM_PATH, MODEL_PATH, + TARGET_PATH) + +# module level logger +LOGGER = logging.getLogger(__name__) + + +if __name__ == '__main__': + + # initialize timing + start_time = time.monotonic() + + # initialize network filename + state_file = ERA5Dataset.state_file( + NET, PREDICTAND, ERA5_PREDICTORS, ERA5_PLEVELS, dem=DEM, + dem_features=DEM_FEATURES, doy=DOY, loss=LOSS, anomalies=ANOMALIES, + decay=OPTIM_PARAMS['weight_decay'], optim=OPTIM, lr=OPTIM_PARAMS['lr'], + lr_scheduler=LR_SCHEDULER) + + # path to model state + if SENSITIVITY: + # models trained for hyperparameter optimization + state_file = MODEL_PATH.joinpath('sensitivity', PREDICTAND, state_file) + target = TARGET_PATH.joinpath('sensitivity', PREDICTAND) + else: + state_file = MODEL_PATH.joinpath(PREDICTAND, state_file) + target = TARGET_PATH.joinpath(PREDICTAND) + + # initialize logging + log_file = state_file.parent.joinpath( + state_file.name.replace(state_file.suffix, '_log.txt')) + if log_file.exists(): + log_file.unlink() + dictConfig(log_conf(log_file)) + + # check if target dataset already exists + target = target.joinpath(state_file.name.replace(state_file.suffix, '.nc')) + if target.exists() and not OVERWRITE: + LogConfig.init_log('{} already exists.'.format(target)) + sys.exit() + + # load pretrained model + if state_file.exists() and not OVERWRITE: + # load pretrained network + net, _ = Network.load_pretrained_model(state_file, NET) + else: + # initialize downscaling + LogConfig.init_log('Initializing downscaling for period: {}'.format( + ' - '.join([str(CALIB_PERIOD[0]), str(CALIB_PERIOD[-1])]))) + + # check if model exists + if state_file.exists() and not OVERWRITE: + # load pretrained network + LogConfig.init_log('{} already exists.'.format(state_file)) + sys.exit() + + # initialize ERA5 predictor dataset + LogConfig.init_log('Initializing ERA5 predictors.') + Era5 = ERA5Dataset(ERA5_PATH.joinpath('ERA5'), ERA5_PREDICTORS, + plevels=ERA5_PLEVELS) + Era5_ds = Era5.merge(chunks=CHUNKS) + + # initialize OBS predictand dataset + LogConfig.init_log('Initializing observations for predictand: {}' + .format(PREDICTAND)) + + # check whether to joinlty train tasmin and tasmax + if PREDICTAND == 'tas': + # read both tasmax and tasmin + tasmax = xr.open_dataset( + search_files(OBS_PATH.joinpath('tasmax'), '.nc$').pop()) + tasmin = xr.open_dataset( + search_files(OBS_PATH.joinpath('tasmin'), '.nc$').pop()) + Obs_ds = xr.merge([tasmax, tasmin]) + else: + # read in-situ gridded observations + Obs_ds = search_files(OBS_PATH.joinpath(PREDICTAND), '.nc$').pop() + Obs_ds = xr.open_dataset(Obs_ds) + + # whether to use digital elevation model + if DEM: + # digital elevation model: Copernicus EU-Dem v1.1 + dem = search_files(DEM_PATH, '^eu_dem_v11_stt.nc$').pop() + + # read elevation and compute slope and aspect + dem = ERA5Dataset.dem_features( + dem, {'y': Era5_ds.y, 'x': Era5_ds.x}, + add_coord={'time': Era5_ds.time}) + + # check whether to use slope and aspect + if not DEM_FEATURES: + dem = dem.drop_vars(['slope', 'aspect']).chunk(Era5_ds.chunks) + + # add dem to set of predictor variables + Era5_ds = xr.merge([Era5_ds, dem]) + + # initialize network and optimizer + LogConfig.init_log('Initializing network and optimizer.') + + # define number of output fields + # check whether modelling pr with probabilistic approach + outputs = len(Obs_ds.data_vars) + if PREDICTAND == 'pr': + outputs = (1 if (isinstance(LOSS, MSELoss) or + isinstance(LOSS, L1Loss)) else 3) + + # instanciate network + inputs = len(Era5_ds.data_vars) + 2 if DOY else len(Era5_ds.data_vars) + net = NET(state_file, inputs, outputs, filters=FILTERS) + + # initialize optimizer + optimizer = OPTIM(net.parameters(), **OPTIM_PARAMS) + + # initialize learning rate scheduler + if LR_SCHEDULER is not None: + LR_SCHEDULER = LR_SCHEDULER(optimizer, **LR_SCHEDULER_PARAMS) + + # initialize training data + LogConfig.init_log('Initializing training data.') + + # split calibration period into training and validation period + if PREDICTAND == 'pr' and STRATIFY: + # stratify training and validation dataset by number of + # observed wet days for precipitation + wet_days = (Obs_ds.sel(time=CALIB_PERIOD).mean(dim=('y', 'x')) + >= WET_DAY_THRESHOLD).to_array().values.squeeze() + train, valid = train_test_split( + CALIB_PERIOD, stratify=wet_days, test_size=VALID_SIZE) + + # sort chronologically + train, valid = sorted(train), sorted(valid) + else: + train, valid = train_test_split(CALIB_PERIOD, shuffle=False, + test_size=VALID_SIZE) + + # training and validation dataset + Era5_train, Obs_train = Era5_ds.sel(time=train), Obs_ds.sel(time=train) + Era5_valid, Obs_valid = Era5_ds.sel(time=valid), Obs_ds.sel(time=valid) + + # create PyTorch compliant dataset and dataloader instances for model + # training + train_ds = NetCDFDataset(Era5_train, Obs_train, normalize=NORM, + doy=DOY, anomalies=ANOMALIES) + valid_ds = NetCDFDataset(Era5_valid, Obs_valid, normalize=NORM, + doy=DOY, anomalies=ANOMALIES) + train_dl = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=SHUFFLE, + drop_last=False) + valid_dl = DataLoader(valid_ds, batch_size=BATCH_SIZE, shuffle=SHUFFLE, + drop_last=False) + + # initialize network trainer + trainer = NetworkTrainer(net, optimizer, net.state_file, train_dl, + valid_dl, loss_function=LOSS, + lr_scheduler=LR_SCHEDULER, **TRAIN_CONFIG) + + # train model + state = trainer.train() + + # predict reference period + LogConfig.init_log('Predicting reference period: {}'.format( + ' - '.join([str(VALID_PERIOD[0]), str(VALID_PERIOD[-1])]))) + + # subset to reference period and predict in NYEAR intervals + trg_ds = [] + for dates in split_date_range(VALID_PERIOD[0], VALID_PERIOD[-1], + years=NYEARS): + LogConfig.init_log('Predicting period: {}'.format( + ' - '.join([str(dates[0]), str(dates[-1])]))) + ref_ds = Era5_ds.sel(time=dates) + trg_ds.append(predict_ERA5(net, ref_ds, PREDICTAND, LOSS, + normalize=NORM, batch_size=BATCH_SIZE, + doy=DOY, anomalies=ANOMALIES)) + + # merge predictions for entire validation period + LOGGER.info('Merging reference periods ...') + trg_ds = xr.concat(trg_ds, dim='time') + + # save model predictions as NetCDF file + if not target.parent.exists(): + target.parent.mkdir(parents=True, exist_ok=True) + LOGGER.info('Saving network predictions: {}.'.format(target)) + trg_ds.to_netcdf(target, engine='h5netcdf') + + # log execution time of script + LogConfig.init_log('Execution time of script {}: {}' + .format(__file__, timedelta(seconds=time.monotonic() - + start_time))) diff --git a/climax/main/downscale_infer.py b/climax/main/downscale_infer.py deleted file mode 100644 index 0a7945fb1204c6770660e6f6393c383d88e63616..0000000000000000000000000000000000000000 --- a/climax/main/downscale_infer.py +++ /dev/null @@ -1,127 +0,0 @@ -"""Dynamical climate downscaling using deep convolutional neural networks.""" - -# !/usr/bin/env python -# -*- coding: utf-8 -*- - -# builtins -import sys -import time -import logging -from datetime import timedelta -from logging.config import dictConfig - -# externals -import xarray as xr - -# locals -from pysegcnn.core.trainer import LogConfig -from pysegcnn.core.models import Network -from pysegcnn.core.logging import log_conf -from pysegcnn.core.utils import search_files -from climax.core.dataset import ERA5Dataset -from climax.core.predict import predict_ERA5 -from climax.core.utils import split_date_range -from climax.main.config import (ERA5_PREDICTORS, ERA5_PLEVELS, PREDICTAND, NET, - VALID_PERIOD, BATCH_SIZE, NORM, DOY, NYEARS, - DEM, DEM_FEATURES, LOSS, ANOMALIES, OPTIM, - OPTIM_PARAMS, CHUNKS, LR_SCHEDULER, OVERWRITE, - SENSITIVITY) -from climax.main.io import ERA5_PATH, DEM_PATH, MODEL_PATH, TARGET_PATH - -# module level logger -LOGGER = logging.getLogger(__name__) - - -if __name__ == '__main__': - - # initialize timing - start_time = time.monotonic() - - # initialize network filename - state_file = ERA5Dataset.state_file( - NET, PREDICTAND, ERA5_PREDICTORS, ERA5_PLEVELS, dem=DEM, - dem_features=DEM_FEATURES, doy=DOY, loss=LOSS, anomalies=ANOMALIES, - decay=OPTIM_PARAMS['weight_decay'], optim=OPTIM, lr=OPTIM_PARAMS['lr'], - lr_scheduler=LR_SCHEDULER) - - # path to model state - if SENSITIVITY: - # models trained for hyperparameter optimization - state_file = MODEL_PATH.joinpath('sensitivity', PREDICTAND, state_file) - target = TARGET_PATH.joinpath('sensitivity', PREDICTAND) - else: - state_file = MODEL_PATH.joinpath(PREDICTAND, state_file) - target = TARGET_PATH.joinpath(PREDICTAND) - - # check if target dataset already exists - target = target.joinpath(state_file.name.replace(state_file.suffix, '.nc')) - if target.exists() and not OVERWRITE: - LogConfig.init_log('{} already exists.'.format(target)) - sys.exit() - else: - # load pretrained model - if state_file.exists(): - # load pretrained network - net, _ = Network.load_pretrained_model(state_file, NET) - else: - # initialize OBS predictand dataset - LOGGER.info('{} does not exist.'.format(state_file)) - sys.exit() - - # initialize logging - log_file = state_file.parent.joinpath( - state_file.name.replace(state_file.suffix, '_log.txt')) - dictConfig(log_conf(log_file)) - - # predict reference period - LogConfig.init_log('Predicting reference period: {}'.format( - ' - '.join([str(VALID_PERIOD[0]), str(VALID_PERIOD[-1])]))) - - # initialize ERA5 predictor dataset - LogConfig.init_log('Initializing ERA5 predictors.') - Era5 = ERA5Dataset(ERA5_PATH.joinpath('ERA5'), ERA5_PREDICTORS, - plevels=ERA5_PLEVELS) - Era5_ds = Era5.merge(chunks=CHUNKS) - - # whether to use digital elevation model - if DEM: - # digital elevation model: Copernicus EU-Dem v1.1 - dem = search_files(DEM_PATH, '^eu_dem_v11_stt.nc$').pop() - - # read elevation and compute slope and aspect - dem = ERA5Dataset.dem_features( - dem, {'y': Era5_ds.y, 'x': Era5_ds.x}, - add_coord={'time': Era5_ds.time}) - - # check whether to use slope and aspect - if not DEM_FEATURES: - dem = dem.drop_vars(['slope', 'aspect']) - - # add dem to set of predictor variables - Era5_ds = xr.merge([Era5_ds, dem]).chunk(Era5_ds.chunks) - - # subset to reference period and predict in NYEAR intervals - trg_ds = [] - for dates in split_date_range(VALID_PERIOD[0], VALID_PERIOD[-1], - years=NYEARS): - LogConfig.init_log('Predicting period: {}'.format( - ' - '.join([str(dates[0]), str(dates[-1])]))) - ref_ds = Era5_ds.sel(time=dates) - trg_ds.append(predict_ERA5(net, ref_ds, PREDICTAND, LOSS, - normalize=NORM, batch_size=BATCH_SIZE, - doy=DOY, anomalies=ANOMALIES)) - - # merge predictions for entire validation period - LOGGER.info('Merging reference periods ...') - trg_ds = xr.concat(trg_ds, dim='time') - - # save model predictions as NetCDF file - if not target.parent.exists(): - target.parent.mkdir(parents=True, exist_ok=True) - LOGGER.info('Saving network predictions: {}.'.format(target)) - trg_ds.to_netcdf(target, engine='h5netcdf') - - # log execution time of script - LogConfig.init_log('Execution time of script {}: {}' - .format(__file__, timedelta(seconds=time.monotonic() - - start_time))) diff --git a/climax/main/downscale_train.py b/climax/main/downscale_train.py deleted file mode 100644 index 87e994bd54c3628e0293307ca0d8dc6088c3e095..0000000000000000000000000000000000000000 --- a/climax/main/downscale_train.py +++ /dev/null @@ -1,178 +0,0 @@ -"""Dynamical climate downscaling using deep convolutional neural networks.""" - -# !/usr/bin/env python -# -*- coding: utf-8 -*- - -# builtins -import sys -import time -import logging -from datetime import timedelta -from logging.config import dictConfig - -# externals -import xarray as xr -from sklearn.model_selection import train_test_split -from torch.utils.data import DataLoader - -# locals -from pysegcnn.core.utils import search_files -from pysegcnn.core.trainer import NetworkTrainer, LogConfig -from pysegcnn.core.logging import log_conf -from climax.core.dataset import ERA5Dataset, NetCDFDataset -from climax.core.loss import MSELoss, L1Loss -from climax.main.config import (ERA5_PLEVELS, ERA5_PREDICTORS, PREDICTAND, - CALIB_PERIOD, DOY, SHUFFLE, BATCH_SIZE, OPTIM, - NORM, TRAIN_CONFIG, NET, LOSS, FILTERS, - OVERWRITE, DEM, DEM_FEATURES, STRATIFY, - WET_DAY_THRESHOLD, VALID_SIZE, ANOMALIES, - OPTIM_PARAMS, LR_SCHEDULER, SENSITIVITY, - LR_SCHEDULER_PARAMS, CHUNKS) -from climax.main.io import ERA5_PATH, OBS_PATH, DEM_PATH, MODEL_PATH - -# module level logger -LOGGER = logging.getLogger(__name__) - - -if __name__ == '__main__': - - # initialize timing - start_time = time.monotonic() - - # initialize network filename - state_file = ERA5Dataset.state_file( - NET, PREDICTAND, ERA5_PREDICTORS, ERA5_PLEVELS, dem=DEM, - dem_features=DEM_FEATURES, doy=DOY, loss=LOSS, anomalies=ANOMALIES, - decay=OPTIM_PARAMS['weight_decay'], optim=OPTIM, lr=OPTIM_PARAMS['lr'], - lr_scheduler=LR_SCHEDULER) - - # path to model state - if SENSITIVITY: - # models trained for hyperparameter optimization - state_file = MODEL_PATH.joinpath('sensitivity', PREDICTAND, state_file) - else: - state_file = MODEL_PATH.joinpath(PREDICTAND, state_file) - - # initialize logging - log_file = state_file.parent.joinpath( - state_file.name.replace(state_file.suffix, '_log.txt')) - if log_file.exists(): - log_file.unlink() - dictConfig(log_conf(log_file)) - - # initialize downscaling - LogConfig.init_log('Initializing downscaling for period: {}'.format( - ' - '.join([str(CALIB_PERIOD[0]), str(CALIB_PERIOD[-1])]))) - - # check if model exists - if state_file.exists() and not OVERWRITE: - # load pretrained network - LogConfig.init_log('{} already exists.'.format(state_file)) - sys.exit() - - # initialize ERA5 predictor dataset - LogConfig.init_log('Initializing ERA5 predictors.') - Era5 = ERA5Dataset(ERA5_PATH.joinpath('ERA5'), ERA5_PREDICTORS, - plevels=ERA5_PLEVELS) - Era5_ds = Era5.merge(chunks=CHUNKS) - - # initialize OBS predictand dataset - LogConfig.init_log('Initializing observations for predictand: {}' - .format(PREDICTAND)) - - # check whether to joinlty train tasmin and tasmax - if PREDICTAND == 'tas': - # read both tasmax and tasmin - tasmax = xr.open_dataset( - search_files(OBS_PATH.joinpath('tasmax'), '.nc$').pop()) - tasmin = xr.open_dataset( - search_files(OBS_PATH.joinpath('tasmin'), '.nc$').pop()) - Obs_ds = xr.merge([tasmax, tasmin]) - else: - # read in-situ gridded observations - Obs_ds = search_files(OBS_PATH.joinpath(PREDICTAND), '.nc$').pop() - Obs_ds = xr.open_dataset(Obs_ds) - - # whether to use digital elevation model - if DEM: - # digital elevation model: Copernicus EU-Dem v1.1 - dem = search_files(DEM_PATH, '^eu_dem_v11_stt.nc$').pop() - - # read elevation and compute slope and aspect - dem = ERA5Dataset.dem_features( - dem, {'y': Era5_ds.y, 'x': Era5_ds.x}, - add_coord={'time': Era5_ds.time}) - - # check whether to use slope and aspect - if not DEM_FEATURES: - dem = dem.drop_vars(['slope', 'aspect']).chunk(Era5_ds.chunks) - - # add dem to set of predictor variables - Era5_ds = xr.merge([Era5_ds, dem]) - - # initialize network and optimizer - LogConfig.init_log('Initializing network and optimizer.') - - # define number of output fields - # check whether modelling pr with probabilistic approach - outputs = len(Obs_ds.data_vars) - if PREDICTAND == 'pr': - outputs = (1 if (isinstance(LOSS, MSELoss) or isinstance(LOSS, L1Loss)) - else 3) - - # instanciate network - inputs = len(Era5_ds.data_vars) + 2 if DOY else len(Era5_ds.data_vars) - net = NET(state_file, inputs, outputs, filters=FILTERS) - - # initialize optimizer - optimizer = OPTIM(net.parameters(), **OPTIM_PARAMS) - - # initialize learning rate scheduler - if LR_SCHEDULER is not None: - LR_SCHEDULER = LR_SCHEDULER(optimizer, **LR_SCHEDULER_PARAMS) - - # initialize training data - LogConfig.init_log('Initializing training data.') - - # split calibration period into training and validation period - if PREDICTAND == 'pr' and STRATIFY: - # stratify training and validation dataset by number of - # observed wet days for precipitation - wet_days = (Obs_ds.sel(time=CALIB_PERIOD).mean(dim=('y', 'x')) - >= WET_DAY_THRESHOLD).to_array().values.squeeze() - train, valid = train_test_split( - CALIB_PERIOD, stratify=wet_days, test_size=VALID_SIZE) - - # sort chronologically - train, valid = sorted(train), sorted(valid) - else: - train, valid = train_test_split(CALIB_PERIOD, shuffle=False, - test_size=VALID_SIZE) - - # training and validation dataset - Era5_train, Obs_train = Era5_ds.sel(time=train), Obs_ds.sel(time=train) - Era5_valid, Obs_valid = Era5_ds.sel(time=valid), Obs_ds.sel(time=valid) - - # create PyTorch compliant dataset and dataloader instances for model - # training - train_ds = NetCDFDataset(Era5_train, Obs_train, normalize=NORM, doy=DOY, - anomalies=ANOMALIES) - valid_ds = NetCDFDataset(Era5_valid, Obs_valid, normalize=NORM, doy=DOY, - anomalies=ANOMALIES) - train_dl = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=SHUFFLE, - drop_last=False) - valid_dl = DataLoader(valid_ds, batch_size=BATCH_SIZE, shuffle=SHUFFLE, - drop_last=False) - - # initialize network trainer - trainer = NetworkTrainer(net, optimizer, net.state_file, train_dl, - valid_dl, loss_function=LOSS, - lr_scheduler=LR_SCHEDULER, **TRAIN_CONFIG) - - # train model - state = trainer.train() - - # log execution time of script - LogConfig.init_log('Execution time of script {}: {}' - .format(__file__, timedelta(seconds=time.monotonic() - - start_time)))