"""Functions for model inference.
License
-------
Copyright (c) 2020 Daniel Frisinghelli
This source code is licensed under the GNU General Public License v3.
See the LICENSE file in the repository's root directory.
"""
# !/usr/bin/env python
# -*- coding: utf-8 -*-
# builtins
import logging
# externals
import numpy as np
import torch
from torch.utils.data import DataLoader
from torch.utils.data.dataset import Subset
import torch.nn.functional as F
# locals
from pysegcnn.core.utils import reconstruct_scene, accuracy_function
from pysegcnn.core.graphics import plot_sample
from pysegcnn.core.split import RandomSubset, SceneSubset
# module level logger
LOGGER = logging.getLogger(__name__)
def _get_scene_tiles(ds, scene_id):
"""Return the tiles of the scene with id = ``scene_id``.
Parameters
----------
ds : `pysegcnn.core.split.RandomSubset` or
`pysegcnn.core.split.SceneSubset`
An instance of `~pysegcnn.core.split.RandomSubset` or
`~pysegcnn.core.split.SceneSubset`.
scene_id : `str`
A valid scene identifier.
Raises
------
ValueError
Raised if ``scene_id`` is not a valid scene identifier for the dataset
``ds``.
Returns
-------
indices : `list` [`int`]
List of indices of the tiles from scene with id ``scene_id`` in ``ds``.
"""
# check if the scene id is valid
scene_meta = ds.dataset.parse_scene_id(scene_id)
if scene_meta is None:
raise ValueError('{} is not a valid scene identifier'.format(scene_id))
# iterate over the scenes of the dataset
indices = []
for i, scene in enumerate(ds.scenes):
# if the scene id matches a given id, save the index of the scene
if scene['id'] == scene_id:
indices.append(i)
return indices
[docs]def predict_samples(ds, model, cm=False, plot=False, **kwargs):
"""Classify each sample in ``ds`` with model ``model``.
Parameters
----------
ds : `pysegcnn.core.split.RandomSubset` or
`pysegcnn.core.split.SceneSubset`
An instance of `~pysegcnn.core.split.RandomSubset` or
`~pysegcnn.core.split.SceneSubset`.
model : `pysegcnn.core.models.Network`
An instance of `~pysegcnn.core.models.Network`.
cm : `bool`, optional
Whether to compute the confusion matrix. The default is False.
plot : `bool`, optional
Whether to plot a false color composite, ground truth and model
prediction for each sample. The default is False.
**kwargs
Additional keyword arguments passed to
`pysegcnn.core.graphics.plot_sample`.
Raises
------
TypeError
Raised if ``ds`` is not an instance of
`~pysegcnn.core.split.RandomSubset` or
`~pysegcnn.core.split.SceneSubset`.
Returns
-------
output : `dict`
Output dictionary with keys:
``'input'``
Model input data
``'labels'``
The ground truth
``'prediction'``
Model prediction
conf_mat : `numpy.ndarray`
The confusion matrix. Note that the confusion matrix ``conf_mat`` is
only computed if ``cm`` = True.
"""
# check whether the dataset is a valid subset, i.e.
# an instance of pysegcnn.core.split.SceneSubset or
# an instance of pysegcnn.core.split.RandomSubset
if not isinstance(ds, RandomSubset) or not isinstance(ds, SceneSubset):
raise TypeError('ds should be an instance of {} or of {}.'
.format(repr(RandomSubset), repr(SceneSubset)))
# the device to compute on, use gpu if available
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# set the model to evaluation mode
LOGGER.info('Setting model to evaluation mode ...')
model.eval()
model.to(device)
# base filename for each sample
fname = model.state_file.name.split('.pt')[0]
# initialize confusion matrix
conf_mat = np.zeros(shape=(model.nclasses, model.nclasses))
# create the dataloader
dataloader = DataLoader(ds, batch_size=1, shuffle=False, drop_last=False)
# iterate over the samples and plot inputs, ground truth and
# model predictions
output = {}
LOGGER.info('Predicting samples of the {} dataset ...'.format(ds.name))
for batch, (inputs, labels) in enumerate(dataloader):
# send inputs and labels to device
inputs = inputs.to(device)
labels = labels.to(device)
# compute model predictions
with torch.no_grad():
prd = F.softmax(model(inputs), dim=1).argmax(dim=1).squeeze()
# store output for current batch
output[batch] = {'input': inputs, 'labels': labels, 'prediction': prd}
LOGGER.info('Sample: {:d}/{:d}, Accuracy: {:.2f}'.format(
batch + 1, len(dataloader), accuracy_function(prd, labels)))
# update confusion matrix
if cm:
for ytrue, ypred in zip(labels.view(-1), prd.view(-1)):
conf_mat[ytrue.long(), ypred.long()] += 1
# save plot of current batch to disk
if plot:
# plot inputs, ground truth and model predictions
sname = fname + '_{}_{}.pt'.format(ds.name, batch)
fig, ax = plot_sample(inputs.numpy().clip(0, 1),
ds.dataset.use_bands,
ds.dataset.labels,
y=labels,
y_pred=prd,
state=sname,
**kwargs)
return output, conf_mat
[docs]def predict_scenes(ds, model, scene_id=None, cm=False, plot=False, **kwargs):
"""Classify each scene in ``ds`` with model ``model``.
Parameters
----------
ds : `pysegcnn.core.split.SceneSubset`
An instance of `~pysegcnn.core.split.SceneSubset`.
model : `pysegcnn.core.models.Network`
An instance of `~pysegcnn.core.models.Network`.
scene_id : `str` or `None`
A valid scene identifier.
cm : `bool`, optional
Whether to compute the confusion matrix. The default is False.
plot : `bool`, optional
Whether to plot a false color composite, ground truth and model
prediction for each scene. The default is False.
**kwargs
Additional keyword arguments passed to
`pysegcnn.core.graphics.plot_sample`.
Raises
------
TypeError
Raised if ``ds`` is not an instance of
`~pysegcnn.core.split.SceneSubset`.
Returns
-------
output : `dict`
Output dictionary with keys:
``'input'``
Model input data
``'labels'``
The ground truth
``'prediction'``
Model prediction
conf_mat : `numpy.ndarray`
The confusion matrix. Note that the confusion matrix ``conf_mat`` is
only computed if ``cm`` = True.
"""
# check whether the dataset is a valid subset, i.e. an instance of
# pysegcnn.core.split.SceneSubset
if not isinstance(ds, SceneSubset):
raise TypeError('ds should be an instance of {}.'
.format(repr(SceneSubset)))
# the device to compute on, use gpu if available
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# set the model to evaluation mode
LOGGER.info('Setting model to evaluation mode ...')
model.eval()
model.to(device)
# base filename for each scene
fname = model.state_file.name.split('.pt')[0]
# initialize confusion matrix
conf_mat = np.zeros(shape=(model.nclasses, model.nclasses))
# check whether a scene id is provided
if scene_id is None:
scene_ids = ds.ids
else:
# the name of the selected scene
scene_ids = [scene_id]
# iterate over the scenes
LOGGER.info('Predicting scenes of the {} dataset ...'.format(ds.name))
output = {}
for i, sid in enumerate(scene_ids):
# filename for the current scene
sname = fname + '_{}_{}.pt'.format(ds.name, sid)
# get the indices of the tiles of the scene
indices = _get_scene_tiles(ds, sid)
indices.sort()
# create a subset of the dataset
scene_ds = Subset(ds, indices)
# create the dataloader
scene_dl = DataLoader(scene_ds, batch_size=len(scene_ds),
shuffle=False, drop_last=False)
# predict the current scene
for b, (inp, lab) in enumerate(scene_dl):
# send inputs and labels to device
inp = inp.to(device)
lab = lab.to(device)
# apply forward pass: model prediction
with torch.no_grad():
prd = F.softmax(model(inp), dim=1).argmax(dim=1).squeeze()
# update confusion matrix
if cm:
for ytrue, ypred in zip(lab.view(-1), prd.view(-1)):
conf_mat[ytrue.long(), ypred.long()] += 1
# reconstruct the entire scene
inputs = reconstruct_scene(inp)
labels = reconstruct_scene(lab)
prdtcn = reconstruct_scene(prd)
# print progress
LOGGER.info('Scene {:d}/{:d}, Id: {}, Accuracy: {:.2f}'.format(
i + 1, len(scene_ids), sid, accuracy_function(prdtcn, labels)))
# save outputs to dictionary
output[sid] = {'input': inputs, 'labels': labels, 'prediction': prdtcn}
# plot current scene
if plot:
fig, ax = plot_sample(inputs.clip(0, 1),
ds.dataset.use_bands,
ds.dataset.labels,
y=labels,
y_pred=prdtcn,
state=sname,
**kwargs)
return output, conf_mat