Skip to content
Snippets Groups Projects
Commit 27aa5863 authored by Frisinghelli Daniel's avatar Frisinghelli Daniel
Browse files

Implementation of bootstrapped model training.

parent b81925ec
No related branches found
No related tags found
No related merge requests found
......@@ -92,6 +92,9 @@ VALID_SIZE = 0.2
# number of folds for training with KFold cross-validation
CV = 5
# number of bootstrapped model trainings
BOOTSTRAP = 10
# -----------------------------------------------------------------------------
# Observations ----------------------------------------------------------------
# -----------------------------------------------------------------------------
......
......@@ -60,6 +60,10 @@ if __name__ == '__main__':
state_file = MODEL_PATH.joinpath(PREDICTAND, state_file)
target = TARGET_PATH.joinpath(PREDICTAND)
# check if output path exists
if not target.exists():
target.mkdir(parents=True, exist_ok=True)
# initialize logging
log_file = state_file.parent.joinpath(
state_file.name.replace(state_file.suffix, '_log.txt'))
......@@ -82,12 +86,6 @@ if __name__ == '__main__':
LogConfig.init_log('Initializing downscaling for period: {}'.format(
' - '.join([str(CALIB_PERIOD[0]), str(CALIB_PERIOD[-1])])))
# check if model exists
if state_file.exists() and not OVERWRITE:
# load pretrained network
LogConfig.init_log('{} already exists.'.format(state_file))
sys.exit()
# initialize ERA5 predictor dataset
LogConfig.init_log('Initializing ERA5 predictors.')
Era5 = ERA5Dataset(ERA5_PATH.joinpath('ERA5'), ERA5_PREDICTORS,
......@@ -131,15 +129,17 @@ if __name__ == '__main__':
# initialize network and optimizer
LogConfig.init_log('Initializing network and optimizer.')
# define number of output fields
# define number of output fields:
# check whether modelling pr with probabilistic approach
outputs = len(Obs_ds.data_vars)
if PREDICTAND == 'pr':
outputs = (1 if (isinstance(LOSS, MSELoss) or
isinstance(LOSS, L1Loss)) else 3)
# instanciate network
# define number of input fields
inputs = len(Era5_ds.data_vars) + 2 if DOY else len(Era5_ds.data_vars)
# instanciate network
net = NET(state_file, inputs, outputs, filters=FILTERS)
# initialize optimizer
......@@ -208,10 +208,9 @@ if __name__ == '__main__':
# merge predictions for entire validation period
LOGGER.info('Merging reference periods ...')
trg_ds = xr.concat(trg_ds, dim='time')
trg_ds = trg_ds.sortby(trg_ds.time) # sort predictions chronologically
# save model predictions as NetCDF file
if not target.parent.exists():
target.parent.mkdir(parents=True, exist_ok=True)
LOGGER.info('Saving network predictions: {}.'.format(target))
trg_ds.to_netcdf(target, engine='h5netcdf')
......
"""Dynamical climate downscaling using deep convolutional neural networks."""
# !/usr/bin/env python
# -*- coding: utf-8 -*-
# builtins
import sys
import time
import logging
from datetime import timedelta
from logging.config import dictConfig
# externals
import xarray as xr
from sklearn.model_selection import train_test_split
from torch.utils.data import DataLoader
# locals
from pysegcnn.core.utils import search_files
from pysegcnn.core.models import Network
from pysegcnn.core.trainer import NetworkTrainer, LogConfig
from pysegcnn.core.logging import log_conf
from climax.core.dataset import ERA5Dataset, NetCDFDataset
from climax.core.loss import MSELoss, L1Loss
from climax.core.predict import predict_ERA5
from climax.core.utils import split_date_range
from climax.main.config import (ERA5_PLEVELS, ERA5_PREDICTORS, PREDICTAND,
CALIB_PERIOD, DOY, SHUFFLE, BATCH_SIZE, OPTIM,
NORM, TRAIN_CONFIG, NET, LOSS, FILTERS,
OVERWRITE, DEM, DEM_FEATURES, STRATIFY,
WET_DAY_THRESHOLD, VALID_SIZE, ANOMALIES,
OPTIM_PARAMS, LR_SCHEDULER, SENSITIVITY,
LR_SCHEDULER_PARAMS, CHUNKS, VALID_PERIOD,
NYEARS, BOOTSTRAP)
from climax.main.io import (ERA5_PATH, OBS_PATH, DEM_PATH, MODEL_PATH,
TARGET_PATH)
# module level logger
LOGGER = logging.getLogger(__name__)
if __name__ == '__main__':
# initialize timing
start_time = time.monotonic()
# initialize network filename
state_file = ERA5Dataset.state_file(
NET, PREDICTAND, ERA5_PREDICTORS, ERA5_PLEVELS, dem=DEM,
dem_features=DEM_FEATURES, doy=DOY, loss=LOSS, anomalies=ANOMALIES,
decay=OPTIM_PARAMS['weight_decay'], optim=OPTIM, lr=OPTIM_PARAMS['lr'],
lr_scheduler=LR_SCHEDULER)
# models trained with bootstrapping
state_file = MODEL_PATH.joinpath('bootstrap', PREDICTAND, state_file)
target = TARGET_PATH.joinpath('bootstrap', PREDICTAND, state_file.stem)
# check if output path exists
if not target.exists():
target.mkdir(parents=True, exist_ok=True)
# initialize logging
log_file = state_file.parent.joinpath(
state_file.name.replace(state_file.suffix, '_log.txt'))
if log_file.exists():
log_file.unlink()
dictConfig(log_conf(log_file))
# initialize downscaling
LogConfig.init_log('Initializing downscaling for period: {}'.format(
' - '.join([str(CALIB_PERIOD[0]), str(CALIB_PERIOD[-1])])))
# initialize ERA5 predictor dataset
LogConfig.init_log('Initializing ERA5 predictors.')
Era5 = ERA5Dataset(ERA5_PATH.joinpath('ERA5'), ERA5_PREDICTORS,
plevels=ERA5_PLEVELS)
Era5_ds = Era5.merge(chunks=CHUNKS)
# initialize OBS predictand dataset
LogConfig.init_log('Initializing observations for predictand: {}'
.format(PREDICTAND))
# check whether to joinlty train tasmin and tasmax
if PREDICTAND == 'tas':
# read both tasmax and tasmin
tasmax = xr.open_dataset(
search_files(OBS_PATH.joinpath('tasmax'), '.nc$').pop())
tasmin = xr.open_dataset(
search_files(OBS_PATH.joinpath('tasmin'), '.nc$').pop())
Obs_ds = xr.merge([tasmax, tasmin])
else:
# read in-situ gridded observations
Obs_ds = search_files(OBS_PATH.joinpath(PREDICTAND),
'.nc$').pop()
Obs_ds = xr.open_dataset(Obs_ds)
# whether to use digital elevation model
if DEM:
# digital elevation model: Copernicus EU-Dem v1.1
dem = search_files(DEM_PATH, '^eu_dem_v11_stt.nc$').pop()
# read elevation and compute slope and aspect
dem = ERA5Dataset.dem_features(
dem, {'y': Era5_ds.y, 'x': Era5_ds.x},
add_coord={'time': Era5_ds.time})
# check whether to use slope and aspect
if not DEM_FEATURES:
dem = dem.drop_vars(['slope', 'aspect']).chunk(
Era5_ds.chunks)
# add dem to set of predictor variables
Era5_ds = xr.merge([Era5_ds, dem])
# define number of output fields:
# check whether modelling pr with probabilistic approach
outputs = len(Obs_ds.data_vars)
if PREDICTAND == 'pr':
outputs = (1 if (isinstance(LOSS, MSELoss) or
isinstance(LOSS, L1Loss)) else 3)
# define number of input fields
inputs = len(Era5_ds.data_vars) + 2 if DOY else len(Era5_ds.data_vars)
# initialize training data
LogConfig.init_log('Initializing training data.')
# split calibration period into training and validation period
if PREDICTAND == 'pr' and STRATIFY:
# stratify training and validation dataset by number of
# observed wet days for precipitation
wet_days = (Obs_ds.sel(time=CALIB_PERIOD).mean(dim=('y', 'x'))
>= WET_DAY_THRESHOLD).to_array().values.squeeze()
train, valid = train_test_split(
CALIB_PERIOD, stratify=wet_days, test_size=VALID_SIZE)
# sort chronologically
train, valid = sorted(train), sorted(valid)
else:
train, valid = train_test_split(CALIB_PERIOD, shuffle=False,
test_size=VALID_SIZE)
# training and validation dataset
Era5_train, Obs_train = Era5_ds.sel(time=train), Obs_ds.sel(time=train)
Era5_valid, Obs_valid = Era5_ds.sel(time=valid), Obs_ds.sel(time=valid)
# create PyTorch compliant dataset and dataloader instances for model
# training
train_ds = NetCDFDataset(Era5_train, Obs_train, normalize=NORM, doy=DOY,
anomalies=ANOMALIES)
valid_ds = NetCDFDataset(Era5_valid, Obs_valid, normalize=NORM, doy=DOY,
anomalies=ANOMALIES)
train_dl = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=SHUFFLE,
drop_last=False)
valid_dl = DataLoader(valid_ds, batch_size=BATCH_SIZE, shuffle=SHUFFLE,
drop_last=False)
# initialize bootstrapped model training
for i in range(BOOTSTRAP):
# add suffix for n'th bootstrap
state_file = state_file.parent.joinpath(state_file.name.replace(
state_file.suffix, '_{}.pt'.format(i + 1)))
# check if target dataset already exists
target_ds = target.joinpath(state_file.name.replace(
state_file.suffix, '.nc'))
if target_ds.exists() and not OVERWRITE:
LogConfig.init_log('{} already exists.'.format(target_ds))
continue
# load pretrained model
if state_file.exists() and not OVERWRITE:
# load pretrained network
net, _ = Network.load_pretrained_model(state_file, NET)
else:
# initialize network and optimizer
LogConfig.init_log('Initializing network and optimizer.')
# instanciate network
net = NET(state_file, inputs, outputs, filters=FILTERS)
# initialize optimizer
optimizer = OPTIM(net.parameters(), **OPTIM_PARAMS)
# initialize learning rate scheduler
if LR_SCHEDULER is not None:
LR_SCHEDULER = LR_SCHEDULER(optimizer, **LR_SCHEDULER_PARAMS)
# initialize network trainer
trainer = NetworkTrainer(net, optimizer, net.state_file, train_dl,
valid_dl, loss_function=LOSS,
lr_scheduler=LR_SCHEDULER, **TRAIN_CONFIG)
# train model
state = trainer.train()
# predict reference period
LogConfig.init_log('Predicting reference period: {}'.format(
' - '.join([str(VALID_PERIOD[0]), str(VALID_PERIOD[-1])])))
# subset to reference period and predict in NYEAR intervals
trg_ds = []
for dates in split_date_range(VALID_PERIOD[0], VALID_PERIOD[-1],
years=NYEARS):
LogConfig.init_log('Predicting period: {}'.format(
' - '.join([str(dates[0]), str(dates[-1])])))
ref_ds = Era5_ds.sel(time=dates)
trg_ds.append(predict_ERA5(net, ref_ds, PREDICTAND, LOSS,
normalize=NORM, batch_size=BATCH_SIZE,
doy=DOY, anomalies=ANOMALIES))
# merge predictions for entire validation period
LOGGER.info('Merging reference periods ...')
trg_ds = xr.concat(trg_ds, dim='time')
trg_ds = trg_ds.sortby(trg_ds.time) # sort predictions chronologically
# save model predictions as NetCDF file
LOGGER.info('Saving network predictions: {}.'.format(target_ds))
trg_ds.to_netcdf(target_ds, engine='h5netcdf')
# log execution time of script
LogConfig.init_log('Execution time of script {}: {}'
.format(__file__, timedelta(seconds=time.monotonic() -
start_time)))
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment