Skip to content
GitLab
Explore
Sign in
Primary navigation
Search or go to…
Project
P
PySegCNN
Manage
Activity
Members
Labels
Plan
Issues
Issue boards
Milestones
Wiki
Code
Merge requests
Repository
Branches
Commits
Tags
Repository graph
Compare revisions
Snippets
Deploy
Releases
Package Registry
Container Registry
Model registry
Operate
Terraform modules
Monitor
Incidents
Analyze
Value stream analytics
Contributor analytics
Repository analytics
Model experiments
Help
Help
Support
GitLab documentation
Compare GitLab plans
Community forum
Contribute to GitLab
Provide feedback
Keyboard shortcuts
?
Snippets
Groups
Projects
Show more breadcrumbs
earth_observation_public
PySegCNN
Commits
f8a96422
Commit
f8a96422
authored
4 years ago
by
Frisinghelli Daniel
Browse files
Options
Downloads
Patches
Plain Diff
Included type handling and improved console prints
parent
5a99371d
No related branches found
No related tags found
No related merge requests found
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
pysegcnn/core/predict.py
+36
-26
36 additions, 26 deletions
pysegcnn/core/predict.py
with
36 additions
and
26 deletions
pysegcnn/core/predict.py
+
36
−
26
View file @
f8a96422
...
...
@@ -9,8 +9,9 @@ from torch.utils.data.dataset import Subset
import
torch.nn.functional
as
F
# locals
from
pysegcnn.core.utils
import
reconstruct_scene
from
pysegcnn.core.utils
import
reconstruct_scene
,
accuracy_function
from
pysegcnn.core.graphics
import
plot_sample
from
pysegcnn.core.split
import
RandomSubset
,
SceneSubset
def
get_scene_tiles
(
ds
,
scene_id
):
...
...
@@ -28,9 +29,18 @@ def get_scene_tiles(ds, scene_id):
def
predict_samples
(
ds
,
model
,
optimizer
,
state_path
,
state_file
,
cm
=
False
,
plot
=
False
,
**
kwargs
):
# check whether the dataset is a subset
if
not
isinstance
(
ds
,
Subset
):
raise
TypeError
(
'
ds should be of type {}
'
.
format
(
Subset
))
# check whether the dataset is a valid subset, i.e.
# an instance of pysegcnn.core.split.SceneSubset or
# an instance of pysegcnn.core.split.RandomSubset
_name
=
type
(
ds
).
__name__
if
_name
is
not
RandomSubset
.
__name__
or
_name
is
not
SceneSubset
.
__name__
:
raise
TypeError
(
'
ds should be an instance of {} or of {}
'
.
format
(
'
.
'
.
join
([
RandomSubset
.
__module__
,
RandomSubset
.
__name__
]),
'
.
'
.
join
([
SceneSubset
.
__module__
,
SceneSubset
.
__name__
])
)
)
# the device to compute on, use gpu if available
device
=
torch
.
device
(
"
cuda:0
"
if
torch
.
cuda
.
is_available
()
else
"
cpu
"
)
...
...
@@ -58,6 +68,7 @@ def predict_samples(ds, model, optimizer, state_path, state_file, cm=False,
# iterate over the samples and plot inputs, ground truth and
# model predictions
output
=
{}
print
(
'
Predicting samples of the {} dataset ...
'
.
format
(
ds
.
name
))
for
batch
,
(
inputs
,
labels
)
in
enumerate
(
dataloader
):
# send inputs and labels to device
...
...
@@ -71,6 +82,10 @@ def predict_samples(ds, model, optimizer, state_path, state_file, cm=False,
# store output for current batch
output
[
batch
]
=
{
'
input
'
:
inputs
,
'
labels
'
:
labels
,
'
prediction
'
:
prd
}
print
(
'
Sample: {:d}/{:d}, Accuracy: {:.2f}
'
.
format
(
batch
+
1
,
len
(
dataloader
),
accuracy_function
(
prd
,
labels
)))
# update confusion matrix
if
cm
:
for
ytrue
,
ypred
in
zip
(
labels
.
view
(
-
1
),
prd
.
view
(
-
1
)):
...
...
@@ -95,9 +110,11 @@ def predict_samples(ds, model, optimizer, state_path, state_file, cm=False,
def
predict_scenes
(
ds
,
model
,
optimizer
,
state_path
,
state_file
,
scene_id
=
None
,
cm
=
False
,
plot_scenes
=
False
,
**
kwargs
):
# check if the dataset is an instance of torch.data.dataset.Subset
if
not
isinstance
(
ds
,
Subset
):
raise
TypeError
(
'
ds should be of type {}
'
.
format
(
Subset
))
# check whether the dataset is a valid subset, i.e. an instance of
# pysegcnn.core.split.SceneSubset
if
not
type
(
ds
).
__name__
is
SceneSubset
.
__name__
:
raise
TypeError
(
'
ds should be an instance of {}
'
.
format
(
'
.
'
.
join
([
SceneSubset
.
__module__
,
SceneSubset
.
__name__
])))
# the device to compute on, use gpu if available
device
=
torch
.
device
(
"
cuda:0
"
if
torch
.
cuda
.
is_available
()
else
"
cpu
"
)
...
...
@@ -121,16 +138,8 @@ def predict_scenes(ds, model, optimizer, state_path, state_file,
# check whether a scene id is provided
if
scene_id
is
None
:
# get the names of the scenes
try
:
scene_ids
=
ds
.
ids
except
AttributeError
:
raise
TypeError
(
'
predict_scenes does only work for datasets split
'
'
by
"
scene
"
or by
"
date
"
.
'
)
scene_ids
=
ds
.
ids
else
:
# the name of the selected scene
scene_ids
=
[
scene_id
]
...
...
@@ -138,12 +147,12 @@ def predict_scenes(ds, model, optimizer, state_path, state_file,
scene_size
=
(
ds
.
dataset
.
height
,
ds
.
dataset
.
width
)
# iterate over the scenes
print
(
'
Predicting scenes of the
sub
set ...
'
)
scene
=
{}
for
sid
in
scene_ids
:
print
(
'
Predicting scenes of the
{} data
set ...
'
.
format
(
ds
.
name
)
)
scene
s
=
{}
for
i
,
sid
in
enumerate
(
scene_ids
)
:
# filename for the current scene
sname
=
fname
+
'
_
'
+
sid
+
'
.pt
'
sname
=
fname
+
'
_
{}_{}.pt
'
.
format
(
ds
.
name
,
sid
)
# get the indices of the tiles of the scene
indices
=
get_scene_tiles
(
ds
,
sid
)
...
...
@@ -157,10 +166,7 @@ def predict_scenes(ds, model, optimizer, state_path, state_file,
shuffle
=
False
,
drop_last
=
False
)
# predict the current scene
for
i
,
(
inp
,
lab
)
in
enumerate
(
scene_dl
):
print
(
'
Predicting scene ({}/{}), id: {}
'
.
format
(
i
+
1
,
len
(
scene_ids
),
sid
))
for
b
,
(
inp
,
lab
)
in
enumerate
(
scene_dl
):
# send inputs and labels to device
inp
=
inp
.
to
(
device
)
...
...
@@ -180,8 +186,12 @@ def predict_scenes(ds, model, optimizer, state_path, state_file,
labels
=
reconstruct_scene
(
lab
,
scene_size
,
nbands
=
1
)
prdtcn
=
reconstruct_scene
(
prd
,
scene_size
,
nbands
=
1
)
# print progress
print
(
'
Scene {:d}/{:d}, Id: {}, Accuracy: {:.2f}
'
.
format
(
i
+
1
,
len
(
scene_ids
),
sid
,
accuracy_function
(
prdtcn
,
labels
)))
# save outputs to dictionary
scene
[
sid
]
=
{
'
input
'
:
inputs
,
'
labels
'
:
labels
,
'
prediction
'
:
prdtcn
}
scene
s
[
sid
]
=
{
'
input
'
:
inputs
,
'
labels
'
:
labels
,
'
prediction
'
:
prdtcn
}
# plot current scene
if
plot_scenes
:
...
...
@@ -193,4 +203,4 @@ def predict_scenes(ds, model, optimizer, state_path, state_file,
state
=
sname
,
**
kwargs
)
return
scene
,
cmm
return
scene
s
,
cmm
This diff is collapsed.
Click to expand it.
Preview
0%
Loading
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Save comment
Cancel
Please
register
or
sign in
to comment