Skip to content
Snippets Groups Projects
Commit f8a96422 authored by Frisinghelli Daniel's avatar Frisinghelli Daniel
Browse files

Included type handling and improved console prints

parent 5a99371d
No related branches found
No related tags found
No related merge requests found
......@@ -9,8 +9,9 @@ from torch.utils.data.dataset import Subset
import torch.nn.functional as F
# locals
from pysegcnn.core.utils import reconstruct_scene
from pysegcnn.core.utils import reconstruct_scene, accuracy_function
from pysegcnn.core.graphics import plot_sample
from pysegcnn.core.split import RandomSubset, SceneSubset
def get_scene_tiles(ds, scene_id):
......@@ -28,9 +29,18 @@ def get_scene_tiles(ds, scene_id):
def predict_samples(ds, model, optimizer, state_path, state_file, cm=False,
plot=False, **kwargs):
# check whether the dataset is a subset
if not isinstance(ds, Subset):
raise TypeError('ds should be of type {}'.format(Subset))
# check whether the dataset is a valid subset, i.e.
# an instance of pysegcnn.core.split.SceneSubset or
# an instance of pysegcnn.core.split.RandomSubset
_name = type(ds).__name__
if _name is not RandomSubset.__name__ or _name is not SceneSubset.__name__:
raise TypeError('ds should be an instance of {} or of {}'
.format('.'.join([RandomSubset.__module__,
RandomSubset.__name__]),
'.'.join([SceneSubset.__module__,
SceneSubset.__name__])
)
)
# the device to compute on, use gpu if available
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
......@@ -58,6 +68,7 @@ def predict_samples(ds, model, optimizer, state_path, state_file, cm=False,
# iterate over the samples and plot inputs, ground truth and
# model predictions
output = {}
print('Predicting samples of the {} dataset ...'.format(ds.name))
for batch, (inputs, labels) in enumerate(dataloader):
# send inputs and labels to device
......@@ -71,6 +82,10 @@ def predict_samples(ds, model, optimizer, state_path, state_file, cm=False,
# store output for current batch
output[batch] = {'input': inputs, 'labels': labels, 'prediction': prd}
print('Sample: {:d}/{:d}, Accuracy: {:.2f}'
.format(batch + 1, len(dataloader),
accuracy_function(prd, labels)))
# update confusion matrix
if cm:
for ytrue, ypred in zip(labels.view(-1), prd.view(-1)):
......@@ -95,9 +110,11 @@ def predict_samples(ds, model, optimizer, state_path, state_file, cm=False,
def predict_scenes(ds, model, optimizer, state_path, state_file,
scene_id=None, cm=False, plot_scenes=False, **kwargs):
# check if the dataset is an instance of torch.data.dataset.Subset
if not isinstance(ds, Subset):
raise TypeError('ds should be of type {}'.format(Subset))
# check whether the dataset is a valid subset, i.e. an instance of
# pysegcnn.core.split.SceneSubset
if not type(ds).__name__ is SceneSubset.__name__:
raise TypeError('ds should be an instance of {}'.format(
'.'.join([SceneSubset.__module__, SceneSubset.__name__])))
# the device to compute on, use gpu if available
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
......@@ -121,16 +138,8 @@ def predict_scenes(ds, model, optimizer, state_path, state_file,
# check whether a scene id is provided
if scene_id is None:
# get the names of the scenes
try:
scene_ids = ds.ids
except AttributeError:
raise TypeError('predict_scenes does only work for datasets split '
'by "scene" or by "date".')
scene_ids = ds.ids
else:
# the name of the selected scene
scene_ids = [scene_id]
......@@ -138,12 +147,12 @@ def predict_scenes(ds, model, optimizer, state_path, state_file,
scene_size = (ds.dataset.height, ds.dataset.width)
# iterate over the scenes
print('Predicting scenes of the subset ...')
scene = {}
for sid in scene_ids:
print('Predicting scenes of the {} dataset ...'.format(ds.name))
scenes = {}
for i, sid in enumerate(scene_ids):
# filename for the current scene
sname = fname + '_' + sid + '.pt'
sname = fname + '_{}_{}.pt'.format(ds.name, sid)
# get the indices of the tiles of the scene
indices = get_scene_tiles(ds, sid)
......@@ -157,10 +166,7 @@ def predict_scenes(ds, model, optimizer, state_path, state_file,
shuffle=False, drop_last=False)
# predict the current scene
for i, (inp, lab) in enumerate(scene_dl):
print('Predicting scene ({}/{}), id: {}'.format(i + 1,
len(scene_ids),
sid))
for b, (inp, lab) in enumerate(scene_dl):
# send inputs and labels to device
inp = inp.to(device)
......@@ -180,8 +186,12 @@ def predict_scenes(ds, model, optimizer, state_path, state_file,
labels = reconstruct_scene(lab, scene_size, nbands=1)
prdtcn = reconstruct_scene(prd, scene_size, nbands=1)
# print progress
print('Scene {:d}/{:d}, Id: {}, Accuracy: {:.2f}'.format(
i + 1, len(scene_ids), sid, accuracy_function(prdtcn, labels)))
# save outputs to dictionary
scene[sid] = {'input': inputs, 'labels': labels, 'prediction': prdtcn}
scenes[sid] = {'input': inputs, 'labels': labels, 'prediction': prdtcn}
# plot current scene
if plot_scenes:
......@@ -193,4 +203,4 @@ def predict_scenes(ds, model, optimizer, state_path, state_file,
state=sname,
**kwargs)
return scene, cmm
return scenes, cmm
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment