Skip to content
Snippets Groups Projects
Commit cd2d0cc7 authored by Frisinghelli Daniel's avatar Frisinghelli Daniel
Browse files

Check if target file already exists.

parent ec6d1af3
No related branches found
No related tags found
No related merge requests found
......@@ -24,7 +24,7 @@ from climax.core.utils import split_date_range
from climax.main.config import (ERA5_PREDICTORS, ERA5_PLEVELS, PREDICTAND, NET,
VALID_PERIOD, BATCH_SIZE, NORM, DOY, NYEARS,
DEM, DEM_FEATURES, LOSS, ANOMALIES, OPTIM,
OPTIM_PARAMS, CHUNKS, LR_SCHEDULER)
OPTIM_PARAMS, CHUNKS, LR_SCHEDULER, OVERWRITE)
from climax.main.io import ERA5_PATH, DEM_PATH, MODEL_PATH, TARGET_PATH
# module level logger
......@@ -46,6 +46,22 @@ if __name__ == '__main__':
# path to model state
state_file = MODEL_PATH.joinpath(PREDICTAND, state_file)
# load pretrained model
if state_file.exists():
# load pretrained network
net, _ = Network.load_pretrained_model(state_file, NET)
else:
# initialize OBS predictand dataset
LOGGER.info('{} does not exist.'.format(state_file))
sys.exit()
# check if target dataset already exists
target = TARGET_PATH.joinpath(PREDICTAND, net.state_file.name.replace(
net.state_file.suffix, '.nc'))
if target.exists() and not OVERWRITE:
LogConfig.init_log('{} already exists.'.format(target))
sys.exit()
# initialize logging
log_file = MODEL_PATH.joinpath(PREDICTAND,
state_file.name.replace('.pt', '_log.txt'))
......@@ -78,15 +94,6 @@ if __name__ == '__main__':
# add dem to set of predictor variables
Era5_ds = xr.merge([Era5_ds, dem]).chunk(Era5_ds.chunks)
# load pretrained model
if state_file.exists():
# load pretrained network
net, _ = Network.load_pretrained_model(state_file, NET)
else:
# initialize OBS predictand dataset
LOGGER.info('{} does not exist.'.format(state_file))
sys.exit()
# subset to reference period and predict in NYEAR intervals
trg_ds = []
for dates in split_date_range(VALID_PERIOD[0], VALID_PERIOD[-1],
......@@ -103,8 +110,6 @@ if __name__ == '__main__':
trg_ds = xr.concat(trg_ds, dim='time')
# save model predictions as NetCDF file
target = TARGET_PATH.joinpath(PREDICTAND, net.state_file.name.replace(
net.state_file.suffix, '.nc'))
if not target.parent.exists():
target.parent.mkdir(parents=True, exist_ok=True)
LOGGER.info('Saving network predictions: {}.'.format(target))
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment