Skip to content
Snippets Groups Projects
Commit 5128033c authored by Frisinghelli Daniel's avatar Frisinghelli Daniel
Browse files

Adapted predict functions to changes in trainer.py

parent ef5cae60
No related branches found
No related tags found
No related merge requests found
# builtins # builtins
import os import os
import pathlib
# externals # externals
import numpy as np import numpy as np
...@@ -26,30 +27,27 @@ def get_scene_tiles(ds, scene_id): ...@@ -26,30 +27,27 @@ def get_scene_tiles(ds, scene_id):
return indices return indices
def predict_samples(ds, model, optimizer, state_path, state_file, cm=False, def predict_samples(ds, model, optimizer, state_file, cm=False,
plot=False, **kwargs): plot=False, **kwargs):
# check whether the dataset is a valid subset, i.e. # check whether the dataset is a valid subset, i.e.
# an instance of pysegcnn.core.split.SceneSubset or # an instance of pysegcnn.core.split.SceneSubset or
# an instance of pysegcnn.core.split.RandomSubset # an instance of pysegcnn.core.split.RandomSubset
_name = type(ds).__name__ _name = type(ds).__name__
if _name is not RandomSubset.__name__ or _name is not SceneSubset.__name__: if not isinstance(ds, RandomSubset) or not isinstance(ds, SceneSubset):
raise TypeError('ds should be an instance of {} or of {}' raise TypeError('ds should be an instance of {} or of {}.'
.format('.'.join([RandomSubset.__module__, .format(repr(RandomSubset), repr(SceneSubset)))
RandomSubset.__name__]),
'.'.join([SceneSubset.__module__, # convert state file to pathlib.Path object
SceneSubset.__name__]) state_file = pathlib.Path(state_file)
)
)
# the device to compute on, use gpu if available # the device to compute on, use gpu if available
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# load the pretrained model state # load the pretrained model state
state = os.path.join(state_path, state_file) if not state_file.exists():
if not os.path.exists(state): raise FileNotFoundError('{} does not exist.'.format(state_file))
raise FileNotFoundError('{} does not exist.'.format(state)) _ = model.load(state_file.name, optimizer, state_file.parent)
state = model.load(state_file, optimizer, state_path)
# set the model to evaluation mode # set the model to evaluation mode
print('Setting model to evaluation mode ...') print('Setting model to evaluation mode ...')
...@@ -57,7 +55,7 @@ def predict_samples(ds, model, optimizer, state_path, state_file, cm=False, ...@@ -57,7 +55,7 @@ def predict_samples(ds, model, optimizer, state_path, state_file, cm=False,
model.to(device) model.to(device)
# base filename for each sample # base filename for each sample
fname = state_file.split('.pt')[0] fname = state_file.name.split('.pt')[0]
# initialize confusion matrix # initialize confusion matrix
cmm = np.zeros(shape=(model.nclasses, model.nclasses)) cmm = np.zeros(shape=(model.nclasses, model.nclasses))
...@@ -107,23 +105,25 @@ def predict_samples(ds, model, optimizer, state_path, state_file, cm=False, ...@@ -107,23 +105,25 @@ def predict_samples(ds, model, optimizer, state_path, state_file, cm=False,
return output, cmm return output, cmm
def predict_scenes(ds, model, optimizer, state_path, state_file, def predict_scenes(ds, model, optimizer, state_file,
scene_id=None, cm=False, plot_scenes=False, **kwargs): scene_id=None, cm=False, plot_scenes=False, **kwargs):
# check whether the dataset is a valid subset, i.e. an instance of # check whether the dataset is a valid subset, i.e. an instance of
# pysegcnn.core.split.SceneSubset # pysegcnn.core.split.SceneSubset
if not type(ds).__name__ is SceneSubset.__name__: if not isinstance(ds, SceneSubset):
raise TypeError('ds should be an instance of {}'.format( raise TypeError('ds should be an instance of {}.'
'.'.join([SceneSubset.__module__, SceneSubset.__name__]))) .format(repr(SceneSubset)))
# convert state file to pathlib.Path object
state_file = pathlib.Path(state_file)
# the device to compute on, use gpu if available # the device to compute on, use gpu if available
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# load the pretrained model state # load the pretrained model state
state = os.path.join(state_path, state_file) if not state_file.exists():
if not os.path.exists(state): raise FileNotFoundError('{} does not exist.'.format(state_file))
raise FileNotFoundError('{} does not exist.'.format(state)) _ = model.load(state_file.name, optimizer, state_file.parent)
state = model.load(state_file, optimizer, state_path)
# set the model to evaluation mode # set the model to evaluation mode
print('Setting model to evaluation mode ...') print('Setting model to evaluation mode ...')
...@@ -131,7 +131,7 @@ def predict_scenes(ds, model, optimizer, state_path, state_file, ...@@ -131,7 +131,7 @@ def predict_scenes(ds, model, optimizer, state_path, state_file,
model.to(device) model.to(device)
# base filename for each scene # base filename for each scene
fname = state_file.split('.pt')[0] fname = state_file.name.split('.pt')[0]
# initialize confusion matrix # initialize confusion matrix
cmm = np.zeros(shape=(model.nclasses, model.nclasses)) cmm = np.zeros(shape=(model.nclasses, model.nclasses))
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment