Skip to content
GitLab
Explore
Sign in
Primary navigation
Search or go to…
Project
P
PySegCNN
Manage
Activity
Members
Labels
Plan
Issues
Issue boards
Milestones
Wiki
Code
Merge requests
Repository
Branches
Commits
Tags
Repository graph
Compare revisions
Snippets
Deploy
Releases
Package Registry
Container Registry
Model registry
Operate
Terraform modules
Monitor
Incidents
Analyze
Value stream analytics
Contributor analytics
Repository analytics
Model experiments
Help
Help
Support
GitLab documentation
Compare GitLab plans
Community forum
Contribute to GitLab
Provide feedback
Keyboard shortcuts
?
Snippets
Groups
Projects
Show more breadcrumbs
earth_observation_public
PySegCNN
Commits
9374ee3e
Commit
9374ee3e
authored
4 years ago
by
Frisinghelli Daniel
Browse files
Options
Downloads
Patches
Plain Diff
Improved train/valid/test split workflow
parent
d9a3b557
No related branches found
No related tags found
No related merge requests found
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
pysegcnn/core/split.py
+94
-63
94 additions, 63 deletions
pysegcnn/core/split.py
with
94 additions
and
63 deletions
pysegcnn/core/split.py
+
94
−
63
View file @
9374ee3e
...
...
@@ -14,6 +14,9 @@ from torch.utils.data.dataset import Subset
# the names of the subsets
SUBSET_NAMES
=
[
'
train
'
,
'
valid
'
,
'
test
'
]
# valid split modes
VALID_SPLIT_MODES
=
[
'
random
'
,
'
scene
'
,
'
date
'
]
# function calculating number of samples in a dataset given a ratio
def
_ds_len
(
ds
,
ratio
):
...
...
@@ -129,7 +132,6 @@ def date_scene_split(ds, date, dateformat='%Y%m%d'):
# store the indices and corresponding tiles of the current subset to
# dictionary
subsets
[
SUBSET_NAMES
[
name
]]
=
scenes
# sbst.ids = np.unique([s['id'] for s in scenes.values()])
# check if the splits are disjoint
assert
pairwise_disjoint
([
s
.
keys
()
for
s
in
subsets
.
values
()])
...
...
@@ -143,101 +145,130 @@ def pairwise_disjoint(sets):
return
n
==
len
(
union
)
class
S
plit
(
objec
t
):
class
S
ceneSubset
(
Subse
t
):
# the valid modes
valid_modes
=
[
'
random
'
,
'
scene
'
,
'
date
'
]
def
__init__
(
self
,
ds
,
indices
,
name
,
scenes
,
scene_ids
):
super
().
__init__
(
dataset
=
ds
,
indices
=
indices
)
def
__init__
(
self
,
ds
,
mode
,
**
kwargs
):
# the name of the subset
self
.
name
=
name
# check which mode is provided
if
mode
not
in
self
.
valid_modes
:
raise
ValueError
(
'
{} is not supported. Valid modes are {}, see
'
'
pysegcnn.main.config.py for a description of
'
'
each mode.
'
.
format
(
mode
,
self
.
valid_modes
))
self
.
mode
=
mode
# the scene in the subset
self
.
scenes
=
scenes
# the
dataset to split
self
.
ds
=
ds
# the
names of the scenes
self
.
i
ds
=
scene_i
ds
# the keyword arguments
self
.
kwargs
=
kwargs
# initialize split
self
.
_init_split
()
class
RandomSubset
(
Subset
):
def
_init_split
(
self
):
def
__init__
(
self
,
ds
,
indices
,
name
,
scenes
,
scene_ids
):
super
().
__init__
(
dataset
=
ds
,
indices
=
indices
)
if
self
.
mode
==
'
random
'
:
self
.
subset
=
RandomSubset
self
.
split_function
=
random_tile_split
self
.
allowed_kwargs
=
[
'
tvratio
'
,
'
ttratio
'
,
'
seed
'
]
# the name of the subset
self
.
name
=
name
if
self
.
mode
==
'
scene
'
:
self
.
subset
=
SceneSubset
self
.
split_function
=
random_scene_split
self
.
allowed_kwargs
=
[
'
tvratio
'
,
'
ttratio
'
,
'
seed
'
]
# the scene in the subset
self
.
scenes
=
scenes
if
self
.
mode
==
'
date
'
:
self
.
subset
=
SceneSubset
self
.
split_function
=
date_scene_split
self
.
allowed_kwargs
=
[
'
date
'
,
'
dateformat
'
]
self
.
_check_kwargs
()
class
Split
(
object
):
def
_
check_kwargs
(
self
):
def
_
_init__
(
self
,
ds
):
# check if the correct keyword arguments are provided
if
not
set
(
self
.
allowed_kwargs
).
issubset
(
self
.
kwargs
.
keys
()):
raise
TypeError
(
'
__init__() expecting keyword arguments: {}.
'
.
format
(
'
,
'
.
join
(
kwa
for
kwa
in
self
.
allowed_kwargs
)))
# select the correct keyword arguments
self
.
kwargs
=
{
k
:
self
.
kwargs
[
k
]
for
k
in
self
.
allowed_kwargs
}
# the dataset to split
self
.
ds
=
ds
# function apply the split
def
split
(
self
):
# create the subsets
subsets
=
self
.
split_function
(
self
.
ds
,
**
self
.
kwargs
)
# build the subsets
ds_split
=
[]
for
name
,
sub
in
subsets
.
items
():
for
name
,
sub
in
self
.
subsets
()
.
items
():
# the scene identifiers of the current subset
ids
=
np
.
unique
([
s
[
'
id
'
]
for
s
in
sub
.
values
()])
# build the subset
s
u
bs
e
t
=
self
.
subset
(
self
.
ds
,
list
(
sub
.
keys
()),
name
,
list
(
sub
.
values
()),
ids
)
ds_split
.
append
(
s
u
bs
e
t
)
sbst
=
self
.
subset
_type
()
(
self
.
ds
,
list
(
sub
.
keys
()),
name
,
list
(
sub
.
values
()),
ids
)
ds_split
.
append
(
sbst
)
return
ds_split
@property
def
subsets
(
self
):
raise
NotImplementedError
class
SceneSubset
(
Subset
):
def
subset_type
(
self
):
raise
NotImplementedError
def
__init__
(
self
,
ds
,
indices
,
name
,
scenes
,
scene_ids
):
super
().
__init__
(
dataset
=
ds
,
indices
=
indices
)
def
__repr__
(
self
):
#
the name of the subse
t
self
.
name
=
name
#
representation string to prin
t
fs
=
self
.
__class__
.
__name__
+
'
(
\n
'
# the scene in the subset
self
.
scenes
=
scenes
# dataset split
fs
+=
'
\n
'
.
join
(
'
- {}: {:d} batches ({:.2f}%)
'
.
format
(
k
,
len
(
v
),
len
(
v
)
*
100
/
len
(
self
.
ds
))
for
k
,
v
in
self
.
subsets
().
items
())
fs
+=
'
\n
)
'
return
fs
# the names of the scenes
self
.
ids
=
scene_ids
class
DateSplit
(
Split
):
def
__init__
(
self
,
ds
,
date
,
dateformat
):
super
().
__init__
(
ds
)
class
RandomSubset
(
Subset
):
# the date to split the dataset
# before: training set
# after : validation set
self
.
date
=
date
def
__init__
(
self
,
ds
,
indices
,
name
,
scenes
):
s
uper
().
__init__
(
dataset
=
ds
,
indices
=
indices
)
# the format of the date
s
elf
.
dateformat
=
dateformat
# the name of the subset
self
.
name
=
name
def
subsets
(
self
):
return
date_scene_split
(
self
.
ds
,
self
.
date
,
self
.
dateformat
)
# the scene in the subset
self
.
scenes
=
scenes
def
subset_type
(
self
):
return
SceneSubset
class
RandomSplit
(
Split
):
def
__init__
(
self
,
ds
,
ttratio
,
tvratio
,
seed
):
super
().
__init__
(
ds
)
# the training, validation and test set ratios
self
.
ttratio
=
ttratio
self
.
tvratio
=
tvratio
# the random seed: useful for reproducibility
self
.
seed
=
seed
class
RandomTileSplit
(
RandomSplit
):
def
__init__
(
self
,
ds
,
ttratio
,
tvratio
,
seed
):
super
().
__init__
(
ds
,
ttratio
,
tvratio
,
seed
)
def
subsets
(
self
):
return
random_tile_split
(
self
.
ds
,
self
.
tvratio
,
self
.
ttratio
,
self
.
seed
)
def
subset_type
(
self
):
return
RandomSubset
class
RandomSceneSplit
(
RandomSplit
):
def
__init__
(
self
,
ds
,
ttratio
,
tvratio
,
seed
):
super
().
__init__
(
ds
,
ttratio
,
tvratio
,
seed
)
def
subsets
(
self
):
return
random_scene_split
(
self
.
ds
,
self
.
tvratio
,
self
.
ttratio
,
self
.
seed
)
def
subset_type
(
self
):
return
SceneSubset
This diff is collapsed.
Click to expand it.
Preview
0%
Loading
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Save comment
Cancel
Please
register
or
sign in
to comment