Skip to content
GitLab
Explore
Sign in
Primary navigation
Search or go to…
Project
P
PySegCNN
Manage
Activity
Members
Labels
Plan
Issues
Issue boards
Milestones
Wiki
Code
Merge requests
Repository
Branches
Commits
Tags
Repository graph
Compare revisions
Snippets
Deploy
Releases
Package Registry
Container Registry
Model registry
Operate
Terraform modules
Monitor
Incidents
Analyze
Value stream analytics
Contributor analytics
Repository analytics
Model experiments
Help
Help
Support
GitLab documentation
Compare GitLab plans
Community forum
Contribute to GitLab
Provide feedback
Keyboard shortcuts
?
Snippets
Groups
Projects
Show more breadcrumbs
earth_observation_public
PySegCNN
Commits
e554e653
Commit
e554e653
authored
4 years ago
by
Frisinghelli Daniel
Browse files
Options
Downloads
Patches
Plain Diff
Made the ground truth mask in plot_sample optional
parent
239f3e57
No related branches found
No related tags found
No related merge requests found
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
pysegcnn/core/graphics.py
+43
-37
43 additions, 37 deletions
pysegcnn/core/graphics.py
pysegcnn/core/predict.py
+2
-2
2 additions, 2 deletions
pysegcnn/core/predict.py
with
45 additions
and
39 deletions
pysegcnn/core/graphics.py
+
43
−
37
View file @
e554e653
...
...
@@ -32,8 +32,6 @@ from pysegcnn.core.trainer import accuracy_function
from
pysegcnn.main.config
import
HERE
# this function applies percentile stretching at the alpha level
# can be used to increase constrast for visualization
def
contrast_stretching
(
image
,
alpha
=
5
):
"""
Apply percentile stretching to an image to increase constrast.
...
...
@@ -85,9 +83,7 @@ def running_mean(x, w):
return
(
cumsum
[
w
:]
-
cumsum
[:
-
w
])
/
w
# plot_sample() plots a false color composite of the scene/tile together
# with the model prediction and the corresponding ground truth
def
plot_sample
(
x
,
y
,
use_bands
,
labels
,
y_pred
=
None
,
figsize
=
(
10
,
10
),
def
plot_sample
(
x
,
use_bands
,
labels
,
y
=
None
,
y_pred
=
None
,
figsize
=
(
10
,
10
),
bands
=
[
'
nir
'
,
'
red
'
,
'
green
'
],
state
=
None
,
outpath
=
os
.
path
.
join
(
HERE
,
'
_samples/
'
),
alpha
=
0
):
"""
Plot false color composite (FCC), ground truth and model prediction.
...
...
@@ -96,8 +92,6 @@ def plot_sample(x, y, use_bands, labels, y_pred=None, figsize=(10, 10),
----------
x : `numpy.ndarray` or `torch.Tensor`, (b, h, w)
Array containing the raw data of the tile, shape=(bands, height, width)
y : `numpy.ndarray` or `torch.Tensor`, (h, w)
Array containing the ground truth of tile ``x``, shape=(height, width)
use_bands : `list` of `str`
List describing the order of the bands in ``x``.
labels : `dict` [`int`, `dict`]
...
...
@@ -107,9 +101,12 @@ def plot_sample(x, y, use_bands, labels, y_pred=None, figsize=(10, 10),
A named color (`str`).
``
'
label
'
``
The name of the class label (`str`).
y_pred : `numpy.ndarray` or `None`, optional
y : `numpy.ndarray` or `torch.Tensor` or `None`, optional
Array containing the ground truth of tile ``x``, shape=(height, width).
The default is None.
y_pred : `numpy.ndarray` or `torch.Tensor` or `None`, optional
Array containing the prediction for tile ``x``, shape=(height, width).
The default is None
, i.e. only FCC and ground truth are plotted
.
The default is None.
figsize : `tuple`, optional
The figure size in centimeters. The default is (10, 10).
bands : `list` [`str`], optional
...
...
@@ -119,7 +116,7 @@ def plot_sample(x, y, use_bands, labels, y_pred=None, figsize=(10, 10),
state file ending with
'
.pt
'
. The default is None, i.e. plot is not
saved to disk.
outpath : `str` or `pathlib.Path`, optional
Output path. The default is
os.path.join(HERE,
'
_samples
/
'
)
.
Output path. The default is
'
pysegcnn/main/
_samples
'
.
alpha : `int`, optional
The level of the percentiles to increase constrast in the FCC.
The default is 0, i.e. no stretching.
...
...
@@ -128,8 +125,8 @@ def plot_sample(x, y, use_bands, labels, y_pred=None, figsize=(10, 10),
-------
fig : `matplotlib.figure.Figure`
The figure handle.
ax : `matplotlib.axes._subplots.AxesSubplot`
T
he axes handle.
ax :
`numpy.ndarray` [
`matplotlib.axes._subplots.AxesSubplot`
]
An array of t
he axes handle
s
.
"""
# check whether to apply constrast stretching
...
...
@@ -145,35 +142,46 @@ def plot_sample(x, y, use_bands, labels, y_pred=None, figsize=(10, 10),
boundaries
=
[
*
labels
.
keys
(),
cmap
.
N
]
norm
=
BoundaryNorm
(
boundaries
,
cmap
.
N
)
# create figure: check whether to plot model prediction
if
y_pred
is
not
None
:
# compute accuracy
acc
=
accuracy_function
(
y_pred
,
y
)
# plot model prediction
fig
,
ax
=
plt
.
subplots
(
1
,
3
,
figsize
=
figsize
)
ax
[
2
].
imshow
(
y_pred
,
cmap
=
cmap
,
interpolation
=
'
nearest
'
,
norm
=
norm
)
ax
[
2
].
set_title
(
'
Prediction ({:.2f}%)
'
.
format
(
acc
*
100
),
pad
=
15
)
# create a patch (proxy artist) for every color
patches
=
[
mpatches
.
Patch
(
color
=
c
,
label
=
l
)
for
c
,
l
in
zip
(
colors
,
ulabels
)]
else
:
fig
,
ax
=
plt
.
subplots
(
1
,
2
,
figsize
=
figsize
)
# initialize figure
fig
,
ax
=
plt
.
subplots
(
1
,
3
,
figsize
=
figsize
)
# plot false color composite
ax
[
0
].
imshow
(
rgb
)
ax
[
0
].
set_title
(
'
R = {}, G = {}, B = {}
'
.
format
(
*
bands
),
pad
=
15
)
# plot ground thruth mask
ax
[
1
].
imshow
(
y
,
cmap
=
cmap
,
interpolation
=
'
nearest
'
,
norm
=
norm
)
ax
[
1
].
set_title
(
'
Ground truth
'
,
pad
=
15
)
# check whether to plot ground truth
acc
=
None
if
y
is
None
:
# remove axis to plot ground truth from figure
fig
.
delaxes
(
ax
[
1
])
else
:
# plot ground thruth mask
ax
[
1
].
imshow
(
y
,
cmap
=
cmap
,
interpolation
=
'
nearest
'
,
norm
=
norm
)
ax
[
1
].
set_title
(
'
Ground truth
'
,
pad
=
15
)
# check whether to plot model prediction
if
y_pred
is
None
:
# remove axis to plot model prediction from figure
fig
.
delaxes
(
ax
[
2
])
else
:
# plot model prediction
ax
[
2
].
imshow
(
y_pred
,
cmap
=
cmap
,
interpolation
=
'
nearest
'
,
norm
=
norm
)
# create a patch (proxy artist) for every color
patches
=
[
mpatches
.
Patch
(
color
=
c
,
label
=
l
)
for
c
,
l
in
zip
(
colors
,
ulabels
)]
# set title
title
=
'
Prediction
'
if
y
is
not
None
:
acc
=
accuracy_function
(
y_pred
,
y
)
title
+=
'
({:.2f}%)
'
.
format
(
acc
*
100
)
ax
[
2
].
set_title
(
title
,
pad
=
15
)
# plot patches as legend
plt
.
legend
(
handles
=
patches
,
bbox_to_anchor
=
(
1.05
,
1
),
loc
=
2
,
frameon
=
False
)
# if a ground truth or a model prediction is plotted, add legend
if
len
(
fig
.
axes
)
>
1
:
plt
.
legend
(
handles
=
patches
,
bbox_to_anchor
=
(
1.05
,
1
),
loc
=
2
,
frameon
=
False
)
# save figure
if
state
is
not
None
:
...
...
@@ -184,8 +192,6 @@ def plot_sample(x, y, use_bands, labels, y_pred=None, figsize=(10, 10),
return
fig
,
ax
# plot_confusion_matrix() plots the confusion matrix of the validation/test
# set returned by the pytorch.predict function
def
plot_confusion_matrix
(
cm
,
labels
,
normalize
=
True
,
figsize
=
(
10
,
10
),
cmap
=
'
Blues
'
,
state
=
None
,
outpath
=
os
.
path
.
join
(
HERE
,
'
_graphics/
'
)):
...
...
@@ -213,7 +219,7 @@ def plot_confusion_matrix(cm, labels, normalize=True,
state file ending with
'
.pt
'
. The default is None, i.e. plot is not
saved to disk.
outpath : `str` or `pathlib.Path`, optional
Output path. The default is
os.path.join(HERE,
'
_graphics/
'
)
.
Output path. The default is
'
pysegcnn/main/
_graphics/
'
.
Returns
-------
...
...
@@ -306,7 +312,7 @@ def plot_loss(state_file, figsize=(10, 10), step=5,
A list of four named colors supported by `matplotlib`.
The default is [
'
lightgreen
'
,
'
green
'
,
'
skyblue
'
,
'
steelblue
'
].
outpath : `str` or `pathlib.Path`, optional
Output path. The default is
os.path.join(HERE,
'
_graphics/
'
)
.
Output path. The default is
'
pysegcnn/main/
_graphics/
'
.
Returns
-------
...
...
This diff is collapsed.
Click to expand it.
pysegcnn/core/predict.py
+
2
−
2
View file @
e554e653
...
...
@@ -168,9 +168,9 @@ def predict_samples(ds, model, cm=False, plot=False, **kwargs):
# plot inputs, ground truth and model predictions
sname
=
fname
+
'
_{}_{}.pt
'
.
format
(
ds
.
name
,
batch
)
fig
,
ax
=
plot_sample
(
inputs
.
numpy
().
clip
(
0
,
1
),
labels
,
ds
.
dataset
.
use_bands
,
ds
.
dataset
.
labels
,
y
=
labels
,
y_pred
=
prd
,
state
=
sname
,
**
kwargs
)
...
...
@@ -298,9 +298,9 @@ def predict_scenes(ds, model, scene_id=None, cm=False, plot=False, **kwargs):
# plot current scene
if
plot
:
fig
,
ax
=
plot_sample
(
inputs
.
clip
(
0
,
1
),
labels
,
ds
.
dataset
.
use_bands
,
ds
.
dataset
.
labels
,
y
=
labels
,
y_pred
=
prdtcn
,
state
=
sname
,
**
kwargs
)
...
...
This diff is collapsed.
Click to expand it.
Preview
0%
Loading
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Save comment
Cancel
Please
register
or
sign in
to comment