Skip to content
Snippets Groups Projects
Commit 80149582 authored by Frisinghelli Daniel's avatar Frisinghelli Daniel
Browse files

Module to train CNN using cross-validation.

parent ae8300ef
No related branches found
No related tags found
No related merge requests found
"""Dynamical climate downscaling using deep convolutional neural networks."""
# !/usr/bin/env python
# -*- coding: utf-8 -*-
# builtins
import time
import logging
from datetime import timedelta
from logging.config import dictConfig
# externals
import torch
import xarray as xr
from sklearn.model_selection import TimeSeriesSplit
from torch.utils.data import DataLoader
# locals
from pysegcnn.core.utils import search_files
from pysegcnn.core.trainer import NetworkTrainer, LogConfig
from pysegcnn.core.models import Network
from pysegcnn.core.logging import log_conf
from climax.core.dataset import ERA5Dataset, NetCDFDataset
from climax.core.loss import BernoulliGammaLoss
from climax.main.config import (ERA5_PLEVELS, ERA5_PREDICTORS, PREDICTAND,
CALIB_PERIOD, DOY, SHUFFLE, BATCH_SIZE, LR,
LAMBDA, NORM, TRAIN_CONFIG, NET, LOSS, FILTERS,
OVERWRITE, DEM, DEM_FEATURES)
from climax.main.io import ERA5_PATH, OBS_PATH, DEM_PATH, MODEL_PATH
# module level logger
LOGGER = logging.getLogger(__name__)
if __name__ == '__main__':
# initialize timing
start_time = time.monotonic()
# initialize network filename
state_file = ERA5Dataset.state_file(
NET, PREDICTAND, ERA5_PREDICTORS, ERA5_PLEVELS, dem=DEM,
dem_features=DEM_FEATURES, doy=DOY)
# adjust statefile name for precipitation
if PREDICTAND == 'pr':
if isinstance(LOSS, BernoulliGammaLoss):
state_file = state_file.replace('.pt', '_{}mm_{}.pt'.format(
str(LOSS.min_amount).replace('.', ''),
repr(LOSS).strip('()')))
else:
state_file = state_file.replace('.pt', '_{}.pt'.format(
repr(LOSS).strip('()')))
# add suffix for training with cross-validation
state_file = state_file.replace('.pt', '_cv.pt')
# path to model state
state_file = MODEL_PATH.joinpath(PREDICTAND, state_file)
# initialize logging
log_file = MODEL_PATH.joinpath(PREDICTAND,
state_file.name.replace('.pt', '_log.txt'))
if log_file.exists():
log_file.unlink()
dictConfig(log_conf(log_file))
# initialize downscaling
LogConfig.init_log('Initializing downscaling for period: {}'.format(
' - '.join([str(CALIB_PERIOD[0]), str(CALIB_PERIOD[-1])])))
# check if model exists
if state_file.exists() and not OVERWRITE:
# load pretrained network
net, _ = Network.load_pretrained_model(state_file, NET)
else:
# initialize ERA5 predictor dataset
LogConfig.init_log('Initializing ERA5 predictors.')
Era5 = ERA5Dataset(ERA5_PATH.joinpath('ERA5'), ERA5_PREDICTORS,
plevels=ERA5_PLEVELS)
Era5_ds = Era5.merge(chunks=-1)
# initialize OBS predictand dataset
LogConfig.init_log('Initializing observations for predictand: {}'
.format(PREDICTAND))
# check whether to joinlty train tasmin and tasmax
if PREDICTAND == 'tas':
# read both tasmax and tasmin
tasmax = xr.open_dataset(
search_files(OBS_PATH.joinpath('tasmax'), '.nc$').pop())
tasmin = xr.open_dataset(
search_files(OBS_PATH.joinpath('tasmin'), '.nc$').pop())
Obs_ds = xr.merge([tasmax, tasmin])
else:
# read in-situ gridded observations
Obs_ds = search_files(OBS_PATH.joinpath(PREDICTAND), '.nc$').pop()
Obs_ds = xr.open_dataset(Obs_ds)
# whether to use digital elevation model
if DEM:
# digital elevation model: Copernicus EU-Dem v1.1
dem = search_files(DEM_PATH, '^eu_dem_v11_stt.nc$').pop()
# read elevation and compute slope and aspect
dem = ERA5Dataset.dem_features(
dem, {'y': Era5_ds.y, 'x': Era5_ds.x},
add_coord={'time': Era5_ds.time})
# check whether to use slope and aspect
if not DEM_FEATURES:
dem = dem.drop_vars(['slope', 'aspect'])
# add dem to set of predictor variables
Era5_ds = xr.merge([Era5_ds, dem])
# initialize network and optimizer
LogConfig.init_log('Initializing network and optimizer.')
# define number of output fields
# check whether modelling pr with probabilistic approach
outputs = len(Obs_ds.data_vars)
if PREDICTAND == 'pr' and isinstance(LOSS, BernoulliGammaLoss):
outputs = 3
# instanciate network
net = NET(state_file, len(Era5_ds.data_vars), outputs, filters=FILTERS)
# initialize optimizer
optimizer = torch.optim.Adam(net.parameters(), lr=LR,
weight_decay=LAMBDA)
# initialize training data
LogConfig.init_log('Initializing training data.')
# split calibration period using cross-validation TimeSeriesSplit
cv = TimeSeriesSplit()
for i, (train_idx, valid_idx) in enumerate(cv.split(CALIB_PERIOD)):
# time steps for training and validation set
train = CALIB_PERIOD[train_idx]
valid = CALIB_PERIOD[valid_idx]
LogConfig.init_log('Fold {}/{}: {} - {}'.format(
i + 1, cv.n_splits, str(train[0]), str(train[-1])))
# training and validation dataset
Era5_train, Obs_train = (Era5_ds.sel(time=train),
Obs_ds.sel(time=train))
Era5_valid, Obs_valid = (Era5_ds.sel(time=valid),
Obs_ds.sel(time=valid))
# create PyTorch compliant dataset and dataloader instances for
# model training
train_ds = NetCDFDataset(Era5_train, Obs_train, normalize=NORM,
doy=DOY)
valid_ds = NetCDFDataset(Era5_valid, Obs_valid, normalize=NORM,
doy=DOY)
train_dl = DataLoader(train_ds, batch_size=BATCH_SIZE,
shuffle=SHUFFLE, drop_last=False)
valid_dl = DataLoader(valid_ds, batch_size=BATCH_SIZE,
shuffle=SHUFFLE, drop_last=False)
# initialize network trainer
trainer = NetworkTrainer(net, optimizer, net.state_file, train_dl,
valid_dl, loss_function=LOSS,
**TRAIN_CONFIG)
# train model
state = trainer.train()
# log execution time of script
LogConfig.init_log('Execution time of script {}: {}'
.format(__file__, timedelta(seconds=time.monotonic() -
start_time)))
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment