diff --git a/pysegcnn/main/train_source.py b/pysegcnn/main/train_source.py index 55cbaa8dbc4f772533b93533de50361deabc19ed..b6b1a4d39db4bf1ebe63abd0b3c9202a9aec74a4 100644 --- a/pysegcnn/main/train_source.py +++ b/pysegcnn/main/train_source.py @@ -30,6 +30,7 @@ License # -*- coding: utf-8 -*- # builtins +import logging from logging.config import dictConfig # locals @@ -39,6 +40,9 @@ from pysegcnn.core.trainer import (DatasetConfig, SplitConfig, ModelConfig, from pysegcnn.main.train_config import ds_config, ds_split_config, model_config from pysegcnn.core.logging import log_conf +# module level logger +LOGGER = logging.getLogger(__name__) + if __name__ == '__main__': @@ -64,6 +68,12 @@ if __name__ == '__main__': # (vii) instanciate the model state file for the current fold state_file = net_sc.init_state(src_dc, src_sc, net_mc, fold=fold) + # check if the state file already exists + if state_file.exists() and not net_mc.checkpoint: + LOGGER.info('Fold already exists: {}'.format(state_file)) + LOGGER.info('Moving to next fold ...') + continue + # (viii) instanciate logging configuration net_lc = LogConfig(state_file) dictConfig(log_conf(net_lc.log_file))