From f8a96422f7b126d86171d0b0639804443e4a8d85 Mon Sep 17 00:00:00 2001
From: "Daniel.Frisinghelli" <daniel.frisinghelli@eurac.edu>
Date: Thu, 30 Jul 2020 16:49:21 +0200
Subject: [PATCH] Included type handling and improved console prints

---
 pysegcnn/core/predict.py | 62 +++++++++++++++++++++++-----------------
 1 file changed, 36 insertions(+), 26 deletions(-)

diff --git a/pysegcnn/core/predict.py b/pysegcnn/core/predict.py
index fe9a33e..b1e69cd 100644
--- a/pysegcnn/core/predict.py
+++ b/pysegcnn/core/predict.py
@@ -9,8 +9,9 @@ from torch.utils.data.dataset import Subset
 import torch.nn.functional as F
 
 # locals
-from pysegcnn.core.utils import reconstruct_scene
+from pysegcnn.core.utils import reconstruct_scene, accuracy_function
 from pysegcnn.core.graphics import plot_sample
+from pysegcnn.core.split import RandomSubset, SceneSubset
 
 
 def get_scene_tiles(ds, scene_id):
@@ -28,9 +29,18 @@ def get_scene_tiles(ds, scene_id):
 def predict_samples(ds, model, optimizer, state_path, state_file, cm=False,
                     plot=False, **kwargs):
 
-    # check whether the dataset is a subset
-    if not isinstance(ds, Subset):
-        raise TypeError('ds should be of type {}'.format(Subset))
+    # check whether the dataset is a valid subset, i.e.
+    # an instance of pysegcnn.core.split.SceneSubset or
+    # an instance of pysegcnn.core.split.RandomSubset
+    _name = type(ds).__name__
+    if _name is not RandomSubset.__name__ or _name is not SceneSubset.__name__:
+        raise TypeError('ds should be an instance of {} or of {}'
+                        .format('.'.join([RandomSubset.__module__,
+                                         RandomSubset.__name__]),
+                                '.'.join([SceneSubset.__module__,
+                                         SceneSubset.__name__])
+                                )
+                        )
 
     # the device to compute on, use gpu if available
     device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
@@ -58,6 +68,7 @@ def predict_samples(ds, model, optimizer, state_path, state_file, cm=False,
     # iterate over the samples and plot inputs, ground truth and
     # model predictions
     output = {}
+    print('Predicting samples of the {} dataset ...'.format(ds.name))
     for batch, (inputs, labels) in enumerate(dataloader):
 
         # send inputs and labels to device
@@ -71,6 +82,10 @@ def predict_samples(ds, model, optimizer, state_path, state_file, cm=False,
         # store output for current batch
         output[batch] = {'input': inputs, 'labels': labels, 'prediction': prd}
 
+        print('Sample: {:d}/{:d}, Accuracy: {:.2f}'
+              .format(batch + 1, len(dataloader),
+                      accuracy_function(prd, labels)))
+
         # update confusion matrix
         if cm:
             for ytrue, ypred in zip(labels.view(-1), prd.view(-1)):
@@ -95,9 +110,11 @@ def predict_samples(ds, model, optimizer, state_path, state_file, cm=False,
 def predict_scenes(ds, model, optimizer, state_path, state_file,
                    scene_id=None, cm=False, plot_scenes=False, **kwargs):
 
-    # check if the dataset is an instance of torch.data.dataset.Subset
-    if not isinstance(ds, Subset):
-        raise TypeError('ds should be of type {}'.format(Subset))
+    # check whether the dataset is a valid subset, i.e. an instance of
+    # pysegcnn.core.split.SceneSubset
+    if not type(ds).__name__ is SceneSubset.__name__:
+        raise TypeError('ds should be an instance of {}'.format(
+            '.'.join([SceneSubset.__module__, SceneSubset.__name__])))
 
     # the device to compute on, use gpu if available
     device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
@@ -121,16 +138,8 @@ def predict_scenes(ds, model, optimizer, state_path, state_file,
 
     # check whether a scene id is provided
     if scene_id is None:
-
-        # get the names of the scenes
-        try:
-            scene_ids = ds.ids
-        except AttributeError:
-            raise TypeError('predict_scenes does only work for datasets split '
-                            'by "scene" or by "date".')
-
+        scene_ids = ds.ids
     else:
-
         # the name of the selected scene
         scene_ids = [scene_id]
 
@@ -138,12 +147,12 @@ def predict_scenes(ds, model, optimizer, state_path, state_file,
     scene_size = (ds.dataset.height, ds.dataset.width)
 
     # iterate over the scenes
-    print('Predicting scenes of the subset ...')
-    scene = {}
-    for sid in scene_ids:
+    print('Predicting scenes of the {} dataset ...'.format(ds.name))
+    scenes = {}
+    for i, sid in enumerate(scene_ids):
 
         # filename for the current scene
-        sname = fname + '_' + sid + '.pt'
+        sname = fname + '_{}_{}.pt'.format(ds.name, sid)
 
         # get the indices of the tiles of the scene
         indices = get_scene_tiles(ds, sid)
@@ -157,10 +166,7 @@ def predict_scenes(ds, model, optimizer, state_path, state_file,
                               shuffle=False, drop_last=False)
 
         # predict the current scene
-        for i, (inp, lab) in enumerate(scene_dl):
-            print('Predicting scene ({}/{}), id: {}'.format(i + 1,
-                                                            len(scene_ids),
-                                                            sid))
+        for b, (inp, lab) in enumerate(scene_dl):
 
             # send inputs and labels to device
             inp = inp.to(device)
@@ -180,8 +186,12 @@ def predict_scenes(ds, model, optimizer, state_path, state_file,
         labels = reconstruct_scene(lab, scene_size, nbands=1)
         prdtcn = reconstruct_scene(prd, scene_size, nbands=1)
 
+        # print progress
+        print('Scene {:d}/{:d}, Id: {}, Accuracy: {:.2f}'.format(
+            i + 1, len(scene_ids), sid, accuracy_function(prdtcn, labels)))
+
         # save outputs to dictionary
-        scene[sid] = {'input': inputs, 'labels': labels, 'prediction': prdtcn}
+        scenes[sid] = {'input': inputs, 'labels': labels, 'prediction': prdtcn}
 
         # plot current scene
         if plot_scenes:
@@ -193,4 +203,4 @@ def predict_scenes(ds, model, optimizer, state_path, state_file,
                                   state=sname,
                                   **kwargs)
 
-    return scene, cmm
+    return scenes, cmm
-- 
GitLab