diff --git a/pysegcnn/core/cli.py b/pysegcnn/core/cli.py index b870fcc79d6b1be6185b5f4c07248b7a48381e93..855694fefc8e165d4550a109ebf568109ee6b69e 100644 --- a/pysegcnn/core/cli.py +++ b/pysegcnn/core/cli.py @@ -162,6 +162,12 @@ def evaluation_parser(): .format(default)), default=False, nargs='?', const=True, metavar='') + # optional argument: whether to save model evaluations + parser.add_argument('-sv', '--save', type=bool, + help=('Save model evaluations {}.' + .format(default)), + default=False, nargs='?', const=True, metavar='') + # optional argument: whether to overwrite existing files parser.add_argument('-o', '--overwrite', type=bool, help=('Overwrite existing model evaluations {}.' diff --git a/pysegcnn/core/trainer.py b/pysegcnn/core/trainer.py index 7b13d67a7361b41302447c30e62edf7757215d34..6f3405901de818de98b37c95fde306bfe1147829 100644 --- a/pysegcnn/core/trainer.py +++ b/pysegcnn/core/trainer.py @@ -2013,6 +2013,9 @@ class NetworkInference(BaseConfig): Whether to aggregate the statistics of the different models in ``state_files``. Useful to aggregate the results of mutliple model runs in cross validation. The default is `False`. + save : `bool` + Whether to save the model evaluations to a file. The default is + `False`. ds : `dict` The dataset configuration dictionary passed to :py:class:`pysegcnn.core.trainer.DatasetConfig` when evaluating on @@ -2084,6 +2087,7 @@ class NetworkInference(BaseConfig): domain: str = 'src' test: object = False aggregate: bool = False + save: bool = False ds: dict = dataclasses.field(default_factory={}) ds_split: dict = dataclasses.field(default_factory={}) drive_path: str = '' @@ -2668,7 +2672,7 @@ class NetworkInference(BaseConfig): log.init_log('Evaluating model: {}.'.format(state)) # check whether model was already evaluated - if self.eval_file(state).exists(): + if self.save and self.eval_file(state).exists(): LOGGER.info('Found existing model evaluation: {}.' .format(self.eval_file(state))) @@ -2740,9 +2744,10 @@ class NetworkInference(BaseConfig): outpath=self.perfmc_path) # save model predictions to file - LOGGER.info('Saving model evaluation: {}' - .format(self.eval_file(state))) - torch.save(output, self.eval_file(state)) + if self.save: + LOGGER.info('Saving model evaluation: {}' + .format(self.eval_file(state))) + torch.save(output, self.eval_file(state)) # save model predictions to list inference[state.stem] = output diff --git a/pysegcnn/main/eval.py b/pysegcnn/main/eval.py index 1d09aba543d1076d0784a137d1e8e6f268e7db56..c4ac3d38266b427deb3dcfb4f6011c99c2ba9142 100644 --- a/pysegcnn/main/eval.py +++ b/pysegcnn/main/eval.py @@ -71,6 +71,7 @@ if __name__ == '__main__': domain=args.domain, test=args.subset, aggregate=args.aggregate, + save=args.save, ds=ds, ds_split=ds_split, drive_path=args.dataset_path,