Skip to content
Snippets Groups Projects
Commit 5128033c authored by Frisinghelli Daniel's avatar Frisinghelli Daniel
Browse files

Adapted predict functions to changes in trainer.py

parent ef5cae60
No related branches found
No related tags found
No related merge requests found
# builtins
import os
import pathlib
# externals
import numpy as np
......@@ -26,30 +27,27 @@ def get_scene_tiles(ds, scene_id):
return indices
def predict_samples(ds, model, optimizer, state_path, state_file, cm=False,
def predict_samples(ds, model, optimizer, state_file, cm=False,
plot=False, **kwargs):
# check whether the dataset is a valid subset, i.e.
# an instance of pysegcnn.core.split.SceneSubset or
# an instance of pysegcnn.core.split.RandomSubset
_name = type(ds).__name__
if _name is not RandomSubset.__name__ or _name is not SceneSubset.__name__:
raise TypeError('ds should be an instance of {} or of {}'
.format('.'.join([RandomSubset.__module__,
RandomSubset.__name__]),
'.'.join([SceneSubset.__module__,
SceneSubset.__name__])
)
)
if not isinstance(ds, RandomSubset) or not isinstance(ds, SceneSubset):
raise TypeError('ds should be an instance of {} or of {}.'
.format(repr(RandomSubset), repr(SceneSubset)))
# convert state file to pathlib.Path object
state_file = pathlib.Path(state_file)
# the device to compute on, use gpu if available
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# load the pretrained model state
state = os.path.join(state_path, state_file)
if not os.path.exists(state):
raise FileNotFoundError('{} does not exist.'.format(state))
state = model.load(state_file, optimizer, state_path)
if not state_file.exists():
raise FileNotFoundError('{} does not exist.'.format(state_file))
_ = model.load(state_file.name, optimizer, state_file.parent)
# set the model to evaluation mode
print('Setting model to evaluation mode ...')
......@@ -57,7 +55,7 @@ def predict_samples(ds, model, optimizer, state_path, state_file, cm=False,
model.to(device)
# base filename for each sample
fname = state_file.split('.pt')[0]
fname = state_file.name.split('.pt')[0]
# initialize confusion matrix
cmm = np.zeros(shape=(model.nclasses, model.nclasses))
......@@ -107,23 +105,25 @@ def predict_samples(ds, model, optimizer, state_path, state_file, cm=False,
return output, cmm
def predict_scenes(ds, model, optimizer, state_path, state_file,
def predict_scenes(ds, model, optimizer, state_file,
scene_id=None, cm=False, plot_scenes=False, **kwargs):
# check whether the dataset is a valid subset, i.e. an instance of
# pysegcnn.core.split.SceneSubset
if not type(ds).__name__ is SceneSubset.__name__:
raise TypeError('ds should be an instance of {}'.format(
'.'.join([SceneSubset.__module__, SceneSubset.__name__])))
if not isinstance(ds, SceneSubset):
raise TypeError('ds should be an instance of {}.'
.format(repr(SceneSubset)))
# convert state file to pathlib.Path object
state_file = pathlib.Path(state_file)
# the device to compute on, use gpu if available
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# load the pretrained model state
state = os.path.join(state_path, state_file)
if not os.path.exists(state):
raise FileNotFoundError('{} does not exist.'.format(state))
state = model.load(state_file, optimizer, state_path)
if not state_file.exists():
raise FileNotFoundError('{} does not exist.'.format(state_file))
_ = model.load(state_file.name, optimizer, state_file.parent)
# set the model to evaluation mode
print('Setting model to evaluation mode ...')
......@@ -131,7 +131,7 @@ def predict_scenes(ds, model, optimizer, state_path, state_file,
model.to(device)
# base filename for each scene
fname = state_file.split('.pt')[0]
fname = state_file.name.split('.pt')[0]
# initialize confusion matrix
cmm = np.zeros(shape=(model.nclasses, model.nclasses))
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment