Skip to content
GitLab
Explore
Sign in
Primary navigation
Search or go to…
Project
P
PySegCNN
Manage
Activity
Members
Labels
Plan
Issues
Issue boards
Milestones
Wiki
Code
Merge requests
Repository
Branches
Commits
Tags
Repository graph
Compare revisions
Snippets
Deploy
Releases
Package registry
Container Registry
Model registry
Operate
Terraform modules
Monitor
Incidents
Analyze
Value stream analytics
Contributor analytics
Repository analytics
Model experiments
Help
Help
Support
GitLab documentation
Compare GitLab plans
Community forum
Contribute to GitLab
Provide feedback
Keyboard shortcuts
?
Snippets
Groups
Projects
Show more breadcrumbs
earth_observation_public
PySegCNN
Commits
2037ea8f
Commit
2037ea8f
authored
4 years ago
by
Frisinghelli Daniel
Browse files
Options
Downloads
Patches
Plain Diff
Moved random seed parameter to split configurations.
parent
ad3d88eb
No related branches found
No related tags found
No related merge requests found
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
pysegcnn/core/dataset.py
+20
-17
20 additions, 17 deletions
pysegcnn/core/dataset.py
with
20 additions
and
17 deletions
pysegcnn/core/dataset.py
+
20
−
17
View file @
2037ea8f
...
@@ -67,8 +67,6 @@ class ImageDataset(Dataset):
...
@@ -67,8 +67,6 @@ class ImageDataset(Dataset):
A regural expression to match the ground truth naming convention.
A regural expression to match the ground truth naming convention.
sort : `bool`
sort : `bool`
Whether to chronologically sort the samples.
Whether to chronologically sort the samples.
seed : `int`
The random seed.
transforms : `list`
transforms : `list`
List of :py:class:`pysegcnn.core.transforms.Augment` instances.
List of :py:class:`pysegcnn.core.transforms.Augment` instances.
merge_labels : `dict` [`str`, `str`]
merge_labels : `dict` [`str`, `str`]
...
@@ -113,7 +111,7 @@ class ImageDataset(Dataset):
...
@@ -113,7 +111,7 @@ class ImageDataset(Dataset):
"""
"""
def
__init__
(
self
,
root_dir
,
use_bands
=
[],
tile_size
=
None
,
pad
=
False
,
def
__init__
(
self
,
root_dir
,
use_bands
=
[],
tile_size
=
None
,
pad
=
False
,
gt_pattern
=
'
(.*)gt
\\
.tif
'
,
sort
=
False
,
seed
=
0
,
transforms
=
[],
gt_pattern
=
'
(.*)gt
\\
.tif
'
,
sort
=
False
,
transforms
=
[],
merge_labels
=
{}):
merge_labels
=
{}):
r
"""
Initialize.
r
"""
Initialize.
...
@@ -140,10 +138,6 @@ class ImageDataset(Dataset):
...
@@ -140,10 +138,6 @@ class ImageDataset(Dataset):
sort : `bool`, optional
sort : `bool`, optional
Whether to chronologically sort the samples. Useful for time series
Whether to chronologically sort the samples. Useful for time series
data. The default is `False`.
data. The default is `False`.
seed : `int`, optional
The random seed. Used to split the dataset into training,
validation and test set. Useful for reproducibility. The default is
`0`.
transforms : `list`, optional
transforms : `list`, optional
List of :py:class:`pysegcnn.core.transforms.Augment` instances.
List of :py:class:`pysegcnn.core.transforms.Augment` instances.
Each item in ``transforms`` generates a distinct transformed
Each item in ``transforms`` generates a distinct transformed
...
@@ -167,7 +161,6 @@ class ImageDataset(Dataset):
...
@@ -167,7 +161,6 @@ class ImageDataset(Dataset):
self
.
pad
=
pad
self
.
pad
=
pad
self
.
gt_pattern
=
gt_pattern
self
.
gt_pattern
=
gt_pattern
self
.
sort
=
sort
self
.
sort
=
sort
self
.
seed
=
seed
self
.
transforms
=
transforms
self
.
transforms
=
transforms
self
.
merge_labels
=
merge_labels
self
.
merge_labels
=
merge_labels
...
@@ -711,11 +704,11 @@ class StandardEoDataset(ImageDataset):
...
@@ -711,11 +704,11 @@ class StandardEoDataset(ImageDataset):
"""
"""
def
__init__
(
self
,
root_dir
,
use_bands
=
[],
tile_size
=
None
,
pad
=
False
,
def
__init__
(
self
,
root_dir
,
use_bands
=
[],
tile_size
=
None
,
pad
=
False
,
gt_pattern
=
'
(.*)gt
\\
.tif
'
,
sort
=
False
,
seed
=
0
,
transforms
=
[],
gt_pattern
=
'
(.*)gt
\\
.tif
'
,
sort
=
False
,
transforms
=
[],
merge_labels
=
{}):
merge_labels
=
{}):
# initialize super class ImageDataset
# initialize super class ImageDataset
super
().
__init__
(
root_dir
,
use_bands
,
tile_size
,
pad
,
gt_pattern
,
super
().
__init__
(
root_dir
,
use_bands
,
tile_size
,
pad
,
gt_pattern
,
sort
,
seed
,
transforms
,
merge_labels
)
sort
,
transforms
,
merge_labels
)
def
_get_band_number
(
self
,
path
):
def
_get_band_number
(
self
,
path
):
"""
Return the band number of a scene .tif file.
"""
Return the band number of a scene .tif file.
...
@@ -804,8 +797,12 @@ class StandardEoDataset(ImageDataset):
...
@@ -804,8 +797,12 @@ class StandardEoDataset(ImageDataset):
def
compose_scenes
(
self
):
def
compose_scenes
(
self
):
"""
Build the list of samples of the dataset.
"""
"""
Build the list of samples of the dataset.
"""
# search the root directory
# initialize scene list and counter
scenes
=
[]
scenes
=
[]
nscenes
=
0
# search the root directory
for
dirpath
,
dirname
,
files
in
os
.
walk
(
self
.
root
):
for
dirpath
,
dirname
,
files
in
os
.
walk
(
self
.
root
):
# search for a ground truth in the current directory
# search for a ground truth in the current directory
...
@@ -855,9 +852,15 @@ class StandardEoDataset(ImageDataset):
...
@@ -855,9 +852,15 @@ class StandardEoDataset(ImageDataset):
# store optional transformation
# store optional transformation
data
[
'
transform
'
]
=
transf
data
[
'
transform
'
]
=
transf
# store scene counter
data
[
'
scene
'
]
=
nscenes
# append to list
# append to list
scenes
.
append
(
data
)
scenes
.
append
(
data
)
# advance scene counter
nscenes
+=
1
# sort list of scenes and ground truths in chronological order
# sort list of scenes and ground truths in chronological order
if
self
.
sort
:
if
self
.
sort
:
scenes
.
sort
(
key
=
lambda
k
:
k
[
'
date
'
])
scenes
.
sort
(
key
=
lambda
k
:
k
[
'
date
'
])
...
@@ -878,11 +881,11 @@ class SparcsDataset(StandardEoDataset):
...
@@ -878,11 +881,11 @@ class SparcsDataset(StandardEoDataset):
"""
"""
def
__init__
(
self
,
root_dir
,
use_bands
=
[],
tile_size
=
None
,
pad
=
False
,
def
__init__
(
self
,
root_dir
,
use_bands
=
[],
tile_size
=
None
,
pad
=
False
,
gt_pattern
=
'
(.*)gt
\\
.tif
'
,
sort
=
False
,
seed
=
0
,
transforms
=
[],
gt_pattern
=
'
(.*)gt
\\
.tif
'
,
sort
=
False
,
transforms
=
[],
merge_labels
=
{}):
merge_labels
=
{}):
# initialize super class StandardEoDataset
# initialize super class StandardEoDataset
super
().
__init__
(
root_dir
,
use_bands
,
tile_size
,
pad
,
gt_pattern
,
super
().
__init__
(
root_dir
,
use_bands
,
tile_size
,
pad
,
gt_pattern
,
sort
,
seed
,
transforms
,
merge_labels
)
sort
,
transforms
,
merge_labels
)
@staticmethod
@staticmethod
def
get_size
():
def
get_size
():
...
@@ -951,11 +954,11 @@ class AlcdDataset(StandardEoDataset):
...
@@ -951,11 +954,11 @@ class AlcdDataset(StandardEoDataset):
"""
"""
def
__init__
(
self
,
root_dir
,
use_bands
=
[],
tile_size
=
None
,
pad
=
False
,
def
__init__
(
self
,
root_dir
,
use_bands
=
[],
tile_size
=
None
,
pad
=
False
,
gt_pattern
=
'
(.*)gt
\\
.tif
'
,
sort
=
False
,
seed
=
0
,
transforms
=
[],
gt_pattern
=
'
(.*)gt
\\
.tif
'
,
sort
=
False
,
transforms
=
[],
merge_labels
=
{}):
merge_labels
=
{}):
# initialize super class StandardEoDataset
# initialize super class StandardEoDataset
super
().
__init__
(
root_dir
,
use_bands
,
tile_size
,
pad
,
gt_pattern
,
super
().
__init__
(
root_dir
,
use_bands
,
tile_size
,
pad
,
gt_pattern
,
sort
,
seed
,
transforms
,
merge_labels
)
sort
,
transforms
,
merge_labels
)
@staticmethod
@staticmethod
def
get_size
():
def
get_size
():
...
@@ -1024,11 +1027,11 @@ class Cloud95Dataset(ImageDataset):
...
@@ -1024,11 +1027,11 @@ class Cloud95Dataset(ImageDataset):
"""
"""
def
__init__
(
self
,
root_dir
,
use_bands
=
[],
tile_size
=
None
,
pad
=
False
,
def
__init__
(
self
,
root_dir
,
use_bands
=
[],
tile_size
=
None
,
pad
=
False
,
gt_pattern
=
'
(.*)gt
\\
.tif
'
,
sort
=
False
,
seed
=
0
,
transforms
=
[],
gt_pattern
=
'
(.*)gt
\\
.tif
'
,
sort
=
False
,
transforms
=
[],
merge_labels
=
{}):
merge_labels
=
{}):
# initialize super class StandardEoDataset
# initialize super class StandardEoDataset
super
().
__init__
(
root_dir
,
use_bands
,
tile_size
,
pad
,
gt_pattern
,
super
().
__init__
(
root_dir
,
use_bands
,
tile_size
,
pad
,
gt_pattern
,
sort
,
seed
,
transforms
,
merge_labels
)
sort
,
transforms
,
merge_labels
)
# the csv file containing the names of the informative patches
# the csv file containing the names of the informative patches
# patches with more than 80% black pixels, i.e. patches resulting from
# patches with more than 80% black pixels, i.e. patches resulting from
...
...
This diff is collapsed.
Click to expand it.
Preview
0%
Loading
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Save comment
Cancel
Please
register
or
sign in
to comment