Skip to content
GitLab
Explore
Sign in
Primary navigation
Search or go to…
Project
P
PySegCNN
Manage
Activity
Members
Labels
Plan
Issues
Issue boards
Milestones
Wiki
Code
Merge requests
Repository
Branches
Commits
Tags
Repository graph
Compare revisions
Snippets
Deploy
Releases
Package Registry
Container Registry
Model registry
Operate
Terraform modules
Monitor
Incidents
Analyze
Value stream analytics
Contributor analytics
Repository analytics
Model experiments
Help
Help
Support
GitLab documentation
Compare GitLab plans
Community forum
Contribute to GitLab
Provide feedback
Keyboard shortcuts
?
Snippets
Groups
Projects
Show more breadcrumbs
earth_observation_public
PySegCNN
Commits
f50e3a2b
Commit
f50e3a2b
authored
4 years ago
by
Frisinghelli Daniel
Browse files
Options
Downloads
Patches
Plain Diff
Increased modularity
parent
f881ac7d
No related branches found
No related tags found
No related merge requests found
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
pysegcnn/core/trainer.py
+455
-290
455 additions, 290 deletions
pysegcnn/core/trainer.py
with
455 additions
and
290 deletions
pysegcnn/core/trainer.py
+
455
−
290
View file @
f50e3a2b
# !/usr/bin/env python
# -*- coding: utf-8 -*-
"""
Created on
Fri Jun 26 16:31
:3
6
2020
Created on
Wed Aug 12 10:24
:3
4
2020
@author: Daniel
"""
# builtins
import
os
import
dataclasses
import
pathlib
# externals
import
numpy
as
np
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
torch.utils.data
import
DataLoader
from
torch.utils.data
import
DataLoader
,
Dataset
from
torch.optim
import
Optimizer
# locals
from
pysegcnn.core.dataset
import
SupportedDatasets
from
pysegcnn.core.dataset
import
SupportedDatasets
,
ImageDataset
from
pysegcnn.core.transforms
import
Augment
from
pysegcnn.core.utils
import
img2np
,
item_in_enum
,
accuracy_function
from
pysegcnn.core.split
import
SupportedSplits
from
pysegcnn.core.models
import
(
SupportedModels
,
SupportedOptimizers
,
SupportedLossFunctions
,
Network
)
from
pysegcnn.core.layers
import
Conv2dSame
from
pysegcnn.core.utils
import
img2np
,
accuracy_function
from
pysegcnn.core.split
import
(
RandomTileSplit
,
RandomSceneSplit
,
DateSplit
,
VALID_SPLIT_MODES
)
from
pysegcnn.main.config
import
HERE
@dataclasses.dataclass
class
BaseConfig
:
def
__post_init__
(
self
):
# check input types
for
field
in
dataclasses
.
fields
(
self
):
# the value of the current field
value
=
getattr
(
self
,
field
.
name
)
# check whether the value is of the correct type
if
not
isinstance
(
value
,
field
.
type
):
# try to convert the value to the correct type
try
:
setattr
(
self
,
field
.
name
,
field
.
type
(
value
))
except
TypeError
:
# raise an exception if the conversion fails
raise
TypeError
(
'
Expected {} to be {}, got {}.
'
.
format
(
field
.
name
,
field
.
type
,
type
(
value
)))
@dataclasses.dataclass
class
DatasetConfig
(
BaseConfig
):
dataset_name
:
str
root_dir
:
pathlib
.
Path
bands
:
list
tile_size
:
int
gt_pattern
:
str
seed
:
int
sort
:
bool
=
False
transforms
:
list
=
dataclasses
.
field
(
default_factory
=
list
)
pad
:
bool
=
False
cval
:
int
=
99
def
__post_init__
(
self
):
# check input types
super
().
__post_init__
()
# check whether the dataset is currently supported
self
.
dataset_class
=
item_in_enum
(
self
.
dataset_name
,
SupportedDatasets
)
class
NetworkTrainer
(
object
):
# check whether the root directory exists
if
not
self
.
root_dir
.
exists
():
raise
FileNotFoundError
(
'
{} does not exist.
'
.
format
(
self
.
root_dir
))
def
__init__
(
self
,
config
):
# check whether the transformations inherit from the correct class
if
not
all
([
isinstance
(
t
,
Augment
)
for
t
in
self
.
transforms
if
self
.
transforms
]):
raise
TypeError
(
'
Each transformation is expected to be an instance
'
'
of {}.
'
.
format
(
'
.
'
.
join
([
Augment
.
__module__
,
Augment
.
__name__
])))
# the configuration file as defined in pysegcnn.main.config.py
for
k
,
v
in
config
.
items
():
setattr
(
self
,
k
,
v
)
# check whether the constant padding value is within the valid range
if
not
0
<
self
.
cval
<
255
:
raise
ValueError
(
'
Expecting 0 <= cval <= 255, got cval={}.
'
.
format
(
self
.
cval
))
# whether to use the gpu
self
.
device
=
torch
.
device
(
"
cuda:0
"
if
torch
.
cuda
.
is_available
()
else
"
cpu
"
)
def
init_dataset
(
self
):
# instanciate the dataset
dataset
=
self
.
dataset_class
(
root_dir
=
str
(
self
.
root_dir
),
use_bands
=
self
.
bands
,
tile_size
=
self
.
tile_size
,
seed
=
self
.
seed
,
sort
=
self
.
sort
,
transforms
=
self
.
transforms
,
pad
=
self
.
pad
,
cval
=
self
.
cval
,
gt_pattern
=
self
.
gt_pattern
)
return
dataset
@dataclasses.dataclass
class
SplitConfig
(
BaseConfig
):
split_mode
:
str
ttratio
:
float
tvratio
:
float
date
:
str
=
'
yyyymmdd
'
dateformat
:
str
=
'
%Y%m%d
'
drop
:
float
=
0
def
__post_init__
(
self
):
# check input types
super
().
__post_init__
()
# check if the split mode is valid
self
.
split_class
=
item_in_enum
(
self
.
split_mode
,
SupportedSplits
)
# function to drop samples with a fraction of pixels equal to the constant
# padding value self.cval >= self.drop
def
_drop_samples
(
self
,
ds
,
drop_threshold
=
1
):
# iterate over the scenes returned by self.compose_scenes()
dropped
=
[]
for
pos
,
i
in
enumerate
(
ds
.
indices
):
# the current scene
s
=
ds
.
dataset
.
scenes
[
i
]
# the current tile in the ground truth
tile_gt
=
img2np
(
s
[
'
gt
'
],
ds
.
dataset
.
tile_size
,
s
[
'
tile
'
],
ds
.
dataset
.
pad
,
ds
.
dataset
.
cval
)
# initialize the dataset to train the model on
self
.
_init_dataset
()
# percent of pixels equal to the constant padding value
npixels
=
(
tile_gt
[
tile_gt
==
ds
.
dataset
.
cval
].
size
/
tile_gt
.
size
)
# drop samples where npixels >= self.drop
if
npixels
>=
drop_threshold
:
print
(
'
Skipping scene {}, tile {}: {:.2f}% padded pixels ...
'
.
format
(
s
[
'
id
'
],
s
[
'
tile
'
],
npixels
*
100
))
dropped
.
append
(
s
)
_
=
ds
.
indices
.
pop
(
pos
)
return
dropped
def
train_val_test_split
(
self
,
ds
):
if
not
isinstance
(
ds
,
ImageDataset
):
raise
TypeError
(
'
Expected
"
ds
"
to be {}.
'
.
format
(
'
.
'
.
join
([
ImageDataset
.
__module__
,
ImageDataset
.
__name__
])))
if
self
.
split_mode
==
'
random
'
or
self
.
split_mode
==
'
scene
'
:
subset
=
self
.
split_class
(
ds
,
self
.
ttratio
,
self
.
tvratio
,
ds
.
seed
)
else
:
subset
=
self
.
split_class
(
ds
,
self
.
date
,
self
.
dateformat
)
# the training, validation and test dataset
train_ds
,
valid_ds
,
test_ds
=
subset
.
split
()
# whether to drop training samples with a fraction of pixels equal to
# the constant padding value cval >= drop
if
ds
.
pad
and
self
.
drop
>
0
:
self
.
dropped
=
self
.
_drop_samples
(
train_ds
,
self
.
drop
)
return
train_ds
,
valid_ds
,
test_ds
@staticmethod
def
dataloaders
(
*
args
,
**
kwargs
):
# check whether each dataset in args has the correct type
loaders
=
[]
for
ds
in
args
:
if
not
isinstance
(
ds
,
Dataset
):
raise
TypeError
(
'
Expected {}, got {}.
'
.
format
(
repr
(
Dataset
),
type
(
ds
)))
# check if the dataset is not empty
if
len
(
ds
)
>
0
:
# build the dataloader
loader
=
DataLoader
(
ds
,
**
kwargs
)
else
:
loader
=
None
loaders
.
append
(
loader
)
# initialize the model state files
self
.
_init_state
()
return
loaders
# initialize the model
self
.
_init_model
()
def
from_pretrained
(
self
):
@dataclasses.dataclass
class
ModelConfig
(
BaseConfig
):
model_name
:
str
filters
:
list
torch_seed
:
int
skip_connection
:
bool
=
True
kwargs
:
dict
=
dataclasses
.
field
(
default_factory
=
lambda
:
{
'
kernel_size
'
:
3
,
'
stride
'
:
1
,
'
dilation
'
:
1
})
state_path
:
pathlib
.
Path
=
pathlib
.
Path
(
HERE
).
joinpath
(
'
_models/
'
)
batch_size
:
int
=
64
checkpoint
:
bool
=
False
pretrained
:
bool
=
False
pretrained_model
:
str
=
''
def
__post_init__
(
self
):
# check input types
super
().
__post_init__
()
# check whether the model is currently supported
self
.
model_class
=
item_in_enum
(
self
.
model_name
,
SupportedModels
)
def
init_state
(
self
,
ds
,
sc
,
tc
):
# file to save model state to:
# network_dataset_optim_split_splitparams_tilesize_batchsize_bands.pt
# model state filename
state_file
=
'
{}_{}_{}_{}Split_{}_t{}_b{}_{}.pt
'
# get the band numbers
bformat
=
''
.
join
(
band
[
0
]
+
str
(
ds
.
sensor
.
__members__
[
band
].
value
)
for
band
in
ds
.
use_bands
)
# check which split mode was used
if
sc
.
split_mode
==
'
date
'
:
# store the date that was used to split the dataset
state_file
=
state_file
.
format
(
self
.
model_class
.
__name__
,
ds
.
__class__
.
__name__
,
tc
.
optim_name
,
sc
.
split_mode
.
capitalize
(),
sc
.
date
,
ds
.
tile_size
,
self
.
batch_size
,
bformat
)
else
:
# store the random split parameters
split_params
=
'
s{}_t{}v{}
'
.
format
(
ds
.
seed
,
str
(
sc
.
ttratio
).
replace
(
'
.
'
,
''
),
str
(
sc
.
tvratio
).
replace
(
'
.
'
,
''
))
# model state filename
state_file
=
state_file
.
format
(
self
.
model_class
.
__name__
,
ds
.
__class__
.
__name__
,
tc
.
optim_name
,
sc
.
split_mode
.
capitalize
(),
split_params
,
ds
.
tile_size
,
self
.
batch_size
,
bformat
)
# check whether a pretrained model was used and change state filename
# accordingly
if
self
.
pretrained
:
# add the configuration of the pretrained model to the state name
state_file
=
(
state_file
.
replace
(
'
.pt
'
,
'
_
'
)
+
'
pretrained_
'
+
self
.
pretrained_model
)
# path to model state
state
=
self
.
state_path
.
joinpath
(
state_file
)
# path to model loss/accuracy
loss_state
=
pathlib
.
Path
(
str
(
state
).
replace
(
'
.pt
'
,
'
_loss.pt
'
))
return
state
,
loss_state
def
init_model
(
self
,
ds
):
# case (1): build a new model
if
not
self
.
pretrained
:
# set the random seed for reproducibility
torch
.
manual_seed
(
self
.
torch_seed
)
# instanciate the model
model
=
self
.
model_class
(
in_channels
=
len
(
ds
.
use_bands
),
nclasses
=
len
(
ds
.
labels
),
filters
=
self
.
filters
,
skip
=
self
.
skip_connection
,
**
self
.
kwargs
)
# case (2): load a pretrained model
else
:
# load pretrained model
model
=
self
.
load_pretrained
()
return
model
def
load_checkpoint
(
self
,
state_file
,
loss_state
,
model
,
optimizer
):
# initial accuracy on the validation set
max_accuracy
=
0
# set the model checkpoint to None, overwritten when resuming
# training from an existing model checkpoint
checkpoint_state
=
{}
# whether to resume training from an existing model
if
self
.
checkpoint
:
# check if a model checkpoint exists
if
not
state_file
.
exists
():
raise
FileNotFoundError
(
'
Model checkpoint {} does not exist.
'
.
format
(
state_file
))
# load the model state
state
=
model
.
load
(
state_file
.
name
,
optimizer
,
self
.
state_path
)
print
(
'
Found checkpoint: {}
'
.
format
(
state
))
print
(
'
Resuming training from checkpoint ...
'
.
format
(
state
))
print
(
'
Model epoch: {:d}
'
.
format
(
model
.
epoch
))
# load the model loss and accuracy
checkpoint_state
=
torch
.
load
(
loss_state
)
# get all non-zero elements, i.e. get number of epochs trained
# before the early stop
checkpoint_state
=
{
k
:
v
[
np
.
nonzero
(
v
)].
reshape
(
v
.
shape
[
0
],
-
1
)
for
k
,
v
in
checkpoint_state
.
items
()}
# maximum accuracy on the validation set
max_accuracy
=
checkpoint_state
[
'
va
'
][:,
-
1
].
mean
().
item
()
return
checkpoint_state
,
max_accuracy
def
load_pretrained
(
self
,
ds
):
# load the pretrained model
model_state
=
os
.
path
.
join
(
self
.
state_path
,
self
.
pretrained_model
)
if
not
os
.
p
at
h
.
exists
(
model_state
):
model_state
=
self
.
state_path
.
joinpath
(
self
.
pretrained_model
)
if
not
model_st
at
e
.
exists
():
raise
FileNotFoundError
(
'
Pretrained model {} does not exist.
'
.
format
(
model_state
))
...
...
@@ -61,23 +346,24 @@ class NetworkTrainer(object):
filters
=
model_state
[
'
params
'
][
'
filters
'
]
# check whether the current dataset uses the correct spectral bands
if
self
.
bands
!=
bands
:
if
ds
.
use_
bands
!=
bands
:
raise
ValueError
(
'
The bands of the pretrained network do not
'
'
match the specified bands: {}
'
.
format
(
self
.
bands
))
.
format
(
bands
))
# instanciate pretrained model architecture
model
=
self
.
model
(
**
model_state
[
'
params
'
],
**
model_state
[
'
kwargs
'
])
model
=
self
.
model_class
(
**
model_state
[
'
params
'
],
**
model_state
[
'
kwargs
'
])
# load pretrained model weights
model
.
load
(
self
.
pretrained_model
,
inpath
=
self
.
state_path
)
model
.
load
(
self
.
pretrained_model
,
inpath
=
str
(
self
.
state_path
)
)
# reset model epoch to 0, since the model is trained on a different
# dataset
model
.
epoch
=
0
# adjust the number of classes in the model
model
.
nclasses
=
len
(
s
elf
.
dataset
.
labels
)
model
.
nclasses
=
len
(
d
s
.
labels
)
# adjust the classification layer to the number of classes of the
# current dataset
...
...
@@ -85,49 +371,103 @@ class NetworkTrainer(object):
out_channels
=
model
.
nclasses
,
kernel_size
=
1
)
return
model
def
from_checkpoint
(
self
):
# whether to resume training from an existing model
if
not
os
.
path
.
exists
(
self
.
state
):
raise
FileNotFoundError
(
'
Model checkpoint {} does not exist.
'
.
format
(
self
.
state
))
# load the model state
state
=
self
.
model
.
load
(
self
.
state_file
,
self
.
optimizer
,
self
.
state_path
)
print
(
'
Resuming training from {} ...
'
.
format
(
state
))
print
(
'
Model epoch: {:d}
'
.
format
(
self
.
model
.
epoch
))
# load the model loss and accuracy
checkpoint_state
=
torch
.
load
(
self
.
loss_state
)
# get all non-zero elements, i.e. get number of epochs trained
# before the early stop
checkpoint_state
=
{
k
:
v
[
np
.
nonzero
(
v
)].
reshape
(
v
.
shape
[
0
],
-
1
)
for
k
,
v
in
checkpoint_state
.
items
()}
@dataclasses.dataclass
class
TrainConfig
(
BaseConfig
):
optim_name
:
str
loss_name
:
str
lr
:
float
=
0.001
early_stop
:
bool
=
False
mode
:
str
=
'
max
'
delta
:
float
=
0
patience
:
int
=
10
epochs
:
int
=
50
nthreads
:
int
=
torch
.
get_num_threads
()
save
:
bool
=
True
def
__post_init__
(
self
):
super
().
__post_init__
()
# check whether the optimizer is currently supported
self
.
optim_class
=
item_in_enum
(
self
.
optim_name
,
SupportedOptimizers
)
# check whether the loss function is currently supported
self
.
loss_class
=
item_in_enum
(
self
.
loss_name
,
SupportedLossFunctions
)
def
init_optimizer
(
self
,
model
):
# initialize the optimizer for the specified model
optimizer
=
self
.
optim_class
(
model
.
parameters
(),
self
.
lr
)
return
optimizer
def
init_loss_function
(
self
):
loss_function
=
self
.
loss_class
()
return
loss_function
@dataclasses.dataclass
class
EvalConfig
(
BaseConfig
):
test
:
object
predict_scene
:
bool
=
False
plot_samples
:
bool
=
False
plot_scenes
:
bool
=
False
plot_bands
:
list
=
dataclasses
.
field
(
default_factory
=
lambda
:
[
'
nir
'
,
'
red
'
,
'
green
'
])
cm
:
bool
=
True
def
__post_init__
(
self
):
super
().
__post_init__
()
# check whether the test input parameter is correctly specified
if
self
.
test
not
in
[
None
,
False
,
True
]:
raise
TypeError
(
'
Expected
"
test
"
to be None, True or False, got
'
'
{}.
'
.
format
(
self
.
test
))
@dataclasses.dataclass
class
NetworkTrainer
(
BaseConfig
):
model
:
Network
optimizer
:
Optimizer
loss_function
:
nn
.
Module
train_dl
:
DataLoader
valid_dl
:
DataLoader
state_file
:
pathlib
.
Path
loss_state
:
pathlib
.
Path
epochs
:
int
=
1
nthreads
:
int
=
torch
.
get_num_threads
()
early_stop
:
bool
=
False
mode
:
str
=
'
max
'
delta
:
float
=
0
patience
:
int
=
10
max_accuracy
:
float
=
0
checkpoint_state
:
dict
=
dataclasses
.
field
(
default_factory
=
dict
)
save
:
bool
=
True
def
__post_init__
(
self
):
super
().
__post_init__
()
# maximum accuracy on the validation set
max_accuracy
=
checkpoint_state
[
'
va
'
][:,
-
1
].
mean
().
item
()
# whether to use the gpu
self
.
device
=
torch
.
device
(
"
cuda:0
"
if
torch
.
cuda
.
is_available
()
else
"
cpu
"
)
return
checkpoint_state
,
max_accuracy
# whether to use early stopping
self
.
es
=
None
if
self
.
early_stop
:
self
.
es
=
EarlyStopping
(
self
.
mode
,
self
.
delta
,
self
.
patience
)
def
train
(
self
):
print
(
'
------------------------- Training ---------------------------
'
)
# set the number of threads
print
(
'
Device: {}
'
.
format
(
self
.
device
))
print
(
'
Number of cpu threads: {}
'
.
format
(
self
.
nthreads
))
torch
.
set_num_threads
(
self
.
nthreads
)
# instanciate early stopping class
if
self
.
early_stop
:
es
=
EarlyStopping
(
self
.
mode
,
self
.
delta
,
self
.
patience
)
print
(
'
Initializing early stopping ...
'
)
print
(
'
mode = {}, delta = {}, patience = {} epochs ...
'
.
format
(
self
.
mode
,
self
.
delta
,
self
.
patience
))
# create dictionary of the observed losses and accuracies on the
# training and validation dataset
tshape
=
(
len
(
self
.
train_dl
),
self
.
epochs
)
...
...
@@ -207,36 +547,22 @@ class NetworkTrainer(object):
epoch_acc
=
vacc
.
squeeze
().
mean
()
# whether the model improved with respect to the previous epoch
if
es
.
increased
(
epoch_acc
,
self
.
max_accuracy
,
self
.
delta
):
if
self
.
es
.
increased
(
epoch_acc
,
self
.
max_accuracy
,
self
.
delta
):
self
.
max_accuracy
=
epoch_acc
# save model state if the model improved with
# respect to the previous epoch
_
=
self
.
model
.
save
(
self
.
state_file
,
self
.
optimizer
,
self
.
bands
,
self
.
state_path
)
# save losses and accuracy
self
.
_save_loss
(
training_state
,
self
.
checkpoint
,
self
.
checkpoint_state
)
self
.
save_state
(
training_state
)
# whether the early stopping criterion is met
if
es
.
stop
(
epoch_acc
):
if
self
.
es
.
stop
(
epoch_acc
):
break
else
:
# if no early stopping is required, the model state is saved
# after each epoch
_
=
self
.
model
.
save
(
self
.
state_file
,
self
.
optimizer
,
self
.
bands
,
self
.
state_path
)
# if no early stopping is required, the model state is
# saved after each epoch
self
.
save_state
(
training_state
)
# save losses and accuracy after each epoch
self
.
_save_loss
(
training_state
,
self
.
checkpoint
,
self
.
checkpoint_state
)
return
training_state
...
...
@@ -283,217 +609,37 @@ class NetworkTrainer(object):
.
format
(
batch
+
1
,
len
(
self
.
valid_dl
),
acc
))
# calculate overall accuracy on the validation/test set
print
(
'
After training for {:d} epochs, we achieved an overall
'
'
accuracy of {:.2f}% on the validation set!
'
print
(
'
Epoch {:d}, Overall accuracy: {:.2f}%.
'
.
format
(
self
.
model
.
epoch
,
accuracies
.
mean
()
*
100
))
return
accuracies
,
losses
def
_init_state
(
self
):
# file to save model state to
# format: network_dataset_seed_tilesize_batchsize_bands.pt
# get the band numbers
bformat
=
''
.
join
(
band
[
0
]
+
str
(
self
.
dataset
.
sensor
.
__members__
[
band
].
value
)
for
band
in
self
.
bands
)
# model state filename
self
.
state_file
=
(
'
{}_{}_s{}_t{}_b{}_{}.pt
'
.
format
(
self
.
model
.
__name__
,
self
.
dataset
.
__class__
.
__name__
,
self
.
seed
,
self
.
tile_size
,
self
.
batch_size
,
bformat
))
# check whether a pretrained model was used and change state filename
# accordingly
if
self
.
pretrained
:
# add the configuration of the pretrained model to the state name
self
.
state_file
=
(
self
.
state_file
.
replace
(
'
.pt
'
,
'
_
'
)
+
'
pretrained_
'
+
self
.
pretrained_model
)
# path to model state
self
.
state
=
os
.
path
.
join
(
self
.
state_path
,
self
.
state_file
)
# path to model loss/accuracy
self
.
loss_state
=
self
.
state
.
replace
(
'
.pt
'
,
'
_loss.pt
'
)
def
_init_dataset
(
self
):
# the dataset name
self
.
dataset_name
=
os
.
path
.
basename
(
self
.
root_dir
)
# check whether the dataset is currently supported
if
self
.
dataset_name
not
in
SupportedDatasets
.
__members__
:
raise
ValueError
(
'
{} is not a valid dataset.
'
.
format
(
self
.
dataset_name
)
+
'
Available datasets are:
\n
'
+
'
\n
'
.
join
(
name
for
name
,
_
in
SupportedDatasets
.
__members__
.
items
()))
else
:
self
.
dataset_class
=
SupportedDatasets
.
__members__
[
self
.
dataset_name
].
value
# instanciate the dataset
self
.
dataset
=
self
.
dataset_class
(
self
.
root_dir
,
use_bands
=
self
.
bands
,
tile_size
=
self
.
tile_size
,
sort
=
self
.
sort
,
transforms
=
self
.
transforms
,
pad
=
self
.
pad
,
cval
=
self
.
cval
,
gt_pattern
=
self
.
gt_pattern
)
# the mode to split
if
self
.
split_mode
not
in
VALID_SPLIT_MODES
:
raise
ValueError
(
'
{} is not supported. Valid modes are {}, see
'
'
pysegcnn.main.config.py for a description of
'
'
each mode.
'
.
format
(
self
.
split_mode
,
VALID_SPLIT_MODES
))
if
self
.
split_mode
==
'
random
'
:
self
.
subset
=
RandomTileSplit
(
self
.
dataset
,
self
.
ttratio
,
self
.
tvratio
,
self
.
seed
)
if
self
.
split_mode
==
'
scene
'
:
self
.
subset
=
RandomSceneSplit
(
self
.
dataset
,
self
.
ttratio
,
self
.
tvratio
,
self
.
seed
)
if
self
.
split_mode
==
'
date
'
:
self
.
subset
=
DateSplit
(
self
.
dataset
,
self
.
date
,
self
.
dateformat
)
# the training, validation and test dataset
self
.
train_ds
,
self
.
valid_ds
,
self
.
test_ds
=
self
.
subset
.
split
()
# whether to drop training samples with a fraction of pixels equal to
# the constant padding value self.cval >= self.drop
if
self
.
pad
and
self
.
drop
:
self
.
_drop
(
self
.
train_ds
)
# the shape of a single batch
self
.
batch_shape
=
(
len
(
self
.
bands
),
self
.
tile_size
,
self
.
tile_size
)
# the training dataloader
self
.
train_dl
=
None
if
len
(
self
.
train_ds
)
>
0
:
self
.
train_dl
=
DataLoader
(
self
.
train_ds
,
self
.
batch_size
,
shuffle
=
True
,
drop_last
=
False
)
# the validation dataloader
self
.
valid_dl
=
None
if
len
(
self
.
valid_ds
)
>
0
:
self
.
valid_dl
=
DataLoader
(
self
.
valid_ds
,
self
.
batch_size
,
shuffle
=
True
,
drop_last
=
False
)
# the test dataloader
self
.
test_dl
=
None
if
len
(
self
.
test_ds
)
>
0
:
self
.
test_dl
=
DataLoader
(
self
.
test_ds
,
self
.
batch_size
,
shuffle
=
True
,
drop_last
=
False
)
def
_init_model
(
self
):
# initial accuracy on the validation set
self
.
max_accuracy
=
0
# set the model checkpoint to None, overwritten when resuming
# training from an existing model checkpoint
self
.
checkpoint_state
=
None
# case (1): build a model for the specified dataset
if
not
self
.
pretrained
and
not
self
.
checkpoint
:
# instanciate the model
self
.
model
=
self
.
model
(
in_channels
=
len
(
self
.
dataset
.
use_bands
),
nclasses
=
len
(
self
.
dataset
.
labels
),
filters
=
self
.
filters
,
skip
=
self
.
skip_connection
,
**
self
.
kwargs
)
# the optimizer used to update the model weights
self
.
optimizer
=
self
.
optimizer
(
self
.
model
.
parameters
(),
self
.
lr
)
# case (2): using a pretrained model withouth existing checkpoint on
# a new dataset, i.e. transfer learning
if
self
.
pretrained
and
not
self
.
checkpoint
:
# load pretrained model
self
.
model
=
self
.
from_pretrained
()
# the optimizer used to update the model weights
self
.
optimizer
=
self
.
optimizer
(
self
.
model
.
parameters
(),
self
.
lr
)
# case (3): using a pretrained model with existing checkpoint on the
# same dataset the pretrained model was trained on
elif
self
.
checkpoint
:
# instanciate the model
self
.
model
=
self
.
model
(
in_channels
=
len
(
self
.
dataset
.
use_bands
),
nclasses
=
len
(
self
.
dataset
.
labels
),
filters
=
self
.
filters
,
skip
=
self
.
skip_connection
,
**
self
.
kwargs
)
# the optimizer used to update the model weights
self
.
optimizer
=
self
.
optimizer
(
self
.
model
.
parameters
(),
self
.
lr
)
# whether to resume training from an existing model checkpoint
if
self
.
checkpoint
:
(
self
.
checkpoint_state
,
self
.
max_accuracy
)
=
self
.
from_checkpoint
()
# function to drop samples with a fraction of pixels equal to the constant
# padding value self.cval >= self.drop
def
_drop
(
self
,
ds
):
def
save_state
(
self
,
training_state
):
# iterate over the scenes returned by self.compose_scenes()
self
.
dropped
=
[]
for
pos
,
i
in
enumerate
(
ds
.
indices
):
# whether to save the model state
if
self
.
save
:
# save model state
state
=
self
.
model
.
save
(
self
.
state_file
.
name
,
self
.
optimizer
,
self
.
train_dl
.
dataset
.
dataset
.
use_bands
,
self
.
state_file
.
parent
)
#
the current scene
s
=
ds
.
dataset
.
scenes
[
i
]
#
save losses and accuracy
s
elf
.
_save_loss
(
training_state
)
# the current tile in the ground truth
tile_gt
=
img2np
(
s
[
'
gt
'
],
self
.
tile_size
,
s
[
'
tile
'
],
self
.
pad
,
self
.
cval
)
def
_save_loss
(
self
,
training_state
):
# percent of pixels equal to the constant padding value
npixels
=
(
tile_gt
[
tile_gt
==
self
.
cval
].
size
/
tile_gt
.
size
)
# drop samples where npixels >= self.drop
if
npixels
>=
self
.
drop
:
print
(
'
Skipping scene {}, tile {}: {:.2f}% padded pixels ...
'
.
format
(
s
[
'
id
'
],
s
[
'
tile
'
],
npixels
*
100
))
self
.
dropped
.
append
(
s
)
_
=
ds
.
indices
.
pop
(
pos
)
# save losses and accuracy
state
=
training_state
if
self
.
checkpoint_state
:
def
_save_loss
(
self
,
training_state
,
checkpoint
=
False
,
checkpoint_state
=
None
):
# append values from checkpoint to current training state
state
=
{
k1
:
np
.
hstack
([
v1
,
v2
])
for
(
k1
,
v1
),
(
k2
,
v2
)
in
zip
(
self
.
checkpoint_state
.
items
(),
training_state
.
items
())
if
k1
==
k2
}
# save losses and accuracy
if
checkpoint
and
checkpoint_state
is
not
None
:
# append values from checkpoint to current training
# state
torch
.
save
({
k1
:
np
.
hstack
([
v1
,
v2
])
for
(
k1
,
v1
),
(
k2
,
v2
)
in
zip
(
checkpoint_state
.
items
(),
training_state
.
items
())
if
k1
==
k2
},
self
.
loss_state
)
else
:
torch
.
save
(
training_state
,
self
.
loss_state
)
# save the model loss and accuracies to file
torch
.
save
(
state
,
self
.
loss_state
)
def
__repr__
(
self
):
...
...
@@ -502,26 +648,37 @@ class NetworkTrainer(object):
# dataset
fs
+=
'
(dataset):
\n
'
fs
+=
''
.
join
(
self
.
dataset
.
__repr__
()).
replace
(
'
\n
'
,
'
\n
'
)
fs
+=
''
.
join
(
repr
(
self
.
train_dl
.
dataset
.
dataset
)).
replace
(
'
\n
'
,
'
\n
'
)
# batch size
fs
+=
'
\n
(batch):
\n
'
fs
+=
'
- batch size: {}
\n
'
.
format
(
self
.
batch_size
)
fs
+=
'
- batch shape (b, h, w): {}
'
.
format
(
self
.
batch_shape
)
fs
+=
'
- batch size: {}
\n
'
.
format
(
self
.
train_dl
.
batch_size
)
fs
+=
'
- mini-batch shape (b, c, h, w): {}
'
.
format
(
(
self
.
train_dl
.
batch_size
,
len
(
self
.
train_dl
.
dataset
.
dataset
.
use_bands
),
self
.
train_dl
.
dataset
.
dataset
.
tile_size
,
self
.
train_dl
.
dataset
.
dataset
.
tile_size
)
)
# dataset split
fs
+=
'
\n
(split):
\n
'
fs
+=
''
.
join
(
self
.
subset
.
__repr__
()).
replace
(
'
\n
'
,
'
\n
'
)
fs
+=
'
\n
(split):
'
fs
+=
'
\n
'
+
repr
(
self
.
train_dl
.
dataset
)
fs
+=
'
\n
'
+
repr
(
self
.
valid_dl
.
dataset
)
# model
fs
+=
'
\n
(model):
\n
'
fs
+=
''
.
join
(
self
.
model
.
__repr__
(
)).
replace
(
'
\n
'
,
'
\n
'
)
fs
+=
''
.
join
(
repr
(
self
.
model
)).
replace
(
'
\n
'
,
'
\n
'
)
# optimizer
fs
+=
'
\n
(optimizer):
\n
'
fs
+=
''
.
join
(
self
.
optimizer
.
__repr__
()).
replace
(
'
\n
'
,
'
\n
'
)
fs
+=
'
\n
)
'
fs
+=
''
.
join
(
repr
(
self
.
optimizer
)).
replace
(
'
\n
'
,
'
\n
'
)
# early stopping
fs
+=
'
\n
(early stop):
\n
'
fs
+=
''
.
join
(
repr
(
self
.
es
)).
replace
(
'
\n
'
,
'
\n
'
)
fs
+=
'
\n
)
'
return
fs
...
...
@@ -555,6 +712,9 @@ class EarlyStopping(object):
# initialize early stopping flag
self
.
early_stop
=
False
# initialize the early stop counter
self
.
counter
=
0
def
stop
(
self
,
metric
):
if
self
.
best
is
not
None
:
...
...
@@ -584,3 +744,8 @@ class EarlyStopping(object):
def
increased
(
self
,
metric
,
best
,
min_delta
):
return
metric
>
best
+
min_delta
def
__repr__
(
self
):
fs
=
(
self
.
__class__
.
__name__
+
'
(mode={}, delta={}, patience={})
'
.
format
(
self
.
mode
,
self
.
min_delta
,
self
.
patience
))
return
fs
This diff is collapsed.
Click to expand it.
Preview
0%
Loading
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Save comment
Cancel
Please
register
or
sign in
to comment