From 7cda65423e7a59fe19530736265c6f12ec235892 Mon Sep 17 00:00:00 2001
From: "Daniel.Frisinghelli" <daniel.frisinghelli@eurac.edu>
Date: Thu, 30 Jul 2020 10:40:17 +0200
Subject: [PATCH] Removed option to only predict some samples; not required

---
 pysegcnn/core/predict.py | 57 ++++++++++++++--------------------------
 pysegcnn/main/config.py  |  6 -----
 pysegcnn/main/eval.py    |  3 ---
 3 files changed, 19 insertions(+), 47 deletions(-)

diff --git a/pysegcnn/core/predict.py b/pysegcnn/core/predict.py
index 874d978..fe9a33e 100644
--- a/pysegcnn/core/predict.py
+++ b/pysegcnn/core/predict.py
@@ -25,9 +25,12 @@ def get_scene_tiles(ds, scene_id):
     return indices
 
 
-def predict_samples(ds, model, optimizer, state_path, state_file, nsamples,
-                    seed, batch_size=None, cm=False, plot_samples=False,
-                    **kwargs):
+def predict_samples(ds, model, optimizer, state_path, state_file, cm=False,
+                    plot=False, **kwargs):
+
+    # check whether the dataset is a subset
+    if not isinstance(ds, Subset):
+        raise TypeError('ds should be of type {}'.format(Subset))
 
     # the device to compute on, use gpu if available
     device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
@@ -49,26 +52,13 @@ def predict_samples(ds, model, optimizer, state_path, state_file, nsamples,
     # initialize confusion matrix
     cmm = np.zeros(shape=(model.nclasses, model.nclasses))
 
-    # set random seed for reproducibility
-    np.random.seed(seed)
-
-    # draw a number of samples from the dataset
-    samples = np.arange(0, len(ds))
-    if nsamples > 0:
-        batch_size = nsamples
-        samples = np.random.choice(samples, size=min(nsamples, len(ds)))
-
-    # create a subset of the dataset
-    smpl_subset = Subset(ds, samples.tolist())
-    if batch_size is None:
-        raise ValueError('If you specify "nsamples"=-1, you have to provide '
-                         'a batch size, e.g. trainer.batch_size.')
-    smpl_loader = DataLoader(smpl_subset, batch_size=batch_size, shuffle=False)
+    # create the dataloader
+    dataloader = DataLoader(ds, batch_size=1, shuffle=False, drop_last=False)
 
     # iterate over the samples and plot inputs, ground truth and
     # model predictions
     output = {}
-    for batch, (inputs, labels) in enumerate(smpl_loader):
+    for batch, (inputs, labels) in enumerate(dataloader):
 
         # send inputs and labels to device
         inputs = inputs.to(device)
@@ -87,22 +77,14 @@ def predict_samples(ds, model, optimizer, state_path, state_file, nsamples,
                 cmm[ytrue.long(), ypred.long()] += 1
 
         # save plot of current batch to disk
-        if plot_samples:
-
-            # check whether the dataset is a subset
-            if isinstance(ds, Subset):
-                use_bands = ds.dataset.use_bands
-                ds_labels = ds.dataset.labels
-            else:
-                use_bands = ds.use_bands
-                ds_labels = ds.labels
+        if plot:
 
             # plot inputs, ground truth and model predictions
-            sname = fname + '_sample_{}.pt'.format(batch)
+            sname = fname + '_{}_{}.pt'.format(ds.name, batch)
             fig, ax = plot_sample(inputs.numpy().clip(0, 1),
                                   labels,
-                                  use_bands,
-                                  ds_labels,
+                                  ds.dataset.use_bands,
+                                  ds.dataset.labels,
                                   y_pred=prd,
                                   state=sname,
                                   **kwargs)
@@ -113,6 +95,10 @@ def predict_samples(ds, model, optimizer, state_path, state_file, nsamples,
 def predict_scenes(ds, model, optimizer, state_path, state_file,
                    scene_id=None, cm=False, plot_scenes=False, **kwargs):
 
+    # check if the dataset is an instance of torch.data.dataset.Subset
+    if not isinstance(ds, Subset):
+        raise TypeError('ds should be of type {}'.format(Subset))
+
     # the device to compute on, use gpu if available
     device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
 
@@ -136,12 +122,6 @@ def predict_scenes(ds, model, optimizer, state_path, state_file,
     # check whether a scene id is provided
     if scene_id is None:
 
-        # check if the dataset is an instance of torch.data.dataset.Subset
-        if not isinstance(ds, Subset):
-            raise TypeError('ds should be of type {}'.format(Subset))
-
-        print('Predicting scenes of the subset ...')
-
         # get the names of the scenes
         try:
             scene_ids = ds.ids
@@ -158,6 +138,7 @@ def predict_scenes(ds, model, optimizer, state_path, state_file,
     scene_size = (ds.dataset.height, ds.dataset.width)
 
     # iterate over the scenes
+    print('Predicting scenes of the subset ...')
     scene = {}
     for sid in scene_ids:
 
@@ -173,7 +154,7 @@ def predict_scenes(ds, model, optimizer, state_path, state_file,
 
         # create the dataloader
         scene_dl = DataLoader(scene_ds, batch_size=len(scene_ds),
-                              shuffle=False)
+                              shuffle=False, drop_last=False)
 
         # predict the current scene
         for i, (inp, lab) in enumerate(scene_dl):
diff --git a/pysegcnn/main/config.py b/pysegcnn/main/config.py
index d453fa8..61bdaaa 100644
--- a/pysegcnn/main/config.py
+++ b/pysegcnn/main/config.py
@@ -261,12 +261,6 @@ config = {
     #       split_mode="date"
     'predict_scene': True,
 
-    # number of samples to validate model performance on
-    # if nsamples': -1, the model is evaluated on all samples of the validation
-    # set or test set
-    # only takes effect if predict_scene=False
-    'nsamples': -1,
-
     # whether to save plots of (input, ground truth, prediction) of the
     # samples from the validation/test dataset to disk
     # output path is: pysegcnn/main/_samples/
diff --git a/pysegcnn/main/eval.py b/pysegcnn/main/eval.py
index f082c35..85a5351 100644
--- a/pysegcnn/main/eval.py
+++ b/pysegcnn/main/eval.py
@@ -49,9 +49,6 @@ if __name__ == '__main__':
                                       trainer.optimizer,
                                       trainer.state_path,
                                       trainer.state_file,
-                                      trainer.nsamples,
-                                      trainer.seed,
-                                      trainer.batch_size,
                                       trainer.cm,
                                       trainer.plot_samples,
                                       bands=trainer.plot_bands,
-- 
GitLab