From 5128033c9f24ea07815c9665d30051d90673c7bc Mon Sep 17 00:00:00 2001 From: "Daniel.Frisinghelli" <daniel.frisinghelli@eurac.edu> Date: Thu, 13 Aug 2020 16:36:27 +0200 Subject: [PATCH] Adapted predict functions to changes in trainer.py --- pysegcnn/core/predict.py | 46 ++++++++++++++++++++-------------------- 1 file changed, 23 insertions(+), 23 deletions(-) diff --git a/pysegcnn/core/predict.py b/pysegcnn/core/predict.py index b1e69cd..01b07cf 100644 --- a/pysegcnn/core/predict.py +++ b/pysegcnn/core/predict.py @@ -1,5 +1,6 @@ # builtins import os +import pathlib # externals import numpy as np @@ -26,30 +27,27 @@ def get_scene_tiles(ds, scene_id): return indices -def predict_samples(ds, model, optimizer, state_path, state_file, cm=False, +def predict_samples(ds, model, optimizer, state_file, cm=False, plot=False, **kwargs): # check whether the dataset is a valid subset, i.e. # an instance of pysegcnn.core.split.SceneSubset or # an instance of pysegcnn.core.split.RandomSubset _name = type(ds).__name__ - if _name is not RandomSubset.__name__ or _name is not SceneSubset.__name__: - raise TypeError('ds should be an instance of {} or of {}' - .format('.'.join([RandomSubset.__module__, - RandomSubset.__name__]), - '.'.join([SceneSubset.__module__, - SceneSubset.__name__]) - ) - ) + if not isinstance(ds, RandomSubset) or not isinstance(ds, SceneSubset): + raise TypeError('ds should be an instance of {} or of {}.' + .format(repr(RandomSubset), repr(SceneSubset))) + + # convert state file to pathlib.Path object + state_file = pathlib.Path(state_file) # the device to compute on, use gpu if available device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") # load the pretrained model state - state = os.path.join(state_path, state_file) - if not os.path.exists(state): - raise FileNotFoundError('{} does not exist.'.format(state)) - state = model.load(state_file, optimizer, state_path) + if not state_file.exists(): + raise FileNotFoundError('{} does not exist.'.format(state_file)) + _ = model.load(state_file.name, optimizer, state_file.parent) # set the model to evaluation mode print('Setting model to evaluation mode ...') @@ -57,7 +55,7 @@ def predict_samples(ds, model, optimizer, state_path, state_file, cm=False, model.to(device) # base filename for each sample - fname = state_file.split('.pt')[0] + fname = state_file.name.split('.pt')[0] # initialize confusion matrix cmm = np.zeros(shape=(model.nclasses, model.nclasses)) @@ -107,23 +105,25 @@ def predict_samples(ds, model, optimizer, state_path, state_file, cm=False, return output, cmm -def predict_scenes(ds, model, optimizer, state_path, state_file, +def predict_scenes(ds, model, optimizer, state_file, scene_id=None, cm=False, plot_scenes=False, **kwargs): # check whether the dataset is a valid subset, i.e. an instance of # pysegcnn.core.split.SceneSubset - if not type(ds).__name__ is SceneSubset.__name__: - raise TypeError('ds should be an instance of {}'.format( - '.'.join([SceneSubset.__module__, SceneSubset.__name__]))) + if not isinstance(ds, SceneSubset): + raise TypeError('ds should be an instance of {}.' + .format(repr(SceneSubset))) + + # convert state file to pathlib.Path object + state_file = pathlib.Path(state_file) # the device to compute on, use gpu if available device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") # load the pretrained model state - state = os.path.join(state_path, state_file) - if not os.path.exists(state): - raise FileNotFoundError('{} does not exist.'.format(state)) - state = model.load(state_file, optimizer, state_path) + if not state_file.exists(): + raise FileNotFoundError('{} does not exist.'.format(state_file)) + _ = model.load(state_file.name, optimizer, state_file.parent) # set the model to evaluation mode print('Setting model to evaluation mode ...') @@ -131,7 +131,7 @@ def predict_scenes(ds, model, optimizer, state_path, state_file, model.to(device) # base filename for each scene - fname = state_file.split('.pt')[0] + fname = state_file.name.split('.pt')[0] # initialize confusion matrix cmm = np.zeros(shape=(model.nclasses, model.nclasses)) -- GitLab