Skip to content
Snippets Groups Projects
Commit 46fec41a authored by Frisinghelli Daniel's avatar Frisinghelli Daniel
Browse files

Adjusted prediction modules to changes in core.trainer.py

parent fc9b837e
No related branches found
No related tags found
No related merge requests found
......@@ -31,8 +31,7 @@ def get_scene_tiles(ds, scene_id):
return indices
def predict_samples(ds, model, optimizer, state_file, cm=False,
plot=False, **kwargs):
def predict_samples(ds, model, cm=False, plot=False, **kwargs):
# check whether the dataset is a valid subset, i.e.
# an instance of pysegcnn.core.split.SceneSubset or
......@@ -41,24 +40,16 @@ def predict_samples(ds, model, optimizer, state_file, cm=False,
raise TypeError('ds should be an instance of {} or of {}.'
.format(repr(RandomSubset), repr(SceneSubset)))
# convert state file to pathlib.Path object
state_file = pathlib.Path(state_file)
# the device to compute on, use gpu if available
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# load the pretrained model state
if not state_file.exists():
raise FileNotFoundError('{} does not exist.'.format(state_file))
_ = model.load(state_file.name, optimizer, state_file.parent)
# set the model to evaluation mode
LOGGER.info('Setting model to evaluation mode ...')
model.eval()
model.to(device)
# base filename for each sample
fname = state_file.name.split('.pt')[0]
fname = model.state_file.name.split('.pt')[0]
# initialize confusion matrix
cmm = np.zeros(shape=(model.nclasses, model.nclasses))
......@@ -107,8 +98,7 @@ def predict_samples(ds, model, optimizer, state_file, cm=False,
return output, cmm
def predict_scenes(ds, model, optimizer, state_file,
scene_id=None, cm=False, plot_scenes=False, **kwargs):
def predict_scenes(ds, model, scene_id=None, cm=False, plot=False, **kwargs):
# check whether the dataset is a valid subset, i.e. an instance of
# pysegcnn.core.split.SceneSubset
......@@ -116,24 +106,16 @@ def predict_scenes(ds, model, optimizer, state_file,
raise TypeError('ds should be an instance of {}.'
.format(repr(SceneSubset)))
# convert state file to pathlib.Path object
state_file = pathlib.Path(state_file)
# the device to compute on, use gpu if available
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# load the pretrained model state
if not state_file.exists():
raise FileNotFoundError('{} does not exist.'.format(state_file))
_ = model.load(state_file.name, optimizer, state_file.parent)
# set the model to evaluation mode
LOGGER.info('Setting model to evaluation mode ...')
model.eval()
model.to(device)
# base filename for each scene
fname = state_file.name.split('.pt')[0]
fname = model.state_file.name.split('.pt')[0]
# initialize confusion matrix
cmm = np.zeros(shape=(model.nclasses, model.nclasses))
......@@ -196,7 +178,7 @@ def predict_scenes(ds, model, optimizer, state_file,
scenes[sid] = {'input': inputs, 'labels': labels, 'prediction': prdtcn}
# plot current scene
if plot_scenes:
if plot:
fig, ax = plot_sample(inputs.clip(0, 1),
labels,
ds.dataset.use_bands,
......
......@@ -9,7 +9,7 @@ import os
# locals
from pysegcnn.core.trainer import (DatasetConfig, SplitConfig, ModelConfig,
TrainConfig, EvalConfig)
StateConfig, EvalConfig)
from pysegcnn.core.predict import predict_samples, predict_scenes
from pysegcnn.main.config import (dataset_config, split_config, model_config,
train_config, eval_config, HERE)
......@@ -22,24 +22,21 @@ if __name__ == '__main__':
dc = DatasetConfig(**dataset_config)
sc = SplitConfig(**split_config)
mc = ModelConfig(**model_config)
tc = TrainConfig(**train_config)
ec = EvalConfig(**eval_config)
# (ii) instanciate the dataset
ds = dc.init_dataset()
ds
# (iii) instanciate the training, validation and test datasets
train_ds, valid_ds, test_ds = sc.train_val_test_split(ds)
# (iv) instanciate the model state files
state_file, loss_state = mc.init_state(ds, sc, tc)
# (iv) instanciate the model state
state = StateConfig(ds, sc, mc)
state_file, loss_state = state.init_state()
# (v) instanciate the model
model = mc.init_model(ds)
# (vi) instanciate the optimizer
optimizer = tc.init_optimizer(model)
# (vii) load pretrained model weights
model, _ = mc.load_pretrained(state_file)
model.state_file = state_file
# plot loss and accuracy
plot_loss(loss_state, outpath=os.path.join(HERE, '_graphics/'))
......@@ -54,37 +51,23 @@ if __name__ == '__main__':
# keyword arguments for plotting
kwargs = {'bands': ec.plot_bands,
'outpath': os.path.join(HERE, '_scenes/'),
'stretch': True,
'alpha': 5}
'alpha': ec.alpha,
'figsize': ec.figsize}
# whether to predict each sample or each scene individually
if ec.predict_scene:
# reconstruct and predict the scenes in the validation/test set
scenes, cm = predict_scenes(ds,
model,
optimizer,
state_file,
scene_id=None,
cm=ec.cm,
plot_scenes=ec.plot_scenes,
**kwargs
)
scenes, cm = predict_scenes(ds, model, scene_id=None, cm=ec.cm,
plot=ec.plot_scenes, **kwargs)
else:
# predict the samples in the validation/test set
samples, cm = predict_samples(ds,
model,
optimizer,
state_file,
cm=ec.cm,
plot_scenes=ec.plot_scenes,
**kwargs)
samples, cm = predict_samples(ds, model, cm=ec.cm,
plot=ec.plot_samples, **kwargs)
# whether to plot the confusion matrix
if ec.cm:
plot_confusion_matrix(cm,
ds.dataset.labels,
normalize=True,
plot_confusion_matrix(cm, ds.dataset.labels,
state=state_file.name.replace('.pt', '.png'),
outpath=os.path.join(HERE, '_graphics/')
)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment