Skip to content
Snippets Groups Projects
Commit f8a96422 authored by Frisinghelli Daniel's avatar Frisinghelli Daniel
Browse files

Included type handling and improved console prints

parent 5a99371d
No related branches found
No related tags found
No related merge requests found
...@@ -9,8 +9,9 @@ from torch.utils.data.dataset import Subset ...@@ -9,8 +9,9 @@ from torch.utils.data.dataset import Subset
import torch.nn.functional as F import torch.nn.functional as F
# locals # locals
from pysegcnn.core.utils import reconstruct_scene from pysegcnn.core.utils import reconstruct_scene, accuracy_function
from pysegcnn.core.graphics import plot_sample from pysegcnn.core.graphics import plot_sample
from pysegcnn.core.split import RandomSubset, SceneSubset
def get_scene_tiles(ds, scene_id): def get_scene_tiles(ds, scene_id):
...@@ -28,9 +29,18 @@ def get_scene_tiles(ds, scene_id): ...@@ -28,9 +29,18 @@ def get_scene_tiles(ds, scene_id):
def predict_samples(ds, model, optimizer, state_path, state_file, cm=False, def predict_samples(ds, model, optimizer, state_path, state_file, cm=False,
plot=False, **kwargs): plot=False, **kwargs):
# check whether the dataset is a subset # check whether the dataset is a valid subset, i.e.
if not isinstance(ds, Subset): # an instance of pysegcnn.core.split.SceneSubset or
raise TypeError('ds should be of type {}'.format(Subset)) # an instance of pysegcnn.core.split.RandomSubset
_name = type(ds).__name__
if _name is not RandomSubset.__name__ or _name is not SceneSubset.__name__:
raise TypeError('ds should be an instance of {} or of {}'
.format('.'.join([RandomSubset.__module__,
RandomSubset.__name__]),
'.'.join([SceneSubset.__module__,
SceneSubset.__name__])
)
)
# the device to compute on, use gpu if available # the device to compute on, use gpu if available
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
...@@ -58,6 +68,7 @@ def predict_samples(ds, model, optimizer, state_path, state_file, cm=False, ...@@ -58,6 +68,7 @@ def predict_samples(ds, model, optimizer, state_path, state_file, cm=False,
# iterate over the samples and plot inputs, ground truth and # iterate over the samples and plot inputs, ground truth and
# model predictions # model predictions
output = {} output = {}
print('Predicting samples of the {} dataset ...'.format(ds.name))
for batch, (inputs, labels) in enumerate(dataloader): for batch, (inputs, labels) in enumerate(dataloader):
# send inputs and labels to device # send inputs and labels to device
...@@ -71,6 +82,10 @@ def predict_samples(ds, model, optimizer, state_path, state_file, cm=False, ...@@ -71,6 +82,10 @@ def predict_samples(ds, model, optimizer, state_path, state_file, cm=False,
# store output for current batch # store output for current batch
output[batch] = {'input': inputs, 'labels': labels, 'prediction': prd} output[batch] = {'input': inputs, 'labels': labels, 'prediction': prd}
print('Sample: {:d}/{:d}, Accuracy: {:.2f}'
.format(batch + 1, len(dataloader),
accuracy_function(prd, labels)))
# update confusion matrix # update confusion matrix
if cm: if cm:
for ytrue, ypred in zip(labels.view(-1), prd.view(-1)): for ytrue, ypred in zip(labels.view(-1), prd.view(-1)):
...@@ -95,9 +110,11 @@ def predict_samples(ds, model, optimizer, state_path, state_file, cm=False, ...@@ -95,9 +110,11 @@ def predict_samples(ds, model, optimizer, state_path, state_file, cm=False,
def predict_scenes(ds, model, optimizer, state_path, state_file, def predict_scenes(ds, model, optimizer, state_path, state_file,
scene_id=None, cm=False, plot_scenes=False, **kwargs): scene_id=None, cm=False, plot_scenes=False, **kwargs):
# check if the dataset is an instance of torch.data.dataset.Subset # check whether the dataset is a valid subset, i.e. an instance of
if not isinstance(ds, Subset): # pysegcnn.core.split.SceneSubset
raise TypeError('ds should be of type {}'.format(Subset)) if not type(ds).__name__ is SceneSubset.__name__:
raise TypeError('ds should be an instance of {}'.format(
'.'.join([SceneSubset.__module__, SceneSubset.__name__])))
# the device to compute on, use gpu if available # the device to compute on, use gpu if available
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
...@@ -121,16 +138,8 @@ def predict_scenes(ds, model, optimizer, state_path, state_file, ...@@ -121,16 +138,8 @@ def predict_scenes(ds, model, optimizer, state_path, state_file,
# check whether a scene id is provided # check whether a scene id is provided
if scene_id is None: if scene_id is None:
scene_ids = ds.ids
# get the names of the scenes
try:
scene_ids = ds.ids
except AttributeError:
raise TypeError('predict_scenes does only work for datasets split '
'by "scene" or by "date".')
else: else:
# the name of the selected scene # the name of the selected scene
scene_ids = [scene_id] scene_ids = [scene_id]
...@@ -138,12 +147,12 @@ def predict_scenes(ds, model, optimizer, state_path, state_file, ...@@ -138,12 +147,12 @@ def predict_scenes(ds, model, optimizer, state_path, state_file,
scene_size = (ds.dataset.height, ds.dataset.width) scene_size = (ds.dataset.height, ds.dataset.width)
# iterate over the scenes # iterate over the scenes
print('Predicting scenes of the subset ...') print('Predicting scenes of the {} dataset ...'.format(ds.name))
scene = {} scenes = {}
for sid in scene_ids: for i, sid in enumerate(scene_ids):
# filename for the current scene # filename for the current scene
sname = fname + '_' + sid + '.pt' sname = fname + '_{}_{}.pt'.format(ds.name, sid)
# get the indices of the tiles of the scene # get the indices of the tiles of the scene
indices = get_scene_tiles(ds, sid) indices = get_scene_tiles(ds, sid)
...@@ -157,10 +166,7 @@ def predict_scenes(ds, model, optimizer, state_path, state_file, ...@@ -157,10 +166,7 @@ def predict_scenes(ds, model, optimizer, state_path, state_file,
shuffle=False, drop_last=False) shuffle=False, drop_last=False)
# predict the current scene # predict the current scene
for i, (inp, lab) in enumerate(scene_dl): for b, (inp, lab) in enumerate(scene_dl):
print('Predicting scene ({}/{}), id: {}'.format(i + 1,
len(scene_ids),
sid))
# send inputs and labels to device # send inputs and labels to device
inp = inp.to(device) inp = inp.to(device)
...@@ -180,8 +186,12 @@ def predict_scenes(ds, model, optimizer, state_path, state_file, ...@@ -180,8 +186,12 @@ def predict_scenes(ds, model, optimizer, state_path, state_file,
labels = reconstruct_scene(lab, scene_size, nbands=1) labels = reconstruct_scene(lab, scene_size, nbands=1)
prdtcn = reconstruct_scene(prd, scene_size, nbands=1) prdtcn = reconstruct_scene(prd, scene_size, nbands=1)
# print progress
print('Scene {:d}/{:d}, Id: {}, Accuracy: {:.2f}'.format(
i + 1, len(scene_ids), sid, accuracy_function(prdtcn, labels)))
# save outputs to dictionary # save outputs to dictionary
scene[sid] = {'input': inputs, 'labels': labels, 'prediction': prdtcn} scenes[sid] = {'input': inputs, 'labels': labels, 'prediction': prdtcn}
# plot current scene # plot current scene
if plot_scenes: if plot_scenes:
...@@ -193,4 +203,4 @@ def predict_scenes(ds, model, optimizer, state_path, state_file, ...@@ -193,4 +203,4 @@ def predict_scenes(ds, model, optimizer, state_path, state_file,
state=sname, state=sname,
**kwargs) **kwargs)
return scene, cmm return scenes, cmm
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment