Skip to content
Snippets Groups Projects
Commit bd79fde1 authored by Frisinghelli Daniel's avatar Frisinghelli Daniel
Browse files

Implemented training with cross-validation.

parent 80149582
No related branches found
No related tags found
No related merge requests found
......@@ -55,6 +55,9 @@ DEM_FEATURES = False
# stratify training/validation set for precipitation by number of wet days
STRATIFY = True
# whether to train using cross-validation
CV = False
# -----------------------------------------------------------------------------
# Observations ----------------------------------------------------------------
# -----------------------------------------------------------------------------
......
......@@ -24,7 +24,7 @@ from climax.core.utils import split_date_range
from climax.core.loss import BernoulliGammaLoss
from climax.main.config import (ERA5_PREDICTORS, ERA5_PLEVELS, PREDICTAND, NET,
VALID_PERIOD, BATCH_SIZE, NORM, DOY, NYEARS,
DEM, DEM_FEATURES, LOSS)
DEM, DEM_FEATURES, LOSS, CV)
from climax.main.io import ERA5_PATH, DEM_PATH, MODEL_PATH, TARGET_PATH
# module level logger
......@@ -51,6 +51,9 @@ if __name__ == '__main__':
state_file = state_file.replace('.pt', '_{}.pt'.format(
repr(LOSS).strip('()')))
# add suffix for training with cross-validation
state_file = state_file.replace('.pt', '_cv.pt') if CV else state_file
# path to model state
state_file = MODEL_PATH.joinpath(PREDICTAND, state_file)
......
......@@ -12,7 +12,7 @@ from logging.config import dictConfig
# externals
import torch
import xarray as xr
from sklearn.model_selection import train_test_split
from sklearn.model_selection import train_test_split, TimeSeriesSplit
from torch.utils.data import DataLoader
# locals
......@@ -26,7 +26,7 @@ from climax.main.config import (ERA5_PLEVELS, ERA5_PREDICTORS, PREDICTAND,
CALIB_PERIOD, DOY, SHUFFLE, BATCH_SIZE, LR,
LAMBDA, NORM, TRAIN_CONFIG, NET, LOSS, FILTERS,
OVERWRITE, DEM, DEM_FEATURES, STRATIFY,
WET_DAY_THRESHOLD)
WET_DAY_THRESHOLD, CV)
from climax.main.io import ERA5_PATH, OBS_PATH, DEM_PATH, MODEL_PATH
# module level logger
......@@ -53,6 +53,9 @@ if __name__ == '__main__':
state_file = state_file.replace('.pt', '_{}.pt'.format(
repr(LOSS).strip('()')))
# add suffix for training with cross-validation
state_file = state_file.replace('.pt', '_cv.pt') if CV else state_file
# path to model state
state_file = MODEL_PATH.joinpath(PREDICTAND, state_file)
......@@ -112,36 +115,8 @@ if __name__ == '__main__':
# add dem to set of predictor variables
Era5_ds = xr.merge([Era5_ds, dem])
# initialize training data
LogConfig.init_log('Initializing training data.')
# split calibration period into training and validation period
if PREDICTAND == 'pr' and STRATIFY:
# stratify training and validation dataset by number of observed
# wet days for precipitation
wet_days = (Obs_ds.sel(time=CALIB_PERIOD).mean(dim=('y', 'x')) >=
WET_DAY_THRESHOLD).to_array().values.squeeze()
train, valid = train_test_split(
CALIB_PERIOD, stratify=wet_days, test_size=0.1)
train, valid = sorted(train), sorted(valid) # sort chronologically
else:
train, valid = train_test_split(CALIB_PERIOD, shuffle=False,
test_size=0.1)
# training and validation dataset
Era5_train, Obs_train = Era5_ds.sel(time=train), Obs_ds.sel(time=train)
Era5_valid, Obs_valid = Era5_ds.sel(time=valid), Obs_ds.sel(time=valid)
# create PyTorch compliant dataset and dataloader instances for model
# training
train_ds = NetCDFDataset(Era5_train, Obs_train, normalize=NORM,
doy=DOY)
valid_ds = NetCDFDataset(Era5_valid, Obs_valid, normalize=NORM,
doy=DOY)
train_dl = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=SHUFFLE,
drop_last=False)
valid_dl = DataLoader(valid_ds, batch_size=BATCH_SIZE, shuffle=SHUFFLE,
drop_last=False)
# initialize network and optimizer
LogConfig.init_log('Initializing network and optimizer.')
# define number of output fields
# check whether modelling pr with probabilistic approach
......@@ -150,18 +125,90 @@ if __name__ == '__main__':
outputs = 3
# instanciate network
net = NET(state_file, train_ds.X.shape[1], outputs, filters=FILTERS)
net = NET(state_file, len(Era5_ds.data_vars), outputs, filters=FILTERS)
# initialize optimizer
optimizer = torch.optim.Adam(net.parameters(), lr=LR,
weight_decay=LAMBDA)
# initialize network trainer
trainer = NetworkTrainer(net, optimizer, net.state_file, train_dl,
valid_dl, loss_function=LOSS, **TRAIN_CONFIG)
# initialize training data
LogConfig.init_log('Initializing training data.')
if CV:
# split calibration period using cross-validation TimeSeriesSplit
cv = TimeSeriesSplit()
for i, (train_idx, valid_idx) in enumerate(cv.split(CALIB_PERIOD)):
# time steps for training and validation set
train = CALIB_PERIOD[train_idx]
valid = CALIB_PERIOD[valid_idx]
LogConfig.init_log('Fold {}/{}: {} - {}'.format(
i + 1, cv.n_splits, str(train[0]), str(train[-1])))
# training and validation dataset
Era5_train, Obs_train = (Era5_ds.sel(time=train),
Obs_ds.sel(time=train))
Era5_valid, Obs_valid = (Era5_ds.sel(time=valid),
Obs_ds.sel(time=valid))
# create PyTorch compliant dataset and dataloader instances for
# model training
train_ds = NetCDFDataset(Era5_train, Obs_train, normalize=NORM,
doy=DOY)
valid_ds = NetCDFDataset(Era5_valid, Obs_valid, normalize=NORM,
doy=DOY)
train_dl = DataLoader(train_ds, batch_size=BATCH_SIZE,
shuffle=SHUFFLE, drop_last=False)
valid_dl = DataLoader(valid_ds, batch_size=BATCH_SIZE,
shuffle=SHUFFLE, drop_last=False)
# initialize network trainer
trainer = NetworkTrainer(
net, optimizer, net.state_file, train_dl, valid_dl,
loss_function=LOSS, **TRAIN_CONFIG)
# train model
state = trainer.train()
# train model
state = trainer.train()
else:
# split calibration period into training and validation period
if PREDICTAND == 'pr' and STRATIFY:
# stratify training and validation dataset by number of
# observed wet days for precipitation
wet_days = (Obs_ds.sel(time=CALIB_PERIOD).mean(dim=('y', 'x'))
>= WET_DAY_THRESHOLD).to_array().values.squeeze()
train, valid = train_test_split(
CALIB_PERIOD, stratify=wet_days, test_size=0.1)
# sort chronologically
train, valid = sorted(train), sorted(valid)
else:
train, valid = train_test_split(CALIB_PERIOD, shuffle=False,
test_size=0.1)
# training and validation dataset
Era5_train, Obs_train = (Era5_ds.sel(time=train),
Obs_ds.sel(time=train))
Era5_valid, Obs_valid = (Era5_ds.sel(time=valid),
Obs_ds.sel(time=valid))
# create PyTorch compliant dataset and dataloader instances for model
# training
train_ds = NetCDFDataset(Era5_train, Obs_train, normalize=NORM,
doy=DOY)
valid_ds = NetCDFDataset(Era5_valid, Obs_valid, normalize=NORM,
doy=DOY)
train_dl = DataLoader(train_ds, batch_size=BATCH_SIZE,
shuffle=SHUFFLE, drop_last=False)
valid_dl = DataLoader(valid_ds, batch_size=BATCH_SIZE,
shuffle=SHUFFLE, drop_last=False)
# initialize network trainer
trainer = NetworkTrainer(net, optimizer, net.state_file, train_dl,
valid_dl, loss_function=LOSS,
**TRAIN_CONFIG)
# train model
state = trainer.train()
# log execution time of script
LogConfig.init_log('Execution time of script {}: {}'
......
"""Dynamical climate downscaling using deep convolutional neural networks."""
# !/usr/bin/env python
# -*- coding: utf-8 -*-
# builtins
import time
import logging
from datetime import timedelta
from logging.config import dictConfig
# externals
import torch
import xarray as xr
from sklearn.model_selection import TimeSeriesSplit
from torch.utils.data import DataLoader
# locals
from pysegcnn.core.utils import search_files
from pysegcnn.core.trainer import NetworkTrainer, LogConfig
from pysegcnn.core.models import Network
from pysegcnn.core.logging import log_conf
from climax.core.dataset import ERA5Dataset, NetCDFDataset
from climax.core.loss import BernoulliGammaLoss
from climax.main.config import (ERA5_PLEVELS, ERA5_PREDICTORS, PREDICTAND,
CALIB_PERIOD, DOY, SHUFFLE, BATCH_SIZE, LR,
LAMBDA, NORM, TRAIN_CONFIG, NET, LOSS, FILTERS,
OVERWRITE, DEM, DEM_FEATURES)
from climax.main.io import ERA5_PATH, OBS_PATH, DEM_PATH, MODEL_PATH
# module level logger
LOGGER = logging.getLogger(__name__)
if __name__ == '__main__':
# initialize timing
start_time = time.monotonic()
# initialize network filename
state_file = ERA5Dataset.state_file(
NET, PREDICTAND, ERA5_PREDICTORS, ERA5_PLEVELS, dem=DEM,
dem_features=DEM_FEATURES, doy=DOY)
# adjust statefile name for precipitation
if PREDICTAND == 'pr':
if isinstance(LOSS, BernoulliGammaLoss):
state_file = state_file.replace('.pt', '_{}mm_{}.pt'.format(
str(LOSS.min_amount).replace('.', ''),
repr(LOSS).strip('()')))
else:
state_file = state_file.replace('.pt', '_{}.pt'.format(
repr(LOSS).strip('()')))
# add suffix for training with cross-validation
state_file = state_file.replace('.pt', '_cv.pt')
# path to model state
state_file = MODEL_PATH.joinpath(PREDICTAND, state_file)
# initialize logging
log_file = MODEL_PATH.joinpath(PREDICTAND,
state_file.name.replace('.pt', '_log.txt'))
if log_file.exists():
log_file.unlink()
dictConfig(log_conf(log_file))
# initialize downscaling
LogConfig.init_log('Initializing downscaling for period: {}'.format(
' - '.join([str(CALIB_PERIOD[0]), str(CALIB_PERIOD[-1])])))
# check if model exists
if state_file.exists() and not OVERWRITE:
# load pretrained network
net, _ = Network.load_pretrained_model(state_file, NET)
else:
# initialize ERA5 predictor dataset
LogConfig.init_log('Initializing ERA5 predictors.')
Era5 = ERA5Dataset(ERA5_PATH.joinpath('ERA5'), ERA5_PREDICTORS,
plevels=ERA5_PLEVELS)
Era5_ds = Era5.merge(chunks=-1)
# initialize OBS predictand dataset
LogConfig.init_log('Initializing observations for predictand: {}'
.format(PREDICTAND))
# check whether to joinlty train tasmin and tasmax
if PREDICTAND == 'tas':
# read both tasmax and tasmin
tasmax = xr.open_dataset(
search_files(OBS_PATH.joinpath('tasmax'), '.nc$').pop())
tasmin = xr.open_dataset(
search_files(OBS_PATH.joinpath('tasmin'), '.nc$').pop())
Obs_ds = xr.merge([tasmax, tasmin])
else:
# read in-situ gridded observations
Obs_ds = search_files(OBS_PATH.joinpath(PREDICTAND), '.nc$').pop()
Obs_ds = xr.open_dataset(Obs_ds)
# whether to use digital elevation model
if DEM:
# digital elevation model: Copernicus EU-Dem v1.1
dem = search_files(DEM_PATH, '^eu_dem_v11_stt.nc$').pop()
# read elevation and compute slope and aspect
dem = ERA5Dataset.dem_features(
dem, {'y': Era5_ds.y, 'x': Era5_ds.x},
add_coord={'time': Era5_ds.time})
# check whether to use slope and aspect
if not DEM_FEATURES:
dem = dem.drop_vars(['slope', 'aspect'])
# add dem to set of predictor variables
Era5_ds = xr.merge([Era5_ds, dem])
# initialize network and optimizer
LogConfig.init_log('Initializing network and optimizer.')
# define number of output fields
# check whether modelling pr with probabilistic approach
outputs = len(Obs_ds.data_vars)
if PREDICTAND == 'pr' and isinstance(LOSS, BernoulliGammaLoss):
outputs = 3
# instanciate network
net = NET(state_file, len(Era5_ds.data_vars), outputs, filters=FILTERS)
# initialize optimizer
optimizer = torch.optim.Adam(net.parameters(), lr=LR,
weight_decay=LAMBDA)
# initialize training data
LogConfig.init_log('Initializing training data.')
# split calibration period using cross-validation TimeSeriesSplit
cv = TimeSeriesSplit()
for i, (train_idx, valid_idx) in enumerate(cv.split(CALIB_PERIOD)):
# time steps for training and validation set
train = CALIB_PERIOD[train_idx]
valid = CALIB_PERIOD[valid_idx]
LogConfig.init_log('Fold {}/{}: {} - {}'.format(
i + 1, cv.n_splits, str(train[0]), str(train[-1])))
# training and validation dataset
Era5_train, Obs_train = (Era5_ds.sel(time=train),
Obs_ds.sel(time=train))
Era5_valid, Obs_valid = (Era5_ds.sel(time=valid),
Obs_ds.sel(time=valid))
# create PyTorch compliant dataset and dataloader instances for
# model training
train_ds = NetCDFDataset(Era5_train, Obs_train, normalize=NORM,
doy=DOY)
valid_ds = NetCDFDataset(Era5_valid, Obs_valid, normalize=NORM,
doy=DOY)
train_dl = DataLoader(train_ds, batch_size=BATCH_SIZE,
shuffle=SHUFFLE, drop_last=False)
valid_dl = DataLoader(valid_ds, batch_size=BATCH_SIZE,
shuffle=SHUFFLE, drop_last=False)
# initialize network trainer
trainer = NetworkTrainer(net, optimizer, net.state_file, train_dl,
valid_dl, loss_function=LOSS,
**TRAIN_CONFIG)
# train model
state = trainer.train()
# log execution time of script
LogConfig.init_log('Execution time of script {}: {}'
.format(__file__, timedelta(seconds=time.monotonic() -
start_time)))
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment