From 8ee32614a70d73edb48e0b501017f96374f0572b Mon Sep 17 00:00:00 2001
From: "Daniel.Frisinghelli" <daniel.frisinghelli@eurac.edu>
Date: Fri, 19 Feb 2021 15:40:24 +0100
Subject: [PATCH] Added option to save or not save model evaluations.

---
 pysegcnn/core/cli.py     |  6 ++++++
 pysegcnn/core/trainer.py | 13 +++++++++----
 pysegcnn/main/eval.py    |  1 +
 3 files changed, 16 insertions(+), 4 deletions(-)

diff --git a/pysegcnn/core/cli.py b/pysegcnn/core/cli.py
index b870fcc..855694f 100644
--- a/pysegcnn/core/cli.py
+++ b/pysegcnn/core/cli.py
@@ -162,6 +162,12 @@ def evaluation_parser():
                               .format(default)),
                         default=False, nargs='?', const=True, metavar='')
 
+    # optional argument: whether to save model evaluations
+    parser.add_argument('-sv', '--save', type=bool,
+                        help=('Save model evaluations {}.'
+                              .format(default)),
+                        default=False, nargs='?', const=True, metavar='')
+
     # optional argument: whether to overwrite existing files
     parser.add_argument('-o', '--overwrite', type=bool,
                         help=('Overwrite existing model evaluations {}.'
diff --git a/pysegcnn/core/trainer.py b/pysegcnn/core/trainer.py
index 7b13d67..6f34059 100644
--- a/pysegcnn/core/trainer.py
+++ b/pysegcnn/core/trainer.py
@@ -2013,6 +2013,9 @@ class NetworkInference(BaseConfig):
         Whether to aggregate the statistics of the different models in
         ``state_files``. Useful to aggregate the results of mutliple model
         runs in cross validation. The default is `False`.
+    save : `bool`
+        Whether to save the model evaluations to a file. The default is
+        `False`.
     ds : `dict`
         The dataset configuration dictionary passed to
         :py:class:`pysegcnn.core.trainer.DatasetConfig` when evaluating on
@@ -2084,6 +2087,7 @@ class NetworkInference(BaseConfig):
     domain: str = 'src'
     test: object = False
     aggregate: bool = False
+    save: bool = False
     ds: dict = dataclasses.field(default_factory={})
     ds_split: dict = dataclasses.field(default_factory={})
     drive_path: str = ''
@@ -2668,7 +2672,7 @@ class NetworkInference(BaseConfig):
             log.init_log('Evaluating model: {}.'.format(state))
 
             # check whether model was already evaluated
-            if self.eval_file(state).exists():
+            if self.save and self.eval_file(state).exists():
                 LOGGER.info('Found existing model evaluation: {}.'
                             .format(self.eval_file(state)))
 
@@ -2740,9 +2744,10 @@ class NetworkInference(BaseConfig):
                     outpath=self.perfmc_path)
 
             # save model predictions to file
-            LOGGER.info('Saving model evaluation: {}'
-                        .format(self.eval_file(state)))
-            torch.save(output, self.eval_file(state))
+            if self.save:
+                LOGGER.info('Saving model evaluation: {}'
+                            .format(self.eval_file(state)))
+                torch.save(output, self.eval_file(state))
 
             # save model predictions to list
             inference[state.stem] = output
diff --git a/pysegcnn/main/eval.py b/pysegcnn/main/eval.py
index 1d09aba..c4ac3d3 100644
--- a/pysegcnn/main/eval.py
+++ b/pysegcnn/main/eval.py
@@ -71,6 +71,7 @@ if __name__ == '__main__':
             domain=args.domain,
             test=args.subset,
             aggregate=args.aggregate,
+            save=args.save,
             ds=ds,
             ds_split=ds_split,
             drive_path=args.dataset_path,
-- 
GitLab