Skip to content
GitLab
Explore
Sign in
Primary navigation
Search or go to…
Project
P
PySegCNN
Manage
Activity
Members
Labels
Plan
Issues
Issue boards
Milestones
Wiki
Code
Merge requests
Repository
Branches
Commits
Tags
Repository graph
Compare revisions
Snippets
Deploy
Releases
Package registry
Container Registry
Model registry
Operate
Terraform modules
Monitor
Incidents
Analyze
Value stream analytics
Contributor analytics
Repository analytics
Model experiments
Help
Help
Support
GitLab documentation
Compare GitLab plans
Community forum
Contribute to GitLab
Provide feedback
Keyboard shortcuts
?
Snippets
Groups
Projects
Show more breadcrumbs
earth_observation_public
PySegCNN
Commits
5128033c
Commit
5128033c
authored
4 years ago
by
Frisinghelli Daniel
Browse files
Options
Downloads
Patches
Plain Diff
Adapted predict functions to changes in trainer.py
parent
ef5cae60
No related branches found
No related tags found
No related merge requests found
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
pysegcnn/core/predict.py
+23
-23
23 additions, 23 deletions
pysegcnn/core/predict.py
with
23 additions
and
23 deletions
pysegcnn/core/predict.py
+
23
−
23
View file @
5128033c
# builtins
# builtins
import
os
import
os
import
pathlib
# externals
# externals
import
numpy
as
np
import
numpy
as
np
...
@@ -26,30 +27,27 @@ def get_scene_tiles(ds, scene_id):
...
@@ -26,30 +27,27 @@ def get_scene_tiles(ds, scene_id):
return
indices
return
indices
def
predict_samples
(
ds
,
model
,
optimizer
,
state_path
,
state_file
,
cm
=
False
,
def
predict_samples
(
ds
,
model
,
optimizer
,
state_file
,
cm
=
False
,
plot
=
False
,
**
kwargs
):
plot
=
False
,
**
kwargs
):
# check whether the dataset is a valid subset, i.e.
# check whether the dataset is a valid subset, i.e.
# an instance of pysegcnn.core.split.SceneSubset or
# an instance of pysegcnn.core.split.SceneSubset or
# an instance of pysegcnn.core.split.RandomSubset
# an instance of pysegcnn.core.split.RandomSubset
_name
=
type
(
ds
).
__name__
_name
=
type
(
ds
).
__name__
if
_name
is
not
RandomSubset
.
__name__
or
_name
is
not
SceneSubset
.
__name__
:
if
not
isinstance
(
ds
,
RandomSubset
)
or
not
isinstance
(
ds
,
SceneSubset
):
raise
TypeError
(
'
ds should be an instance of {} or of {}
'
raise
TypeError
(
'
ds should be an instance of {} or of {}.
'
.
format
(
'
.
'
.
join
([
RandomSubset
.
__module__
,
.
format
(
repr
(
RandomSubset
),
repr
(
SceneSubset
)))
RandomSubset
.
__name__
]),
'
.
'
.
join
([
SceneSubset
.
__module__
,
# convert state file to pathlib.Path object
SceneSubset
.
__name__
])
state_file
=
pathlib
.
Path
(
state_file
)
)
)
# the device to compute on, use gpu if available
# the device to compute on, use gpu if available
device
=
torch
.
device
(
"
cuda:0
"
if
torch
.
cuda
.
is_available
()
else
"
cpu
"
)
device
=
torch
.
device
(
"
cuda:0
"
if
torch
.
cuda
.
is_available
()
else
"
cpu
"
)
# load the pretrained model state
# load the pretrained model state
state
=
os
.
path
.
join
(
state_path
,
state_file
)
if
not
state_file
.
exists
():
if
not
os
.
path
.
exists
(
state
):
raise
FileNotFoundError
(
'
{} does not exist.
'
.
format
(
state_file
))
raise
FileNotFoundError
(
'
{} does not exist.
'
.
format
(
state
))
_
=
model
.
load
(
state_file
.
name
,
optimizer
,
state_file
.
parent
)
state
=
model
.
load
(
state_file
,
optimizer
,
state_path
)
# set the model to evaluation mode
# set the model to evaluation mode
print
(
'
Setting model to evaluation mode ...
'
)
print
(
'
Setting model to evaluation mode ...
'
)
...
@@ -57,7 +55,7 @@ def predict_samples(ds, model, optimizer, state_path, state_file, cm=False,
...
@@ -57,7 +55,7 @@ def predict_samples(ds, model, optimizer, state_path, state_file, cm=False,
model
.
to
(
device
)
model
.
to
(
device
)
# base filename for each sample
# base filename for each sample
fname
=
state_file
.
split
(
'
.pt
'
)[
0
]
fname
=
state_file
.
name
.
split
(
'
.pt
'
)[
0
]
# initialize confusion matrix
# initialize confusion matrix
cmm
=
np
.
zeros
(
shape
=
(
model
.
nclasses
,
model
.
nclasses
))
cmm
=
np
.
zeros
(
shape
=
(
model
.
nclasses
,
model
.
nclasses
))
...
@@ -107,23 +105,25 @@ def predict_samples(ds, model, optimizer, state_path, state_file, cm=False,
...
@@ -107,23 +105,25 @@ def predict_samples(ds, model, optimizer, state_path, state_file, cm=False,
return
output
,
cmm
return
output
,
cmm
def
predict_scenes
(
ds
,
model
,
optimizer
,
state_path
,
state_file
,
def
predict_scenes
(
ds
,
model
,
optimizer
,
state_file
,
scene_id
=
None
,
cm
=
False
,
plot_scenes
=
False
,
**
kwargs
):
scene_id
=
None
,
cm
=
False
,
plot_scenes
=
False
,
**
kwargs
):
# check whether the dataset is a valid subset, i.e. an instance of
# check whether the dataset is a valid subset, i.e. an instance of
# pysegcnn.core.split.SceneSubset
# pysegcnn.core.split.SceneSubset
if
not
type
(
ds
).
__name__
is
SceneSubset
.
__name__
:
if
not
isinstance
(
ds
,
SceneSubset
):
raise
TypeError
(
'
ds should be an instance of {}
'
.
format
(
raise
TypeError
(
'
ds should be an instance of {}.
'
'
.
'
.
join
([
SceneSubset
.
__module__
,
SceneSubset
.
__name__
])))
.
format
(
repr
(
SceneSubset
)))
# convert state file to pathlib.Path object
state_file
=
pathlib
.
Path
(
state_file
)
# the device to compute on, use gpu if available
# the device to compute on, use gpu if available
device
=
torch
.
device
(
"
cuda:0
"
if
torch
.
cuda
.
is_available
()
else
"
cpu
"
)
device
=
torch
.
device
(
"
cuda:0
"
if
torch
.
cuda
.
is_available
()
else
"
cpu
"
)
# load the pretrained model state
# load the pretrained model state
state
=
os
.
path
.
join
(
state_path
,
state_file
)
if
not
state_file
.
exists
():
if
not
os
.
path
.
exists
(
state
):
raise
FileNotFoundError
(
'
{} does not exist.
'
.
format
(
state_file
))
raise
FileNotFoundError
(
'
{} does not exist.
'
.
format
(
state
))
_
=
model
.
load
(
state_file
.
name
,
optimizer
,
state_file
.
parent
)
state
=
model
.
load
(
state_file
,
optimizer
,
state_path
)
# set the model to evaluation mode
# set the model to evaluation mode
print
(
'
Setting model to evaluation mode ...
'
)
print
(
'
Setting model to evaluation mode ...
'
)
...
@@ -131,7 +131,7 @@ def predict_scenes(ds, model, optimizer, state_path, state_file,
...
@@ -131,7 +131,7 @@ def predict_scenes(ds, model, optimizer, state_path, state_file,
model
.
to
(
device
)
model
.
to
(
device
)
# base filename for each scene
# base filename for each scene
fname
=
state_file
.
split
(
'
.pt
'
)[
0
]
fname
=
state_file
.
name
.
split
(
'
.pt
'
)[
0
]
# initialize confusion matrix
# initialize confusion matrix
cmm
=
np
.
zeros
(
shape
=
(
model
.
nclasses
,
model
.
nclasses
))
cmm
=
np
.
zeros
(
shape
=
(
model
.
nclasses
,
model
.
nclasses
))
...
...
This diff is collapsed.
Click to expand it.
Preview
0%
Loading
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Save comment
Cancel
Please
register
or
sign in
to comment