Skip to content
Snippets Groups Projects
Commit b81925ec authored by Frisinghelli Daniel's avatar Frisinghelli Daniel
Browse files

Merged training and inference.

parent c6444ae9
No related branches found
No related tags found
No related merge requests found
......@@ -12,21 +12,28 @@ from logging.config import dictConfig
# externals
import xarray as xr
from sklearn.model_selection import train_test_split
from torch.utils.data import DataLoader
# locals
from pysegcnn.core.trainer import LogConfig
from pysegcnn.core.utils import search_files
from pysegcnn.core.models import Network
from pysegcnn.core.trainer import NetworkTrainer, LogConfig
from pysegcnn.core.logging import log_conf
from pysegcnn.core.utils import search_files
from climax.core.dataset import ERA5Dataset
from climax.core.dataset import ERA5Dataset, NetCDFDataset
from climax.core.loss import MSELoss, L1Loss
from climax.core.predict import predict_ERA5
from climax.core.utils import split_date_range
from climax.main.config import (ERA5_PREDICTORS, ERA5_PLEVELS, PREDICTAND, NET,
VALID_PERIOD, BATCH_SIZE, NORM, DOY, NYEARS,
DEM, DEM_FEATURES, LOSS, ANOMALIES, OPTIM,
OPTIM_PARAMS, CHUNKS, LR_SCHEDULER, OVERWRITE,
SENSITIVITY)
from climax.main.io import ERA5_PATH, DEM_PATH, MODEL_PATH, TARGET_PATH
from climax.main.config import (ERA5_PLEVELS, ERA5_PREDICTORS, PREDICTAND,
CALIB_PERIOD, DOY, SHUFFLE, BATCH_SIZE, OPTIM,
NORM, TRAIN_CONFIG, NET, LOSS, FILTERS,
OVERWRITE, DEM, DEM_FEATURES, STRATIFY,
WET_DAY_THRESHOLD, VALID_SIZE, ANOMALIES,
OPTIM_PARAMS, LR_SCHEDULER, SENSITIVITY,
LR_SCHEDULER_PARAMS, CHUNKS, VALID_PERIOD,
NYEARS)
from climax.main.io import (ERA5_PATH, OBS_PATH, DEM_PATH, MODEL_PATH,
TARGET_PATH)
# module level logger
LOGGER = logging.getLogger(__name__)
......@@ -53,53 +60,140 @@ if __name__ == '__main__':
state_file = MODEL_PATH.joinpath(PREDICTAND, state_file)
target = TARGET_PATH.joinpath(PREDICTAND)
# initialize logging
log_file = state_file.parent.joinpath(
state_file.name.replace(state_file.suffix, '_log.txt'))
if log_file.exists():
log_file.unlink()
dictConfig(log_conf(log_file))
# check if target dataset already exists
target = target.joinpath(state_file.name.replace(state_file.suffix, '.nc'))
if target.exists() and not OVERWRITE:
LogConfig.init_log('{} already exists.'.format(target))
sys.exit()
# load pretrained model
if state_file.exists() and not OVERWRITE:
# load pretrained network
net, _ = Network.load_pretrained_model(state_file, NET)
else:
# load pretrained model
if state_file.exists():
# initialize downscaling
LogConfig.init_log('Initializing downscaling for period: {}'.format(
' - '.join([str(CALIB_PERIOD[0]), str(CALIB_PERIOD[-1])])))
# check if model exists
if state_file.exists() and not OVERWRITE:
# load pretrained network
net, _ = Network.load_pretrained_model(state_file, NET)
else:
# initialize OBS predictand dataset
LOGGER.info('{} does not exist.'.format(state_file))
LogConfig.init_log('{} already exists.'.format(state_file))
sys.exit()
# initialize logging
log_file = state_file.parent.joinpath(
state_file.name.replace(state_file.suffix, '_log.txt'))
dictConfig(log_conf(log_file))
# initialize ERA5 predictor dataset
LogConfig.init_log('Initializing ERA5 predictors.')
Era5 = ERA5Dataset(ERA5_PATH.joinpath('ERA5'), ERA5_PREDICTORS,
plevels=ERA5_PLEVELS)
Era5_ds = Era5.merge(chunks=CHUNKS)
# initialize OBS predictand dataset
LogConfig.init_log('Initializing observations for predictand: {}'
.format(PREDICTAND))
# check whether to joinlty train tasmin and tasmax
if PREDICTAND == 'tas':
# read both tasmax and tasmin
tasmax = xr.open_dataset(
search_files(OBS_PATH.joinpath('tasmax'), '.nc$').pop())
tasmin = xr.open_dataset(
search_files(OBS_PATH.joinpath('tasmin'), '.nc$').pop())
Obs_ds = xr.merge([tasmax, tasmin])
else:
# read in-situ gridded observations
Obs_ds = search_files(OBS_PATH.joinpath(PREDICTAND), '.nc$').pop()
Obs_ds = xr.open_dataset(Obs_ds)
# whether to use digital elevation model
if DEM:
# digital elevation model: Copernicus EU-Dem v1.1
dem = search_files(DEM_PATH, '^eu_dem_v11_stt.nc$').pop()
# read elevation and compute slope and aspect
dem = ERA5Dataset.dem_features(
dem, {'y': Era5_ds.y, 'x': Era5_ds.x},
add_coord={'time': Era5_ds.time})
# check whether to use slope and aspect
if not DEM_FEATURES:
dem = dem.drop_vars(['slope', 'aspect']).chunk(Era5_ds.chunks)
# add dem to set of predictor variables
Era5_ds = xr.merge([Era5_ds, dem])
# initialize network and optimizer
LogConfig.init_log('Initializing network and optimizer.')
# define number of output fields
# check whether modelling pr with probabilistic approach
outputs = len(Obs_ds.data_vars)
if PREDICTAND == 'pr':
outputs = (1 if (isinstance(LOSS, MSELoss) or
isinstance(LOSS, L1Loss)) else 3)
# instanciate network
inputs = len(Era5_ds.data_vars) + 2 if DOY else len(Era5_ds.data_vars)
net = NET(state_file, inputs, outputs, filters=FILTERS)
# initialize optimizer
optimizer = OPTIM(net.parameters(), **OPTIM_PARAMS)
# initialize learning rate scheduler
if LR_SCHEDULER is not None:
LR_SCHEDULER = LR_SCHEDULER(optimizer, **LR_SCHEDULER_PARAMS)
# initialize training data
LogConfig.init_log('Initializing training data.')
# split calibration period into training and validation period
if PREDICTAND == 'pr' and STRATIFY:
# stratify training and validation dataset by number of
# observed wet days for precipitation
wet_days = (Obs_ds.sel(time=CALIB_PERIOD).mean(dim=('y', 'x'))
>= WET_DAY_THRESHOLD).to_array().values.squeeze()
train, valid = train_test_split(
CALIB_PERIOD, stratify=wet_days, test_size=VALID_SIZE)
# sort chronologically
train, valid = sorted(train), sorted(valid)
else:
train, valid = train_test_split(CALIB_PERIOD, shuffle=False,
test_size=VALID_SIZE)
# training and validation dataset
Era5_train, Obs_train = Era5_ds.sel(time=train), Obs_ds.sel(time=train)
Era5_valid, Obs_valid = Era5_ds.sel(time=valid), Obs_ds.sel(time=valid)
# create PyTorch compliant dataset and dataloader instances for model
# training
train_ds = NetCDFDataset(Era5_train, Obs_train, normalize=NORM,
doy=DOY, anomalies=ANOMALIES)
valid_ds = NetCDFDataset(Era5_valid, Obs_valid, normalize=NORM,
doy=DOY, anomalies=ANOMALIES)
train_dl = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=SHUFFLE,
drop_last=False)
valid_dl = DataLoader(valid_ds, batch_size=BATCH_SIZE, shuffle=SHUFFLE,
drop_last=False)
# initialize network trainer
trainer = NetworkTrainer(net, optimizer, net.state_file, train_dl,
valid_dl, loss_function=LOSS,
lr_scheduler=LR_SCHEDULER, **TRAIN_CONFIG)
# train model
state = trainer.train()
# predict reference period
LogConfig.init_log('Predicting reference period: {}'.format(
' - '.join([str(VALID_PERIOD[0]), str(VALID_PERIOD[-1])])))
# initialize ERA5 predictor dataset
LogConfig.init_log('Initializing ERA5 predictors.')
Era5 = ERA5Dataset(ERA5_PATH.joinpath('ERA5'), ERA5_PREDICTORS,
plevels=ERA5_PLEVELS)
Era5_ds = Era5.merge(chunks=CHUNKS)
# whether to use digital elevation model
if DEM:
# digital elevation model: Copernicus EU-Dem v1.1
dem = search_files(DEM_PATH, '^eu_dem_v11_stt.nc$').pop()
# read elevation and compute slope and aspect
dem = ERA5Dataset.dem_features(
dem, {'y': Era5_ds.y, 'x': Era5_ds.x},
add_coord={'time': Era5_ds.time})
# check whether to use slope and aspect
if not DEM_FEATURES:
dem = dem.drop_vars(['slope', 'aspect'])
# add dem to set of predictor variables
Era5_ds = xr.merge([Era5_ds, dem]).chunk(Era5_ds.chunks)
# subset to reference period and predict in NYEAR intervals
trg_ds = []
for dates in split_date_range(VALID_PERIOD[0], VALID_PERIOD[-1],
......
"""Dynamical climate downscaling using deep convolutional neural networks."""
# !/usr/bin/env python
# -*- coding: utf-8 -*-
# builtins
import sys
import time
import logging
from datetime import timedelta
from logging.config import dictConfig
# externals
import xarray as xr
from sklearn.model_selection import train_test_split
from torch.utils.data import DataLoader
# locals
from pysegcnn.core.utils import search_files
from pysegcnn.core.trainer import NetworkTrainer, LogConfig
from pysegcnn.core.logging import log_conf
from climax.core.dataset import ERA5Dataset, NetCDFDataset
from climax.core.loss import MSELoss, L1Loss
from climax.main.config import (ERA5_PLEVELS, ERA5_PREDICTORS, PREDICTAND,
CALIB_PERIOD, DOY, SHUFFLE, BATCH_SIZE, OPTIM,
NORM, TRAIN_CONFIG, NET, LOSS, FILTERS,
OVERWRITE, DEM, DEM_FEATURES, STRATIFY,
WET_DAY_THRESHOLD, VALID_SIZE, ANOMALIES,
OPTIM_PARAMS, LR_SCHEDULER, SENSITIVITY,
LR_SCHEDULER_PARAMS, CHUNKS)
from climax.main.io import ERA5_PATH, OBS_PATH, DEM_PATH, MODEL_PATH
# module level logger
LOGGER = logging.getLogger(__name__)
if __name__ == '__main__':
# initialize timing
start_time = time.monotonic()
# initialize network filename
state_file = ERA5Dataset.state_file(
NET, PREDICTAND, ERA5_PREDICTORS, ERA5_PLEVELS, dem=DEM,
dem_features=DEM_FEATURES, doy=DOY, loss=LOSS, anomalies=ANOMALIES,
decay=OPTIM_PARAMS['weight_decay'], optim=OPTIM, lr=OPTIM_PARAMS['lr'],
lr_scheduler=LR_SCHEDULER)
# path to model state
if SENSITIVITY:
# models trained for hyperparameter optimization
state_file = MODEL_PATH.joinpath('sensitivity', PREDICTAND, state_file)
else:
state_file = MODEL_PATH.joinpath(PREDICTAND, state_file)
# initialize logging
log_file = state_file.parent.joinpath(
state_file.name.replace(state_file.suffix, '_log.txt'))
if log_file.exists():
log_file.unlink()
dictConfig(log_conf(log_file))
# initialize downscaling
LogConfig.init_log('Initializing downscaling for period: {}'.format(
' - '.join([str(CALIB_PERIOD[0]), str(CALIB_PERIOD[-1])])))
# check if model exists
if state_file.exists() and not OVERWRITE:
# load pretrained network
LogConfig.init_log('{} already exists.'.format(state_file))
sys.exit()
# initialize ERA5 predictor dataset
LogConfig.init_log('Initializing ERA5 predictors.')
Era5 = ERA5Dataset(ERA5_PATH.joinpath('ERA5'), ERA5_PREDICTORS,
plevels=ERA5_PLEVELS)
Era5_ds = Era5.merge(chunks=CHUNKS)
# initialize OBS predictand dataset
LogConfig.init_log('Initializing observations for predictand: {}'
.format(PREDICTAND))
# check whether to joinlty train tasmin and tasmax
if PREDICTAND == 'tas':
# read both tasmax and tasmin
tasmax = xr.open_dataset(
search_files(OBS_PATH.joinpath('tasmax'), '.nc$').pop())
tasmin = xr.open_dataset(
search_files(OBS_PATH.joinpath('tasmin'), '.nc$').pop())
Obs_ds = xr.merge([tasmax, tasmin])
else:
# read in-situ gridded observations
Obs_ds = search_files(OBS_PATH.joinpath(PREDICTAND), '.nc$').pop()
Obs_ds = xr.open_dataset(Obs_ds)
# whether to use digital elevation model
if DEM:
# digital elevation model: Copernicus EU-Dem v1.1
dem = search_files(DEM_PATH, '^eu_dem_v11_stt.nc$').pop()
# read elevation and compute slope and aspect
dem = ERA5Dataset.dem_features(
dem, {'y': Era5_ds.y, 'x': Era5_ds.x},
add_coord={'time': Era5_ds.time})
# check whether to use slope and aspect
if not DEM_FEATURES:
dem = dem.drop_vars(['slope', 'aspect']).chunk(Era5_ds.chunks)
# add dem to set of predictor variables
Era5_ds = xr.merge([Era5_ds, dem])
# initialize network and optimizer
LogConfig.init_log('Initializing network and optimizer.')
# define number of output fields
# check whether modelling pr with probabilistic approach
outputs = len(Obs_ds.data_vars)
if PREDICTAND == 'pr':
outputs = (1 if (isinstance(LOSS, MSELoss) or isinstance(LOSS, L1Loss))
else 3)
# instanciate network
inputs = len(Era5_ds.data_vars) + 2 if DOY else len(Era5_ds.data_vars)
net = NET(state_file, inputs, outputs, filters=FILTERS)
# initialize optimizer
optimizer = OPTIM(net.parameters(), **OPTIM_PARAMS)
# initialize learning rate scheduler
if LR_SCHEDULER is not None:
LR_SCHEDULER = LR_SCHEDULER(optimizer, **LR_SCHEDULER_PARAMS)
# initialize training data
LogConfig.init_log('Initializing training data.')
# split calibration period into training and validation period
if PREDICTAND == 'pr' and STRATIFY:
# stratify training and validation dataset by number of
# observed wet days for precipitation
wet_days = (Obs_ds.sel(time=CALIB_PERIOD).mean(dim=('y', 'x'))
>= WET_DAY_THRESHOLD).to_array().values.squeeze()
train, valid = train_test_split(
CALIB_PERIOD, stratify=wet_days, test_size=VALID_SIZE)
# sort chronologically
train, valid = sorted(train), sorted(valid)
else:
train, valid = train_test_split(CALIB_PERIOD, shuffle=False,
test_size=VALID_SIZE)
# training and validation dataset
Era5_train, Obs_train = Era5_ds.sel(time=train), Obs_ds.sel(time=train)
Era5_valid, Obs_valid = Era5_ds.sel(time=valid), Obs_ds.sel(time=valid)
# create PyTorch compliant dataset and dataloader instances for model
# training
train_ds = NetCDFDataset(Era5_train, Obs_train, normalize=NORM, doy=DOY,
anomalies=ANOMALIES)
valid_ds = NetCDFDataset(Era5_valid, Obs_valid, normalize=NORM, doy=DOY,
anomalies=ANOMALIES)
train_dl = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=SHUFFLE,
drop_last=False)
valid_dl = DataLoader(valid_ds, batch_size=BATCH_SIZE, shuffle=SHUFFLE,
drop_last=False)
# initialize network trainer
trainer = NetworkTrainer(net, optimizer, net.state_file, train_dl,
valid_dl, loss_function=LOSS,
lr_scheduler=LR_SCHEDULER, **TRAIN_CONFIG)
# train model
state = trainer.train()
# log execution time of script
LogConfig.init_log('Execution time of script {}: {}'
.format(__file__, timedelta(seconds=time.monotonic() -
start_time)))
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment