Skip to content
GitLab
Explore
Sign in
Primary navigation
Search or go to…
Project
P
PySegCNN
Manage
Activity
Members
Labels
Plan
Issues
Issue boards
Milestones
Wiki
Code
Merge requests
Repository
Branches
Commits
Tags
Repository graph
Compare revisions
Snippets
Deploy
Releases
Package registry
Container registry
Model registry
Operate
Terraform modules
Monitor
Incidents
Analyze
Value stream analytics
Contributor analytics
Repository analytics
Model experiments
Help
Help
Support
GitLab documentation
Compare GitLab plans
Community forum
Contribute to GitLab
Provide feedback
Keyboard shortcuts
?
Snippets
Groups
Projects
Show more breadcrumbs
earth_observation_public
PySegCNN
Commits
451e1467
Commit
451e1467
authored
4 years ago
by
Frisinghelli Daniel
Browse files
Options
Downloads
Patches
Plain Diff
Added a function to plot the confusion matrix
parent
c5887fd3
No related branches found
No related tags found
No related merge requests found
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
pytorch/dataset.py
+114
-18
114 additions, 18 deletions
pytorch/dataset.py
with
114 additions
and
18 deletions
pytorch/dataset.py
+
114
−
18
View file @
451e1467
...
...
@@ -14,6 +14,7 @@ your custom dataset.
# builtins
import
os
import
glob
import
itertools
# externals
import
gdal
...
...
@@ -22,6 +23,7 @@ import torch
import
matplotlib.pyplot
as
plt
import
matplotlib.patches
as
mpatches
from
matplotlib.colors
import
ListedColormap
,
BoundaryNorm
from
matplotlib
import
cm
as
colormap
from
torch.utils.data
import
Dataset
...
...
@@ -76,9 +78,10 @@ class ImageDataset(Dataset):
# get samples: (tiles x channels x height x width)
data
,
gt
=
self
.
build_samples
(
scene
)
# convert to torch tensors
x
=
torch
.
tensor
(
data
,
dtype
=
torch
.
float32
)
y
=
torch
.
tensor
(
gt
,
dtype
=
torch
.
uint8
)
# preprocess input and return torch tensors of shape:
# x : (bands, height, width)
# y : (height, width)
x
,
y
=
self
.
preprocess
(
data
,
gt
)
return
x
,
y
...
...
@@ -132,6 +135,16 @@ class ImageDataset(Dataset):
raise
NotImplementedError
(
'
Inherit the ImageDataset class and
'
'
implement the method.
'
)
# the preprocess() method has to be implemented by the class inheriting
# the ImageDataset class
# preprocess() should return two torch.tensors:
# - input data: tensor of shape (bands, height, width)
# - ground truth: tensor of shape (height, width)
def
preprocess
(
self
,
data
,
gt
):
raise
NotImplementedError
(
'
Inherit the ImageDataset class and
'
'
implement the method.
'
)
# _read_scene() reads all the bands and the ground truth mask in a
# scene/tile to a numpy array and returns a dictionary with
# (key, value) = ('band_name', np.ndarray(band_data))
...
...
@@ -350,6 +363,60 @@ class ImageDataset(Dataset):
return
fig
,
ax
# plot_confusion_matrix() plots the confusion matrix of the validation/test
# set returned by the pytorch.predict function
def
plot_confusion_matrix
(
cm
,
labels
,
normalize
=
True
,
figsize
=
(
10
,
10
),
cmap
=
'
Blues
'
):
# number of classes
nclasses
=
len
(
labels
)
# string format to plot values of confusion matrix
fmt
=
'
d
'
# check whether to normalize the confusion matrix
if
normalize
:
# normalize
cm
=
cm
/
cm
.
sum
(
axis
=
1
,
keepdims
=
True
)
# change string format to floating point
fmt
=
'
.2f
'
# create figure
fig
,
ax
=
plt
.
subplots
(
1
,
1
,
figsize
=
figsize
)
# get colormap
cmap
=
colormap
.
get_cmap
(
cmap
,
256
)
# plot confusion matrix
im
=
ax
.
imshow
(
cm
,
cmap
=
cmap
)
# threshold determining the color of the values
thresh
=
(
cm
.
max
()
+
cm
.
min
())
/
2
# brightest/darkest color of current colormap
cmap_min
,
cmap_max
=
im
.
cmap
(
0
),
im
.
cmap
(
256
)
# plot values of confusion matrix
for
i
,
j
in
itertools
.
product
(
range
(
nclasses
),
range
(
nclasses
)):
ax
.
text
(
j
,
i
,
format
(
cm
[
i
,
j
],
fmt
),
ha
=
'
center
'
,
va
=
'
center
'
,
color
=
cmap_max
if
cm
[
i
,
j
]
<
thresh
else
cmap_min
)
# axes properties and labels
ax
.
set
(
xticks
=
np
.
arange
(
nclasses
),
yticks
=
np
.
arange
(
nclasses
),
xticklabels
=
labels
,
yticklabels
=
labels
,
ylabel
=
'
True
'
,
xlabel
=
'
Predicted
'
)
# add colorbar axes
cax
=
fig
.
add_axes
([
ax
.
get_position
().
x1
+
0.025
,
ax
.
get_position
().
y0
,
0.05
,
ax
.
get_position
().
y1
-
ax
.
get_position
().
y0
])
fig
.
colorbar
(
im
,
cax
=
cax
)
return
fig
,
ax
# SparcsDataset class: inherits from the generic ImageDataset class
class
SparcsDataset
(
ImageDataset
):
...
...
@@ -394,6 +461,16 @@ class SparcsDataset(ImageDataset):
lc
[
i
]
=
{
'
label
'
:
l
,
'
color
'
:
c
}
return
lc
# preprocessing of the Sparcs dataset
def
preprocess
(
self
,
data
,
gt
):
# if the preprocessing is not done externally, implement it here
# convert to torch tensors
x
=
torch
.
tensor
(
data
,
dtype
=
torch
.
float32
)
y
=
torch
.
tensor
(
gt
,
dtype
=
torch
.
uint8
)
return
x
,
y
# _compose_scenes() creates a list of dictionaries containing the paths
# to the files of each scene
# if the scenes are divided into tiles, each tile has its own entry
...
...
@@ -447,7 +524,7 @@ class SparcsDataset(ImageDataset):
return
scene_data
class
Cloud95
(
ImageDataset
):
class
Cloud95
Dataset
(
ImageDataset
):
def
__init__
(
self
,
root_dir
,
use_bands
=
[],
tile_size
=
None
):
# initialize super class ImageDataset
...
...
@@ -467,9 +544,21 @@ class Cloud95(ImageDataset):
# class labels of the Cloud-95 dataset
def
get_labels
(
self
):
return
{
0
:
{
'
label
'
:
'
Clear
'
,
'
color
'
:
'
azur
e
'
},
return
{
0
:
{
'
label
'
:
'
Clear
'
,
'
color
'
:
'
skyblu
e
'
},
1
:
{
'
label
'
:
'
Cloud
'
,
'
color
'
:
'
white
'
}}
# preprocess Cloud-95 dataset
def
preprocess
(
self
,
data
,
gt
):
# normalize the data
# here, we use the normalization of the authors of Cloud-95, i.e.
# Mohajerani and Saeedi (2019, 2020)
x
=
torch
.
tensor
(
data
/
65535
,
dtype
=
torch
.
float32
)
y
=
torch
.
tensor
(
gt
/
255
,
dtype
=
torch
.
uint8
)
return
x
,
y
def
compose_scenes
(
self
,
root_dir
):
# get the names of the directories containing the TIFF files of
...
...
@@ -478,7 +567,7 @@ class Cloud95(ImageDataset):
for
dirpath
,
dirname
,
files
in
os
.
walk
(
root_dir
):
# check if the current directory path includes the name of a band
# or the name of the ground truth mask
cband
=
[
band
for
band
in
[
*
self
.
bands
.
values
()]
+
[
'
gt
'
]
if
cband
=
[
band
for
band
in
self
.
use_
bands
+
[
'
gt
'
]
if
dirpath
.
endswith
(
band
)
and
os
.
path
.
isdir
(
dirpath
)]
# add path to current band files to dictionary
...
...
@@ -527,23 +616,30 @@ if __name__ == '__main__':
cloud_path
=
os
.
path
.
join
(
wd
,
'
_Datasets/Cloud95/Training
'
)
# instanciate the Cloud-95 dataset
cloud_dataset
=
Cloud95
(
cloud_path
)
cloud_dataset
=
Cloud95
Dataset
(
cloud_path
,
tile_size
=
192
)
# instanciate the SparcsDataset class
sparcs_dataset
=
SparcsDataset
(
sparcs_path
,
tile_size
=
None
,
use_bands
=
[
'
nir
'
,
'
red
'
,
'
green
'
])
# randomly sample an integer from [0, nsamples]
sample
=
np
.
random
.
randint
(
len
(
sparcs_dataset
),
size
=
1
).
item
()
# a sample from the sparcs dataset
sample_x
,
sample_y
=
sparcs_dataset
[
sample
]
sample_s
=
np
.
random
.
randint
(
len
(
sparcs_dataset
),
size
=
1
).
item
()
s_x
,
s_y
=
sparcs_dataset
[
sample_s
]
fig
,
ax
=
sparcs_dataset
.
plot_sample
(
s_x
,
s_y
,
bands
=
[
'
nir
'
,
'
red
'
,
'
green
'
])
# print shape of the sample
print
(
'
A sample from the Sparcs dataset:
'
)
print
(
'
Shape of input data: {}
'
.
format
(
sample_x
.
shape
))
print
(
'
Shape of ground truth: {}
'
.
format
(
sample_y
.
shape
))
# a sample from the cloud dataset
sample_c
=
np
.
random
.
randint
(
len
(
cloud_dataset
),
size
=
1
).
item
()
c_x
,
c_y
=
cloud_dataset
[
sample_c
]
fig
,
ax
=
cloud_dataset
.
plot_sample
(
c_x
,
c_y
,
bands
=
[
'
nir
'
,
'
red
'
,
'
green
'
])
# plot the sample
fig
,
ax
=
sparcs_dataset
.
plot_sample
(
sample_x
,
sample_y
,
bands
=
[
'
nir
'
,
'
red
'
,
'
green
'
])
# print shape of the sample
for
i
,
l
,
d
in
zip
([
s_x
,
c_x
],
[
s_y
,
c_y
],
[
sparcs_dataset
,
cloud_dataset
]):
print
(
'
A sample from the {}:
'
.
format
(
d
.
__class__
.
__name__
))
print
(
'
Shape of input data: {}
'
.
format
(
i
.
shape
))
print
(
'
Shape of ground truth: {}
'
.
format
(
l
.
shape
))
# show figures
plt
.
show
()
This diff is collapsed.
Click to expand it.
Preview
0%
Loading
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Save comment
Cancel
Please
register
or
sign in
to comment