From cd2d0cc79a810356a47d66a7c808d04f22d35ea6 Mon Sep 17 00:00:00 2001 From: "Daniel.Frisinghelli" <daniel.frisinghelli@eurac.edu> Date: Fri, 22 Oct 2021 16:20:43 +0200 Subject: [PATCH] Check if target file already exists. --- climax/main/downscale_infer.py | 29 +++++++++++++++++------------ 1 file changed, 17 insertions(+), 12 deletions(-) diff --git a/climax/main/downscale_infer.py b/climax/main/downscale_infer.py index 5d26894..a023e00 100644 --- a/climax/main/downscale_infer.py +++ b/climax/main/downscale_infer.py @@ -24,7 +24,7 @@ from climax.core.utils import split_date_range from climax.main.config import (ERA5_PREDICTORS, ERA5_PLEVELS, PREDICTAND, NET, VALID_PERIOD, BATCH_SIZE, NORM, DOY, NYEARS, DEM, DEM_FEATURES, LOSS, ANOMALIES, OPTIM, - OPTIM_PARAMS, CHUNKS, LR_SCHEDULER) + OPTIM_PARAMS, CHUNKS, LR_SCHEDULER, OVERWRITE) from climax.main.io import ERA5_PATH, DEM_PATH, MODEL_PATH, TARGET_PATH # module level logger @@ -46,6 +46,22 @@ if __name__ == '__main__': # path to model state state_file = MODEL_PATH.joinpath(PREDICTAND, state_file) + # load pretrained model + if state_file.exists(): + # load pretrained network + net, _ = Network.load_pretrained_model(state_file, NET) + else: + # initialize OBS predictand dataset + LOGGER.info('{} does not exist.'.format(state_file)) + sys.exit() + + # check if target dataset already exists + target = TARGET_PATH.joinpath(PREDICTAND, net.state_file.name.replace( + net.state_file.suffix, '.nc')) + if target.exists() and not OVERWRITE: + LogConfig.init_log('{} already exists.'.format(target)) + sys.exit() + # initialize logging log_file = MODEL_PATH.joinpath(PREDICTAND, state_file.name.replace('.pt', '_log.txt')) @@ -78,15 +94,6 @@ if __name__ == '__main__': # add dem to set of predictor variables Era5_ds = xr.merge([Era5_ds, dem]).chunk(Era5_ds.chunks) - # load pretrained model - if state_file.exists(): - # load pretrained network - net, _ = Network.load_pretrained_model(state_file, NET) - else: - # initialize OBS predictand dataset - LOGGER.info('{} does not exist.'.format(state_file)) - sys.exit() - # subset to reference period and predict in NYEAR intervals trg_ds = [] for dates in split_date_range(VALID_PERIOD[0], VALID_PERIOD[-1], @@ -103,8 +110,6 @@ if __name__ == '__main__': trg_ds = xr.concat(trg_ds, dim='time') # save model predictions as NetCDF file - target = TARGET_PATH.joinpath(PREDICTAND, net.state_file.name.replace( - net.state_file.suffix, '.nc')) if not target.parent.exists(): target.parent.mkdir(parents=True, exist_ok=True) LOGGER.info('Saving network predictions: {}.'.format(target)) -- GitLab