Skip to content
Snippets Groups Projects
Commit 7cda6542 authored by Frisinghelli Daniel's avatar Frisinghelli Daniel
Browse files

Removed option to only predict some samples; not required

parent 405feffb
No related branches found
No related tags found
No related merge requests found
......@@ -25,9 +25,12 @@ def get_scene_tiles(ds, scene_id):
return indices
def predict_samples(ds, model, optimizer, state_path, state_file, nsamples,
seed, batch_size=None, cm=False, plot_samples=False,
**kwargs):
def predict_samples(ds, model, optimizer, state_path, state_file, cm=False,
plot=False, **kwargs):
# check whether the dataset is a subset
if not isinstance(ds, Subset):
raise TypeError('ds should be of type {}'.format(Subset))
# the device to compute on, use gpu if available
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
......@@ -49,26 +52,13 @@ def predict_samples(ds, model, optimizer, state_path, state_file, nsamples,
# initialize confusion matrix
cmm = np.zeros(shape=(model.nclasses, model.nclasses))
# set random seed for reproducibility
np.random.seed(seed)
# draw a number of samples from the dataset
samples = np.arange(0, len(ds))
if nsamples > 0:
batch_size = nsamples
samples = np.random.choice(samples, size=min(nsamples, len(ds)))
# create a subset of the dataset
smpl_subset = Subset(ds, samples.tolist())
if batch_size is None:
raise ValueError('If you specify "nsamples"=-1, you have to provide '
'a batch size, e.g. trainer.batch_size.')
smpl_loader = DataLoader(smpl_subset, batch_size=batch_size, shuffle=False)
# create the dataloader
dataloader = DataLoader(ds, batch_size=1, shuffle=False, drop_last=False)
# iterate over the samples and plot inputs, ground truth and
# model predictions
output = {}
for batch, (inputs, labels) in enumerate(smpl_loader):
for batch, (inputs, labels) in enumerate(dataloader):
# send inputs and labels to device
inputs = inputs.to(device)
......@@ -87,22 +77,14 @@ def predict_samples(ds, model, optimizer, state_path, state_file, nsamples,
cmm[ytrue.long(), ypred.long()] += 1
# save plot of current batch to disk
if plot_samples:
# check whether the dataset is a subset
if isinstance(ds, Subset):
use_bands = ds.dataset.use_bands
ds_labels = ds.dataset.labels
else:
use_bands = ds.use_bands
ds_labels = ds.labels
if plot:
# plot inputs, ground truth and model predictions
sname = fname + '_sample_{}.pt'.format(batch)
sname = fname + '_{}_{}.pt'.format(ds.name, batch)
fig, ax = plot_sample(inputs.numpy().clip(0, 1),
labels,
use_bands,
ds_labels,
ds.dataset.use_bands,
ds.dataset.labels,
y_pred=prd,
state=sname,
**kwargs)
......@@ -113,6 +95,10 @@ def predict_samples(ds, model, optimizer, state_path, state_file, nsamples,
def predict_scenes(ds, model, optimizer, state_path, state_file,
scene_id=None, cm=False, plot_scenes=False, **kwargs):
# check if the dataset is an instance of torch.data.dataset.Subset
if not isinstance(ds, Subset):
raise TypeError('ds should be of type {}'.format(Subset))
# the device to compute on, use gpu if available
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
......@@ -136,12 +122,6 @@ def predict_scenes(ds, model, optimizer, state_path, state_file,
# check whether a scene id is provided
if scene_id is None:
# check if the dataset is an instance of torch.data.dataset.Subset
if not isinstance(ds, Subset):
raise TypeError('ds should be of type {}'.format(Subset))
print('Predicting scenes of the subset ...')
# get the names of the scenes
try:
scene_ids = ds.ids
......@@ -158,6 +138,7 @@ def predict_scenes(ds, model, optimizer, state_path, state_file,
scene_size = (ds.dataset.height, ds.dataset.width)
# iterate over the scenes
print('Predicting scenes of the subset ...')
scene = {}
for sid in scene_ids:
......@@ -173,7 +154,7 @@ def predict_scenes(ds, model, optimizer, state_path, state_file,
# create the dataloader
scene_dl = DataLoader(scene_ds, batch_size=len(scene_ds),
shuffle=False)
shuffle=False, drop_last=False)
# predict the current scene
for i, (inp, lab) in enumerate(scene_dl):
......
......@@ -261,12 +261,6 @@ config = {
# split_mode="date"
'predict_scene': True,
# number of samples to validate model performance on
# if nsamples': -1, the model is evaluated on all samples of the validation
# set or test set
# only takes effect if predict_scene=False
'nsamples': -1,
# whether to save plots of (input, ground truth, prediction) of the
# samples from the validation/test dataset to disk
# output path is: pysegcnn/main/_samples/
......
......@@ -49,9 +49,6 @@ if __name__ == '__main__':
trainer.optimizer,
trainer.state_path,
trainer.state_file,
trainer.nsamples,
trainer.seed,
trainer.batch_size,
trainer.cm,
trainer.plot_samples,
bands=trainer.plot_bands,
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment