Skip to content
GitLab
Explore
Sign in
Primary navigation
Search or go to…
Project
P
PySegCNN
Manage
Activity
Members
Labels
Plan
Issues
Issue boards
Milestones
Wiki
Code
Merge requests
Repository
Branches
Commits
Tags
Repository graph
Compare revisions
Snippets
Deploy
Releases
Package registry
Container registry
Model registry
Operate
Terraform modules
Monitor
Incidents
Analyze
Value stream analytics
Contributor analytics
Repository analytics
Model experiments
Help
Help
Support
GitLab documentation
Compare GitLab plans
Community forum
Contribute to GitLab
Provide feedback
Keyboard shortcuts
?
Snippets
Groups
Projects
Show more breadcrumbs
earth_observation_public
PySegCNN
Commits
a1669e70
Commit
a1669e70
authored
4 years ago
by
Frisinghelli Daniel
Browse files
Options
Downloads
Patches
Plain Diff
Extended generic ImageDataset class; added Cloud95 support
parent
9a1d8772
No related branches found
No related tags found
No related merge requests found
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
pytorch/dataset.py
+260
-154
260 additions, 154 deletions
pytorch/dataset.py
with
260 additions
and
154 deletions
pytorch/dataset.py
+
260
−
154
View file @
a1669e70
...
...
@@ -28,28 +28,136 @@ from torch.utils.data import Dataset
# generic image dataset class
class
ImageDataset
(
Dataset
):
def
__init__
(
self
,
root_dir
):
def
__init__
(
self
,
root_dir
,
use_bands
,
tile_size
):
super
().
__init__
()
# the root directory: path to the image dataset
self
.
root
=
root_dir
# this function should return the length of the image dataset
# __len__() is used by pytorch to determine the total number of samples in
# the dataset, has to be implemented by a class inheriting from the
# ImageDataset class
# the size of a scene/patch in the dataset
self
.
size
=
self
.
get_size
()
# the available spectral bands in the dataset
self
.
bands
=
self
.
get_bands
()
# the class labels
self
.
labels
=
self
.
get_labels
()
# check which bands to use
self
.
use_bands
=
(
use_bands
if
use_bands
else
[
*
self
.
bands
.
values
()])
# each scene is divided into (tile_size x tile_size) blocks
# each of these blocks is treated as a single sample
self
.
tile_size
=
tile_size
# calculate number of resulting tiles and check whether the images are
# evenly divisible in square tiles of size (tile_size x tile_size)
if
self
.
tile_size
is
None
:
self
.
tiles
=
None
else
:
self
.
tiles
=
self
.
is_divisible
(
self
.
size
,
self
.
tile_size
)
# the samples of the dataset
self
.
scenes
=
[]
# the __len__() method returns the number of samples in the dataset
def
__len__
(
self
):
raise
NotImplementedError
(
'
Inherit the ImageDataset class and
'
'
implement the method.
'
)
# number of (tiles x channels x height x width) patches after each
# scene is decomposed to tiles blocks
return
len
(
self
.
scenes
)
# th
is function shoul
d return a single sample of the dataset given an
# th
e __getitem__() metho
d return
s
a single sample of the dataset given an
# index, i.e. an array/tensor of shape (channels x height x width)
# it has to be implemented by a class inheriting from the
# ImageDataset class
def
__getitem__
(
self
,
idx
):
# select a scene
scene
=
self
.
read_scene
(
idx
)
# get samples: (tiles x channels x height x width)
data
,
gt
=
self
.
build_samples
(
scene
)
# convert to torch tensors
x
=
torch
.
tensor
(
data
,
dtype
=
torch
.
float32
)
y
=
torch
.
tensor
(
gt
,
dtype
=
torch
.
uint8
)
return
x
,
y
# the compose_scenes() method has to be implemented by the class inheriting
# the ImageDataset class
# compose_scenes() should return a list of dictionaries, where each
# dictionary represent one sample of the dataset, a scene or a tile
# of a scene, etc.
# the dictionaries should have the following (key, value) pairs:
# - (band_1, path_to_band_1.tif)
# - (band_2, path_to_band_2.tif)
# - ...
# - (band_n, path_to_band_n.tif)
# - (gt, path_to_ground_truth.tif)
# - (tile, None or int)
def
compose_scenes
(
self
,
*
args
,
**
kwargs
):
raise
NotImplementedError
(
'
Inherit the ImageDataset class and
'
'
implement the method.
'
)
# the get_size() method has to be implemented by the class inheriting
# the ImageDataset class
# get_size() method should return the image size as tuple, (height, width)
def
get_size
(
self
,
*
args
,
**
kwargs
):
raise
NotImplementedError
(
'
Inherit the ImageDataset class and
'
'
implement the method.
'
)
# the get_bands() method has to be implemented by the class inheriting
# the ImageDataset class
# get_bands() should return a dictionary with the following
# (key: int, value: str) pairs:
# - (1, band_1_name)
# - (2, band_2_name)
# - ...
# - (n, band_n_name)
def
get_bands
(
self
,
*
args
,
**
kwargs
):
raise
NotImplementedError
(
'
Inherit the ImageDataset class and
'
'
implement the method.
'
)
# the get_labels() method has to be implemented by the class inheriting
# the ImageDataset class
# get_labels() should return a dictionary with the following
# (key: int, value: str) pairs:
# - (0, label_1_name)
# - (1, label_2_name)
# - ...
# - (n, label_n_name)
# where the keys should be the values representing the values of the
# corresponding label in the ground truth mask
# the labels in the dictionary determine the classes to be segmented
def
get_labels
(
self
,
*
args
,
**
kwargs
):
raise
NotImplementedError
(
'
Inherit the ImageDataset class and
'
'
implement the method.
'
)
# _read_scene() reads all the bands and the ground truth mask in a
# scene/tile to a numpy array and returns a dictionary with
# (key, value) = ('band_name', np.ndarray(band_data))
def
read_scene
(
self
,
idx
):
# select a scene from the root directory
scene
=
self
.
scenes
[
idx
]
# read each band of the scene into a numpy array
scene_data
=
{
key
:
(
self
.
img2np
(
value
,
tile_size
=
self
.
tile_size
,
tile
=
scene
[
'
tile
'
])
if
key
!=
'
tile
'
else
value
)
for
key
,
value
in
scene
.
items
()}
return
scene_data
# _build_samples() stacks all bands of a scene/tile into a
# numpy array of shape (bands x height x width)
def
build_samples
(
self
,
scene
):
# iterate over the channels to stack
stack
=
np
.
stack
([
scene
[
band
]
for
band
in
self
.
use_bands
],
axis
=
0
)
gt
=
scene
[
'
gt
'
]
return
stack
,
gt
# the following functions are utility functions for common image
# manipulation operations
...
...
@@ -194,117 +302,103 @@ class ImageDataset(Dataset):
return
norm
# plot_sample() plots a false color composite of the scene/tile together
# with the model prediction and the corresponding ground truth
def
plot_sample
(
self
,
x
,
y
,
y_pred
=
None
,
figsize
=
(
10
,
10
),
bands
=
[
'
red
'
,
'
green
'
,
'
blue
'
],
stretch
=
False
,
**
kwargs
):
# SparcsDataset class: inherits from the generic ImageDataset class
class
SparcsDataset
(
ImageDataset
):
# check whether to apply constrast stretching
if
kwargs
:
stretch
=
True
func
=
self
.
contrast_stretching
if
stretch
else
lambda
x
:
x
def
__init__
(
self
,
root_dir
,
bands
=
[
'
red
'
,
'
green
'
,
'
blue
'
],
tile_size
=
None
):
super
().
__init__
(
root_dir
)
# Landsat 8 bands in the SPARCS dataset
self
.
sparcs_bands
=
{
1
:
'
violet
'
,
2
:
'
blue
'
,
3
:
'
green
'
,
4
:
'
red
'
,
5
:
'
nir
'
,
6
:
'
swir1
'
,
7
:
'
swir2
'
,
8
:
'
pan
'
,
9
:
'
cirrus
'
,
10
:
'
tir
'
}
# class labels and corresponding color map
self
.
labels
=
{
0
:
'
Shadow
'
,
1
:
'
Shadow over Water
'
,
2
:
'
Water
'
,
3
:
'
Snow
'
,
4
:
'
Land
'
,
5
:
'
Cloud
'
,
6
:
'
Flooded
'
}
self
.
colors
=
{
0
:
'
black
'
,
1
:
'
darkblue
'
,
2
:
'
blue
'
,
3
:
'
lightblue
'
,
4
:
'
grey
'
,
5
:
'
white
'
,
6
:
'
yellow
'
}
# image size of the SPARCS dataset: height x width
self
.
size
=
(
1000
,
1000
)
# create an rgb stack
rgb
=
np
.
dstack
([
func
(
x
[
self
.
use_bands
.
index
(
band
)],
**
kwargs
)
for
band
in
bands
])
# check which bands to use
if
bands
==
-
1
:
# in case bands=-1, use all bands of the sparcs dataset
self
.
bands
=
[
*
self
.
sparcs_bands
.
values
()]
else
:
self
.
bands
=
bands
# get labels and corresponding colors
labels
=
[
label
[
'
label
'
]
for
label
in
self
.
labels
.
values
()]
colors
=
[
label
[
'
color
'
]
for
label
in
self
.
labels
.
values
()]
# each scene is divided into (tile_size x tile_size) blocks
# each of these blocks is treated as a single sample
self
.
tile_size
=
tile_size
# create a ListedColormap
cmap
=
ListedColormap
(
colors
)
boundaries
=
[
*
self
.
labels
.
keys
(),
cmap
.
N
]
norm
=
BoundaryNorm
(
boundaries
,
cmap
.
N
)
# calculate number of resulting tiles and check whether the images are
# evenly divisible in square tiles of size (tile_size x tile_size)
if
self
.
tile_size
is
None
:
self
.
tiles
=
None
# create figure: check whether to plot model prediction
if
y_pred
is
not
None
:
fig
,
ax
=
plt
.
subplots
(
1
,
3
,
figsize
=
figsize
)
ax
[
2
].
imshow
(
y_pred
,
cmap
=
cmap
,
interpolation
=
'
nearest
'
,
norm
=
norm
)
ax
[
2
].
set_title
(
'
Prediction
'
,
pad
=
20
)
else
:
self
.
tiles
=
self
.
is_divisible
(
self
.
size
,
self
.
tile_size
)
# list of all scenes in the root directory
# each scene is divided into tiles blocks
self
.
scenes
=
[]
for
scene
in
os
.
listdir
(
root_dir
):
self
.
scenes
+=
self
.
_compose_scenes
(
os
.
path
.
join
(
root_dir
,
scene
))
# the __len__() method returns the number of samples in the Sparcs dataset
def
__len__
(
self
):
# number of (tiles x channels x height x width) patches after each
# scene is decomposed to tiles blocks
return
len
(
self
.
scenes
)
fig
,
ax
=
plt
.
subplots
(
1
,
2
,
figsize
=
figsize
)
# the __getitem__() method returns a sample of the Sparcs dataset
# __getitem__() is implicitly used by pytorch to draw samples during
# the training process
def
__getitem__
(
self
,
idx
):
# plot false color composite
ax
[
0
].
imshow
(
rgb
)
ax
[
0
].
set_title
(
'
R = {}, G = {}, B = {}
'
.
format
(
*
bands
),
pad
=
20
)
# select a scene
scene
=
self
.
_read_scene
(
idx
)
# plot ground thruth mask
ax
[
1
].
imshow
(
y
,
cmap
=
cmap
,
interpolation
=
'
nearest
'
,
norm
=
norm
)
ax
[
1
].
set_title
(
'
Ground truth
'
,
pad
=
20
)
# get samples: (tiles x channels x height x width)
data
,
gt
=
self
.
_build_samples
(
scene
)
# create a patch (proxy artist) for every color
patches
=
[
mpatches
.
Patch
(
color
=
c
,
label
=
l
)
for
c
,
l
in
zip
(
colors
,
labels
)]
#
convert to torch tensors
x
=
torch
.
tensor
(
data
,
dtype
=
torch
.
float32
)
y
=
torch
.
tensor
(
gt
,
dtype
=
torch
.
uint8
)
#
plot patches as legend
plt
.
legend
(
handles
=
patches
,
bbox_to_anchor
=
(
1.05
,
1
),
loc
=
2
,
frameon
=
False
)
return
x
,
y
return
fig
,
ax
# returns the band number of the preprocessed Sparcs Tiff files
def
_get_band_number
(
self
,
x
):
return
int
(
os
.
path
.
basename
(
x
).
split
(
'
_
'
)[
2
].
replace
(
'
B
'
,
''
))
# _store_bands() writes the paths to the data of each scene to a dictionary
# only the bands of interest are stored
def
_store_bands
(
self
,
bands
,
gt
):
# store the bands of interest in a dictionary
scene_data
=
{}
for
i
,
b
in
enumerate
(
bands
):
band
=
self
.
sparcs_bands
[
self
.
_get_band_number
(
b
)]
if
band
in
self
.
bands
:
scene_data
[
band
]
=
b
# SparcsDataset class: inherits from the generic ImageDataset class
class
SparcsDataset
(
ImageDataset
):
# store ground truth
scene_data
[
'
gt
'
]
=
gt
def
__init__
(
self
,
root_dir
,
use_bands
=
[
'
red
'
,
'
green
'
,
'
blue
'
],
tile_size
=
None
):
# initialize super class ImageDataset
super
().
__init__
(
root_dir
,
use_bands
,
tile_size
)
return
scene_data
# list of all scenes in the root directory
# each scene is divided into tiles blocks
self
.
scenes
=
[]
for
scene
in
os
.
listdir
(
self
.
root
):
self
.
scenes
+=
self
.
compose_scenes
(
os
.
path
.
join
(
self
.
root
,
scene
))
# image size of the Sparcs dataset: (height, width)
def
get_size
(
self
):
return
(
1000
,
1000
)
# Landsat 8 bands of the Sparcs dataset
def
get_bands
(
self
):
return
{
1
:
'
violet
'
,
2
:
'
blue
'
,
3
:
'
green
'
,
4
:
'
red
'
,
5
:
'
nir
'
,
6
:
'
swir1
'
,
7
:
'
swir2
'
,
8
:
'
pan
'
,
9
:
'
cirrus
'
,
10
:
'
tir
'
}
# class labels of the Sparcs dataset
def
get_labels
(
self
):
labels
=
[
'
Shadow
'
,
'
Shadow over Water
'
,
'
Water
'
,
'
Snow
'
,
'
Land
'
,
'
Cloud
'
,
'
Flooded
'
]
colors
=
[
'
black
'
,
'
darkblue
'
,
'
blue
'
,
'
lightblue
'
,
'
grey
'
,
'
white
'
,
'
yellow
'
]
lc
=
{}
for
i
,
(
l
,
c
)
in
enumerate
(
zip
(
labels
,
colors
)):
lc
[
i
]
=
{
'
label
'
:
l
,
'
color
'
:
c
}
return
lc
# _compose_scenes() creates a list of dictionaries containing the paths
# to the files of each scene
# if the scenes are divided into tiles, each tile has its own entry
# with corresponding tile id
def
_
compose_scenes
(
self
,
scene
):
def
compose_scenes
(
self
,
scene
):
# list the spectral bands of the scene
bands
=
glob
.
glob
(
os
.
path
.
join
(
scene
,
'
*B*.tif
'
))
...
...
@@ -346,84 +440,96 @@ class SparcsDataset(ImageDataset):
return
scene_data
# _read_scene() reads all the bands and the ground truth mask in a
# scene/tile to a numpy array and returns a dictionary with
# (key, value) = ('band_name', np.ndarray(band_data))
def
_read_scene
(
self
,
idx
):
# returns the band number of the preprocessed Sparcs Tiff files
def
_get_band_number
(
self
,
x
):
return
int
(
os
.
path
.
basename
(
x
).
split
(
'
_
'
)[
2
].
replace
(
'
B
'
,
''
))
# select a scene from the root directory
scene
=
self
.
scenes
[
idx
]
# _store_bands() writes the paths to the data of each scene to a dictionary
# only the bands of interest are stored
def
_store_bands
(
self
,
bands
,
gt
):
# read each band of the scene into a numpy array
scene_data
=
{
key
:
(
self
.
img2np
(
value
,
tile_size
=
self
.
tile_size
,
tile
=
scene
[
'
tile
'
])
if
key
!=
'
tile
'
else
value
)
for
key
,
value
in
scene
.
items
()}
# store the bands of interest in a dictionary
scene_data
=
{}
for
i
,
b
in
enumerate
(
bands
):
band
=
self
.
bands
[
self
.
_get_band_number
(
b
)]
if
band
in
self
.
use_bands
:
scene_data
[
band
]
=
b
# store ground truth
scene_data
[
'
gt
'
]
=
gt
return
scene_data
# _build_samples() stacks all bands of a scene/tile into a
# numpy array of shape (bands x height x width)
def
_build_samples
(
self
,
scene
):
# iterate over the channels to stack
stack
=
np
.
stack
([
scene
[
band
]
for
band
in
self
.
bands
],
axis
=
0
)
gt
=
scene
[
'
gt
'
]
class
Cloud95
(
ImageDataset
):
return
stack
,
gt
def
__init__
(
self
,
root_dir
,
use_bands
=
[],
tile_size
=
None
):
# initialize super class ImageDataset
super
().
__init__
(
root_dir
,
use_bands
,
tile_size
)
# plot_sample() plots a false color composite of the scene/tile together
# with the model prediction and the corresponding ground truth
def
plot_sample
(
self
,
x
,
y
,
y_pred
=
None
,
figsize
=
(
10
,
10
),
bands
=
[
'
nir
'
,
'
red
'
,
'
green
'
],
stretch
=
False
,
**
kwargs
):
# list of all scenes in the root directory
# each scene is divided into tiles blocks
self
.
scenes
=
self
.
compose_scenes
(
self
.
root
)
# check whether to apply constrast stretching
func
=
self
.
contrast_stretching
if
stretch
else
lambda
x
:
x
# image size of the Cloud-95 dataset: (height, width)
def
get_size
(
self
):
return
(
384
,
384
)
#
create an rgb stack
rgb
=
np
.
dstack
([
func
(
x
[
self
.
bands
.
index
(
band
)],
**
kwargs
)
for
band
in
bands
])
#
Landsat 8 bands in the Cloud-95 dataset
def
get_bands
(
self
):
return
{
1
:
'
red
'
,
2
:
'
green
'
,
3
:
'
blue
'
,
4
:
'
nir
'
}
# c
reate a ListedColormap
cmap
=
ListedColormap
(
self
.
colors
.
values
())
boundaries
=
[
*
self
.
colors
.
keys
(),
cmap
.
N
]
norm
=
BoundaryNorm
(
boundaries
,
cmap
.
N
)
# c
lass labels of the Cloud-95 dataset
def
get_labels
(
self
):
return
{
0
:
{
'
label
'
:
'
Clear
'
,
'
color
'
:
'
azure
'
},
1
:
{
'
label
'
:
'
Cloud
'
,
'
color
'
:
'
white
'
}}
# create figure: check whether to plot model prediction
if
y_pred
is
not
None
:
fig
,
ax
=
plt
.
subplots
(
1
,
3
,
figsize
=
figsize
)
ax
[
2
].
imshow
(
y_pred
,
cmap
=
cmap
,
interpolation
=
'
nearest
'
,
norm
=
norm
)
ax
[
2
].
set_title
(
'
Prediction
'
,
pad
=
20
)
else
:
fig
,
ax
=
plt
.
subplots
(
1
,
2
,
figsize
=
figsize
)
def
compose_scenes
(
self
,
root_dir
):
# plot false color composite
ax
[
0
].
imshow
(
rgb
)
ax
[
0
].
set_title
(
'
R = {}, G = {}, B = {}
'
.
format
(
*
bands
),
pad
=
20
)
# get the names of the directories containing the TIFF files of
# the bands of interest
band_dirs
=
{}
for
dirpath
,
dirname
,
files
in
os
.
walk
(
root_dir
):
# check if the current directory path includes the name of a band
# or the name of the ground truth mask
cband
=
[
band
for
band
in
self
.
bands
+
[
'
gt
'
]
if
band
in
dirpath
and
os
.
path
.
isdir
(
dirpath
)]
# plot ground thruth mask
ax
[
1
].
imshow
(
y
,
cmap
=
cmap
,
interpolation
=
'
nearest
'
,
norm
=
norm
)
ax
[
1
].
set_title
(
'
Ground truth
'
,
pad
=
20
)
# add path to current band files to dictionary
if
cband
:
band_dirs
[
cband
]
=
dirpath
# create a patch (proxy artist) for every color
patches
=
[
mpatches
.
Patch
(
color
=
c
,
label
=
l
)
for
c
,
l
in
zip
(
self
.
colors
.
values
(),
self
.
labels
.
values
())]
# create empty list to store all patches to
scenes
=
[]
#
plot patches as lege
nd
plt
.
legend
(
handles
=
patches
,
bbox_to_anchor
=
(
1.05
,
1
),
loc
=
2
,
frameon
=
False
)
#
iterate over all the patches of the following ba
nd
biter
=
self
.
bands
[
0
]
for
file
in
os
.
listdir
(
band_dirs
[
biter
]):
return
fig
,
ax
# initialize dictionary to store bands of current patch
scene
=
{}
# iterate over the bands of interest
for
band
in
band_dirs
.
keys
():
# save path to current band TIFF file to dictionary
scene
[
band
]
=
os
.
path
.
join
(
band_dirs
[
band
],
file
.
replace
(
biter
,
band
))
# append patch to list of all patches
scenes
.
append
(
scene
)
return
scenes
if
__name__
==
'
__main__
'
:
# path to the preprocessed sparcs dataset
sparcs_path
=
"
C:/Eurac/2020/
Tutorial/
Datasets/Sparcs
"
sparcs_path
=
"
C:/Eurac/2020/
_
Datasets/Sparcs
"
# sparcs_path = "/mnt/CEPH_PROJECTS/cci_snow/dfrisinghelli/Datasets/Sparcs"
# instanciate the SparcsDataset class
sparcs_dataset
=
SparcsDataset
(
sparcs_path
,
tile_size
=
None
,
bands
=-
1
)
sparcs_dataset
=
SparcsDataset
(
sparcs_path
,
tile_size
=
125
,
use_bands
=
[
'
nir
'
,
'
red
'
,
'
green
'
])
# randomly sample an integer from [0, nsamples]
sample
=
np
.
random
.
randint
(
len
(
sparcs_dataset
),
size
=
1
).
item
()
...
...
This diff is collapsed.
Click to expand it.
Preview
0%
Loading
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Save comment
Cancel
Please
register
or
sign in
to comment