From 46fec41ad7279d6cce17a93b2d992320d8272ebd Mon Sep 17 00:00:00 2001
From: "Daniel.Frisinghelli" <daniel.frisinghelli@eurac.edu>
Date: Mon, 17 Aug 2020 10:18:30 +0200
Subject: [PATCH] Adjusted prediction modules to changes in core.trainer.py

---
 pysegcnn/core/predict.py | 28 +++++--------------------
 pysegcnn/main/eval.py    | 45 +++++++++++++---------------------------
 2 files changed, 19 insertions(+), 54 deletions(-)

diff --git a/pysegcnn/core/predict.py b/pysegcnn/core/predict.py
index ae53905..eebade8 100644
--- a/pysegcnn/core/predict.py
+++ b/pysegcnn/core/predict.py
@@ -31,8 +31,7 @@ def get_scene_tiles(ds, scene_id):
     return indices
 
 
-def predict_samples(ds, model, optimizer, state_file, cm=False,
-                    plot=False, **kwargs):
+def predict_samples(ds, model, cm=False, plot=False, **kwargs):
 
     # check whether the dataset is a valid subset, i.e.
     # an instance of pysegcnn.core.split.SceneSubset or
@@ -41,24 +40,16 @@ def predict_samples(ds, model, optimizer, state_file, cm=False,
         raise TypeError('ds should be an instance of {} or of {}.'
                         .format(repr(RandomSubset), repr(SceneSubset)))
 
-    # convert state file to pathlib.Path object
-    state_file = pathlib.Path(state_file)
-
     # the device to compute on, use gpu if available
     device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
 
-    # load the pretrained model state
-    if not state_file.exists():
-        raise FileNotFoundError('{} does not exist.'.format(state_file))
-    _ = model.load(state_file.name, optimizer, state_file.parent)
-
     # set the model to evaluation mode
     LOGGER.info('Setting model to evaluation mode ...')
     model.eval()
     model.to(device)
 
     # base filename for each sample
-    fname = state_file.name.split('.pt')[0]
+    fname = model.state_file.name.split('.pt')[0]
 
     # initialize confusion matrix
     cmm = np.zeros(shape=(model.nclasses, model.nclasses))
@@ -107,8 +98,7 @@ def predict_samples(ds, model, optimizer, state_file, cm=False,
     return output, cmm
 
 
-def predict_scenes(ds, model, optimizer, state_file,
-                   scene_id=None, cm=False, plot_scenes=False, **kwargs):
+def predict_scenes(ds, model, scene_id=None, cm=False, plot=False, **kwargs):
 
     # check whether the dataset is a valid subset, i.e. an instance of
     # pysegcnn.core.split.SceneSubset
@@ -116,24 +106,16 @@ def predict_scenes(ds, model, optimizer, state_file,
         raise TypeError('ds should be an instance of {}.'
                         .format(repr(SceneSubset)))
 
-    # convert state file to pathlib.Path object
-    state_file = pathlib.Path(state_file)
-
     # the device to compute on, use gpu if available
     device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
 
-    # load the pretrained model state
-    if not state_file.exists():
-        raise FileNotFoundError('{} does not exist.'.format(state_file))
-    _ = model.load(state_file.name, optimizer, state_file.parent)
-
     # set the model to evaluation mode
     LOGGER.info('Setting model to evaluation mode ...')
     model.eval()
     model.to(device)
 
     # base filename for each scene
-    fname = state_file.name.split('.pt')[0]
+    fname = model.state_file.name.split('.pt')[0]
 
     # initialize confusion matrix
     cmm = np.zeros(shape=(model.nclasses, model.nclasses))
@@ -196,7 +178,7 @@ def predict_scenes(ds, model, optimizer, state_file,
         scenes[sid] = {'input': inputs, 'labels': labels, 'prediction': prdtcn}
 
         # plot current scene
-        if plot_scenes:
+        if plot:
             fig, ax = plot_sample(inputs.clip(0, 1),
                                   labels,
                                   ds.dataset.use_bands,
diff --git a/pysegcnn/main/eval.py b/pysegcnn/main/eval.py
index 2479340..82fe4e3 100644
--- a/pysegcnn/main/eval.py
+++ b/pysegcnn/main/eval.py
@@ -9,7 +9,7 @@ import os
 
 # locals
 from pysegcnn.core.trainer import (DatasetConfig, SplitConfig, ModelConfig,
-                                   TrainConfig, EvalConfig)
+                                   StateConfig, EvalConfig)
 from pysegcnn.core.predict import predict_samples, predict_scenes
 from pysegcnn.main.config import (dataset_config, split_config, model_config,
                                   train_config, eval_config, HERE)
@@ -22,24 +22,21 @@ if __name__ == '__main__':
     dc = DatasetConfig(**dataset_config)
     sc = SplitConfig(**split_config)
     mc = ModelConfig(**model_config)
-    tc = TrainConfig(**train_config)
     ec = EvalConfig(**eval_config)
 
     # (ii) instanciate the dataset
     ds = dc.init_dataset()
-    ds
 
     # (iii) instanciate the training, validation and test datasets
     train_ds, valid_ds, test_ds = sc.train_val_test_split(ds)
 
-    # (iv) instanciate the model state files
-    state_file, loss_state = mc.init_state(ds, sc, tc)
+    # (iv) instanciate the model state
+    state = StateConfig(ds, sc, mc)
+    state_file, loss_state = state.init_state()
 
-    # (v) instanciate the model
-    model = mc.init_model(ds)
-
-    # (vi) instanciate the optimizer
-    optimizer = tc.init_optimizer(model)
+    # (vii) load pretrained model weights
+    model, _ = mc.load_pretrained(state_file)
+    model.state_file = state_file
 
     # plot loss and accuracy
     plot_loss(loss_state, outpath=os.path.join(HERE, '_graphics/'))
@@ -54,37 +51,23 @@ if __name__ == '__main__':
     # keyword arguments for plotting
     kwargs = {'bands': ec.plot_bands,
               'outpath': os.path.join(HERE, '_scenes/'),
-              'stretch': True,
-              'alpha': 5}
+              'alpha': ec.alpha,
+              'figsize': ec.figsize}
 
     # whether to predict each sample or each scene individually
     if ec.predict_scene:
         # reconstruct and predict the scenes in the validation/test set
-        scenes, cm = predict_scenes(ds,
-                                    model,
-                                    optimizer,
-                                    state_file,
-                                    scene_id=None,
-                                    cm=ec.cm,
-                                    plot_scenes=ec.plot_scenes,
-                                    **kwargs
-                                    )
+        scenes, cm = predict_scenes(ds, model, scene_id=None, cm=ec.cm,
+                                    plot=ec.plot_scenes, **kwargs)
 
     else:
         # predict the samples in the validation/test set
-        samples, cm = predict_samples(ds,
-                                      model,
-                                      optimizer,
-                                      state_file,
-                                      cm=ec.cm,
-                                      plot_scenes=ec.plot_scenes,
-                                      **kwargs)
+        samples, cm = predict_samples(ds, model, cm=ec.cm,
+                                      plot=ec.plot_samples, **kwargs)
 
     # whether to plot the confusion matrix
     if ec.cm:
-        plot_confusion_matrix(cm,
-                              ds.dataset.labels,
-                              normalize=True,
+        plot_confusion_matrix(cm, ds.dataset.labels,
                               state=state_file.name.replace('.pt', '.png'),
                               outpath=os.path.join(HERE, '_graphics/')
                               )
-- 
GitLab