Skip to content
GitLab
Explore
Sign in
Primary navigation
Search or go to…
Project
P
PySegCNN
Manage
Activity
Members
Labels
Plan
Issues
Issue boards
Milestones
Wiki
Code
Merge requests
Repository
Branches
Commits
Tags
Repository graph
Compare revisions
Snippets
Deploy
Releases
Package Registry
Container Registry
Model registry
Operate
Terraform modules
Monitor
Incidents
Analyze
Value stream analytics
Contributor analytics
Repository analytics
Model experiments
Help
Help
Support
GitLab documentation
Compare GitLab plans
Community forum
Contribute to GitLab
Provide feedback
Keyboard shortcuts
?
Snippets
Groups
Projects
Show more breadcrumbs
earth_observation_public
PySegCNN
Commits
5a99371d
Commit
5a99371d
authored
4 years ago
by
Frisinghelli Daniel
Browse files
Options
Downloads
Patches
Plain Diff
Created distinct class to split dataset into training, validation and test set
parent
6faa33d1
No related branches found
Branches containing commit
No related tags found
No related merge requests found
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
pysegcnn/core/split.py
+127
-41
127 additions, 41 deletions
pysegcnn/core/split.py
pysegcnn/core/trainer.py
+13
-21
13 additions, 21 deletions
pysegcnn/core/trainer.py
with
140 additions
and
62 deletions
pysegcnn/core/split.py
+
127
−
41
View file @
5a99371d
...
...
@@ -11,6 +11,9 @@ import datetime
import
numpy
as
np
from
torch.utils.data.dataset
import
Subset
# the names of the subsets
SUBSET_NAMES
=
[
'
train
'
,
'
valid
'
,
'
test
'
]
# function calculating number of samples in a dataset given a ratio
def
_ds_len
(
ds
,
ratio
):
...
...
@@ -37,29 +40,26 @@ def random_tile_split(ds, tvratio, ttratio=1, seed=0):
# length of the training dataset
# number of samples: (ttratio * tvratio * len(ds))
train_len
=
_ds_len
(
trav_indices
,
tvratio
)
train_ind
ices
=
trav_indices
[:
train_len
]
train_ind
=
trav_indices
[:
train_len
]
# length of the validation dataset
# number of samples: (ttratio * (1- tvratio) * len(ds))
valid_ind
ices
=
trav_indices
[
train_len
:]
valid_ind
=
trav_indices
[
train_len
:]
# length of the test dataset
# number of samples: ((1 - ttratio) * len(ds))
test_ind
ices
=
indices
[
trav_len
:]
test_ind
=
indices
[
trav_len
:]
# get the tiles of the scenes of each dataset
subsets
=
[]
for
dataset
in
[
train_indices
,
valid_indices
,
test_indices
]:
# build the subset: store the scenes
sbst
=
Subset
(
dataset
=
ds
,
indices
=
list
(
dataset
))
sbst
.
scenes
=
[
ds
.
scenes
[
i
]
for
i
in
dataset
]
subsets
=
{}
for
name
,
dataset
in
enumerate
([
train_ind
,
valid_ind
,
test_ind
]):
# add to list of subsets
subsets
.
append
(
sbst
)
# store the indices and corresponding tiles of the current subset to
# dictionary
subsets
[
SUBSET_NAMES
[
name
]]
=
{
k
:
ds
.
scenes
[
k
]
for
k
in
dataset
}
# check if the splits are disjoint
assert
pairwise_disjoint
([
s
.
indices
for
s
in
subsets
])
assert
pairwise_disjoint
([
s
.
keys
()
for
s
in
subsets
.
values
()
])
return
subsets
...
...
@@ -95,28 +95,16 @@ def random_scene_split(ds, tvratio, ttratio=1, seed=0):
test_scenes
=
scene_ids
[
trav_len
:]
# get the tiles of the scenes of each dataset
subsets
=
[]
for
dataset
in
[
train_scenes
,
valid_scenes
,
test_scenes
]:
# the indices of the scenes in the dataset
indices
=
[]
tiles
=
[]
# iterate over the scenes of the whole dataset
for
i
,
scene
in
enumerate
(
ds
.
scenes
):
if
scene
[
'
id
'
]
in
dataset
:
indices
.
append
(
i
)
tiles
.
append
(
scene
)
# build the subset: store scene ids
sbst
=
Subset
(
dataset
=
ds
,
indices
=
indices
)
sbst
.
scenes
=
tiles
sbst
.
ids
=
dataset
# add to list of subsets
subsets
.
append
(
sbst
)
subsets
=
{}
for
name
,
dataset
in
enumerate
([
train_scenes
,
valid_scenes
,
test_scenes
]):
# store the indices and corresponding tiles of the current subset to
# dictionary
subsets
[
SUBSET_NAMES
[
name
]]
=
{
k
:
v
for
k
,
v
in
enumerate
(
ds
.
scenes
)
if
v
[
'
id
'
]
in
dataset
}
# check if the splits are disjoint
assert
pairwise_disjoint
([
s
.
indices
for
s
in
subsets
])
assert
pairwise_disjoint
([
s
.
keys
()
for
s
in
subsets
.
values
()
])
return
subsets
...
...
@@ -135,18 +123,16 @@ def date_scene_split(ds, date, dateformat='%Y%m%d'):
test_scenes
=
{}
# build the training and test datasets
subsets
=
[]
for
scenes
in
[
train_scenes
,
valid_scenes
,
test_scenes
]:
# build the subset: store the scenes
sbst
=
Subset
(
dataset
=
ds
,
indices
=
list
(
scenes
.
keys
()))
sbst
.
scenes
=
list
(
scenes
.
values
())
sbst
.
ids
=
np
.
unique
([
s
[
'
id
'
]
for
s
in
scenes
.
values
()])
subsets
=
{}
for
name
,
scenes
in
enumerate
([
train_scenes
,
valid_scenes
,
test_scenes
]):
# add to list of subsets
subsets
.
append
(
sbst
)
# store the indices and corresponding tiles of the current subset to
# dictionary
subsets
[
SUBSET_NAMES
[
name
]]
=
scenes
# sbst.ids = np.unique([s['id'] for s in scenes.values()])
# check if the splits are disjoint
assert
pairwise_disjoint
([
s
.
indices
for
s
in
subsets
])
assert
pairwise_disjoint
([
s
.
keys
()
for
s
in
subsets
.
values
()
])
return
subsets
...
...
@@ -155,3 +141,103 @@ def pairwise_disjoint(sets):
union
=
set
().
union
(
*
sets
)
n
=
sum
(
len
(
u
)
for
u
in
sets
)
return
n
==
len
(
union
)
class
Split
(
object
):
# the valid modes
valid_modes
=
[
'
random
'
,
'
scene
'
,
'
date
'
]
def
__init__
(
self
,
ds
,
mode
,
**
kwargs
):
# check which mode is provided
if
mode
not
in
self
.
valid_modes
:
raise
ValueError
(
'
{} is not supported. Valid modes are {}, see
'
'
pysegcnn.main.config.py for a description of
'
'
each mode.
'
.
format
(
mode
,
self
.
valid_modes
))
self
.
mode
=
mode
# the dataset to split
self
.
ds
=
ds
# the keyword arguments
self
.
kwargs
=
kwargs
# initialize split
self
.
_init_split
()
def
_init_split
(
self
):
if
self
.
mode
==
'
random
'
:
self
.
subset
=
RandomSubset
self
.
split_function
=
random_tile_split
self
.
allowed_kwargs
=
[
'
tvratio
'
,
'
ttratio
'
,
'
seed
'
]
if
self
.
mode
==
'
scene
'
:
self
.
subset
=
SceneSubset
self
.
split_function
=
random_scene_split
self
.
allowed_kwargs
=
[
'
tvratio
'
,
'
ttratio
'
,
'
seed
'
]
if
self
.
mode
==
'
date
'
:
self
.
subset
=
SceneSubset
self
.
split_function
=
date_scene_split
self
.
allowed_kwargs
=
[
'
date
'
,
'
dateformat
'
]
self
.
_check_kwargs
()
def
_check_kwargs
(
self
):
# check if the correct keyword arguments are provided
if
not
set
(
self
.
allowed_kwargs
).
issubset
(
self
.
kwargs
.
keys
()):
raise
TypeError
(
'
__init__() expecting keyword arguments: {}.
'
.
format
(
'
,
'
.
join
(
kwa
for
kwa
in
self
.
allowed_kwargs
)))
# select the correct keyword arguments
self
.
kwargs
=
{
k
:
self
.
kwargs
[
k
]
for
k
in
self
.
allowed_kwargs
}
# function apply the split
def
split
(
self
):
# create the subsets
subsets
=
self
.
split_function
(
self
.
ds
,
**
self
.
kwargs
)
# build the subsets
ds_split
=
[]
for
name
,
sub
in
subsets
.
items
():
# the scene identifiers of the current subset
ids
=
np
.
unique
([
s
[
'
id
'
]
for
s
in
sub
.
values
()])
# build the subset
subset
=
self
.
subset
(
self
.
ds
,
list
(
sub
.
keys
()),
name
,
list
(
sub
.
values
()),
ids
)
ds_split
.
append
(
subset
)
return
ds_split
class
SceneSubset
(
Subset
):
def
__init__
(
self
,
ds
,
indices
,
name
,
scenes
,
scene_ids
):
super
().
__init__
(
dataset
=
ds
,
indices
=
indices
)
# the name of the subset
self
.
name
=
name
# the scene in the subset
self
.
scenes
=
scenes
# the names of the scenes
self
.
ids
=
scene_ids
class
RandomSubset
(
Subset
):
def
__init__
(
self
,
ds
,
indices
,
name
,
scenes
):
super
().
__init__
(
dataset
=
ds
,
indices
=
indices
)
# the name of the subset
self
.
name
=
name
# the scene in the subset
self
.
scenes
=
scenes
This diff is collapsed.
Click to expand it.
pysegcnn/core/trainer.py
+
13
−
21
View file @
5a99371d
...
...
@@ -17,9 +17,8 @@ from torch.utils.data import DataLoader
# locals
from
pysegcnn.core.dataset
import
SupportedDatasets
from
pysegcnn.core.layers
import
Conv2dSame
from
pysegcnn.core.utils
import
img2np
from
pysegcnn.core.split
import
(
random_tile_split
,
random_scene_split
,
date_scene_split
)
from
pysegcnn.core.utils
import
img2np
,
accuracy_function
from
pysegcnn.core.split
import
Split
class
NetworkTrainer
(
object
):
...
...
@@ -244,7 +243,7 @@ class NetworkTrainer(object):
return
training_state
def
predict
(
self
,
pretrained
=
False
,
confusion
=
False
):
def
predict
(
self
):
print
(
'
------------------------ Predicting --------------------------
'
)
...
...
@@ -341,18 +340,16 @@ class NetworkTrainer(object):
'
\n
'
.
join
(
name
for
name
,
_
in
SupportedDatasets
.
__members__
.
items
()))
# the training, validation and dataset
if
self
.
split_mode
==
'
random
'
:
self
.
train_ds
,
self
.
valid_ds
,
self
.
test_ds
=
random_tile_split
(
self
.
dataset
,
self
.
tvratio
,
self
.
ttratio
,
self
.
seed
)
if
self
.
split_mode
==
'
scene
'
:
self
.
train_ds
,
self
.
valid_ds
,
self
.
test_ds
=
random_scene_split
(
self
.
dataset
,
self
.
tvratio
,
self
.
ttratio
,
self
.
seed
)
# instanciate the Split class handling the dataset split
self
.
subset
=
Split
(
self
.
dataset
,
self
.
split_mode
,
tvratio
=
self
.
tvratio
,
ttratio
=
self
.
ttratio
,
seed
=
self
.
seed
,
date
=
self
.
date
,
dateformat
=
self
.
dateformat
)
if
self
.
split_mode
==
'
date
'
:
self
.
train_ds
,
self
.
valid_ds
,
self
.
test_ds
=
date_scene_split
(
self
.
dataset
,
self
.
date
)
# the training, validation and dataset
self
.
train_ds
,
self
.
valid_ds
,
self
.
test_ds
=
self
.
subset
.
split
()
# whether to drop training samples with a fraction of pixels equal to
# the constant padding value self.cval >= self.drop
...
...
@@ -409,7 +406,7 @@ class NetworkTrainer(object):
for
pos
,
i
in
enumerate
(
ds
.
indices
):
# the current scene
s
=
s
elf
.
dataset
.
scenes
[
i
]
s
=
d
s
.
dataset
.
scenes
[
i
]
# the current tile in the ground truth
tile_gt
=
img2np
(
s
[
'
gt
'
],
self
.
tile_size
,
s
[
'
tile
'
],
...
...
@@ -537,8 +534,3 @@ class EarlyStopping(object):
def
increased
(
self
,
metric
,
best
,
min_delta
):
return
metric
>
best
+
min_delta
# function calculating prediction accuracy
def
accuracy_function
(
outputs
,
labels
):
return
(
np
.
asarray
(
outputs
)
==
np
.
asarray
(
labels
)).
mean
()
This diff is collapsed.
Click to expand it.
Preview
0%
Loading
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Save comment
Cancel
Please
register
or
sign in
to comment