Skip to content
GitLab
Explore
Sign in
Primary navigation
Search or go to…
Project
P
PySegCNN
Manage
Activity
Members
Labels
Plan
Issues
Issue boards
Milestones
Wiki
Code
Merge requests
Repository
Branches
Commits
Tags
Repository graph
Compare revisions
Snippets
Deploy
Releases
Package Registry
Container Registry
Model registry
Operate
Terraform modules
Monitor
Incidents
Analyze
Value stream analytics
Contributor analytics
Repository analytics
Model experiments
Help
Help
Support
GitLab documentation
Compare GitLab plans
Community forum
Contribute to GitLab
Provide feedback
Keyboard shortcuts
?
Snippets
Groups
Projects
Show more breadcrumbs
earth_observation_public
PySegCNN
Commits
a2814daf
Commit
a2814daf
authored
4 years ago
by
Frisinghelli Daniel
Browse files
Options
Downloads
Patches
Plain Diff
Improved training initialization workflow
parent
2ac1b506
No related branches found
No related tags found
No related merge requests found
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
pysegcnn/core/trainer.py
+199
-184
199 additions, 184 deletions
pysegcnn/core/trainer.py
pysegcnn/main/train.py
+30
-26
30 additions, 26 deletions
pysegcnn/main/train.py
with
229 additions
and
210 deletions
pysegcnn/core/trainer.py
+
199
−
184
View file @
a2814daf
...
@@ -7,6 +7,7 @@ Created on Wed Aug 12 10:24:34 2020
...
@@ -7,6 +7,7 @@ Created on Wed Aug 12 10:24:34 2020
# builtins
# builtins
import
dataclasses
import
dataclasses
import
pathlib
import
pathlib
import
logging
# externals
# externals
import
numpy
as
np
import
numpy
as
np
...
@@ -16,7 +17,6 @@ import torch.nn.functional as F
...
@@ -16,7 +17,6 @@ import torch.nn.functional as F
from
torch.utils.data
import
DataLoader
,
Dataset
from
torch.utils.data
import
DataLoader
,
Dataset
from
torch.optim
import
Optimizer
from
torch.optim
import
Optimizer
# locals
# locals
from
pysegcnn.core.dataset
import
SupportedDatasets
,
ImageDataset
from
pysegcnn.core.dataset
import
SupportedDatasets
,
ImageDataset
from
pysegcnn.core.transforms
import
Augment
from
pysegcnn.core.transforms
import
Augment
...
@@ -27,6 +27,9 @@ from pysegcnn.core.models import (SupportedModels, SupportedOptimizers,
...
@@ -27,6 +27,9 @@ from pysegcnn.core.models import (SupportedModels, SupportedOptimizers,
from
pysegcnn.core.layers
import
Conv2dSame
from
pysegcnn.core.layers
import
Conv2dSame
from
pysegcnn.main.config
import
HERE
from
pysegcnn.main.config
import
HERE
# module level logger
LOGGER
=
logging
.
getLogger
(
__name__
)
@dataclasses.dataclass
@dataclasses.dataclass
class
BaseConfig
:
class
BaseConfig
:
...
@@ -60,7 +63,6 @@ class DatasetConfig(BaseConfig):
...
@@ -60,7 +63,6 @@ class DatasetConfig(BaseConfig):
sort
:
bool
=
False
sort
:
bool
=
False
transforms
:
list
=
dataclasses
.
field
(
default_factory
=
list
)
transforms
:
list
=
dataclasses
.
field
(
default_factory
=
list
)
pad
:
bool
=
False
pad
:
bool
=
False
cval
:
int
=
99
def
__post_init__
(
self
):
def
__post_init__
(
self
):
# check input types
# check input types
...
@@ -80,11 +82,6 @@ class DatasetConfig(BaseConfig):
...
@@ -80,11 +82,6 @@ class DatasetConfig(BaseConfig):
'
of {}.
'
.
format
(
'
.
'
.
join
([
Augment
.
__module__
,
'
of {}.
'
.
format
(
'
.
'
.
join
([
Augment
.
__module__
,
Augment
.
__name__
])))
Augment
.
__name__
])))
# check whether the constant padding value is within the valid range
if
not
0
<
self
.
cval
<
255
:
raise
ValueError
(
'
Expecting 0 <= cval <= 255, got cval={}.
'
.
format
(
self
.
cval
))
def
init_dataset
(
self
):
def
init_dataset
(
self
):
# instanciate the dataset
# instanciate the dataset
...
@@ -96,7 +93,6 @@ class DatasetConfig(BaseConfig):
...
@@ -96,7 +93,6 @@ class DatasetConfig(BaseConfig):
sort
=
self
.
sort
,
sort
=
self
.
sort
,
transforms
=
self
.
transforms
,
transforms
=
self
.
transforms
,
pad
=
self
.
pad
,
pad
=
self
.
pad
,
cval
=
self
.
cval
,
gt_pattern
=
self
.
gt_pattern
gt_pattern
=
self
.
gt_pattern
)
)
...
@@ -121,7 +117,8 @@ class SplitConfig(BaseConfig):
...
@@ -121,7 +117,8 @@ class SplitConfig(BaseConfig):
# function to drop samples with a fraction of pixels equal to the constant
# function to drop samples with a fraction of pixels equal to the constant
# padding value self.cval >= self.drop
# padding value self.cval >= self.drop
def
_drop_samples
(
self
,
ds
,
drop_threshold
=
1
):
@staticmethod
def
_drop_samples
(
ds
,
drop_threshold
=
1
):
# iterate over the scenes returned by self.compose_scenes()
# iterate over the scenes returned by self.compose_scenes()
dropped
=
[]
dropped
=
[]
...
@@ -139,8 +136,8 @@ class SplitConfig(BaseConfig):
...
@@ -139,8 +136,8 @@ class SplitConfig(BaseConfig):
# drop samples where npixels >= self.drop
# drop samples where npixels >= self.drop
if
npixels
>=
drop_threshold
:
if
npixels
>=
drop_threshold
:
print
(
'
Skipping scene {}, tile {}: {:.2f}% padded pixels
...
'
LOGGER
.
info
(
'
Skipping scene {}, tile {}: {:.2f}% padded pixels
'
.
format
(
s
[
'
id
'
],
s
[
'
tile
'
],
npixels
*
100
))
'
...
'
.
format
(
s
[
'
id
'
],
s
[
'
tile
'
],
npixels
*
100
))
dropped
.
append
(
s
)
dropped
.
append
(
s
)
_
=
ds
.
indices
.
pop
(
pos
)
_
=
ds
.
indices
.
pop
(
pos
)
...
@@ -197,14 +194,24 @@ class ModelConfig(BaseConfig):
...
@@ -197,14 +194,24 @@ class ModelConfig(BaseConfig):
model_name
:
str
model_name
:
str
filters
:
list
filters
:
list
torch_seed
:
int
torch_seed
:
int
optim_name
:
str
loss_name
:
str
skip_connection
:
bool
=
True
skip_connection
:
bool
=
True
kwargs
:
dict
=
dataclasses
.
field
(
kwargs
:
dict
=
dataclasses
.
field
(
default_factory
=
lambda
:
{
'
kernel_size
'
:
3
,
'
stride
'
:
1
,
'
dilation
'
:
1
})
default_factory
=
lambda
:
{
'
kernel_size
'
:
3
,
'
stride
'
:
1
,
'
dilation
'
:
1
})
state_path
:
pathlib
.
Path
=
pathlib
.
Path
(
HERE
).
joinpath
(
'
_models/
'
)
state_path
:
pathlib
.
Path
=
pathlib
.
Path
(
HERE
).
joinpath
(
'
_models/
'
)
batch_size
:
int
=
64
batch_size
:
int
=
64
checkpoint
:
bool
=
False
checkpoint
:
bool
=
False
pretrained
:
bool
=
False
transfer
:
bool
=
False
pretrained_model
:
str
=
''
pretrained_model
:
str
=
''
lr
:
float
=
0.001
early_stop
:
bool
=
False
mode
:
str
=
'
max
'
delta
:
float
=
0
patience
:
int
=
10
epochs
:
int
=
50
nthreads
:
int
=
torch
.
get_num_threads
()
save
:
bool
=
True
def
__post_init__
(
self
):
def
__post_init__
(
self
):
# check input types
# check input types
...
@@ -213,65 +220,32 @@ class ModelConfig(BaseConfig):
...
@@ -213,65 +220,32 @@ class ModelConfig(BaseConfig):
# check whether the model is currently supported
# check whether the model is currently supported
self
.
model_class
=
item_in_enum
(
self
.
model_name
,
SupportedModels
)
self
.
model_class
=
item_in_enum
(
self
.
model_name
,
SupportedModels
)
def
init_state
(
self
,
ds
,
sc
,
tc
):
# check whether the optimizer is currently supported
self
.
optim_class
=
item_in_enum
(
self
.
optim_name
,
SupportedOptimizers
)
# file to save model state to:
# network_dataset_optim_split_splitparams_tilesize_batchsize_bands.pt
#
model state filename
#
check whether the loss function is currently supported
s
tate_file
=
'
{}_{}_{}_{}Split_{}_t{}_b{}_{}.pt
'
s
elf
.
loss_class
=
item_in_enum
(
self
.
loss_name
,
SupportedLossFunctions
)
# get the band numbers
# path to pretrained model
bformat
=
''
.
join
(
band
[
0
]
+
self
.
pretrained_path
=
self
.
state_path
.
joinpath
(
self
.
pretrained_model
)
str
(
ds
.
sensor
.
__members__
[
band
].
value
)
for
band
in
ds
.
use_bands
)
# check which split mode was used
def
init_optimizer
(
self
,
model
):
if
sc
.
split_mode
==
'
date
'
:
# store the date that was used to split the dataset
state_file
=
state_file
.
format
(
self
.
model_class
.
__name__
,
ds
.
__class__
.
__name__
,
tc
.
optim_name
,
sc
.
split_mode
.
capitalize
(),
sc
.
date
,
ds
.
tile_size
,
self
.
batch_size
,
bformat
)
else
:
# store the random split parameters
split_params
=
'
s{}_t{}v{}
'
.
format
(
ds
.
seed
,
str
(
sc
.
ttratio
).
replace
(
'
.
'
,
''
),
str
(
sc
.
tvratio
).
replace
(
'
.
'
,
''
))
# model state filename
# initialize the optimizer for the specified model
state_file
=
state_file
.
format
(
self
.
model_class
.
__name__
,
optimizer
=
self
.
optim_class
(
model
.
parameters
(),
self
.
lr
)
ds
.
__class__
.
__name__
,
tc
.
optim_name
,
sc
.
split_mode
.
capitalize
(),
split_params
,
ds
.
tile_size
,
self
.
batch_size
,
bformat
)
# check whether a pretrained model was used and change state filename
return
optimizer
# accordingly
if
self
.
pretrained
:
# add the configuration of the pretrained model to the state name
state_file
=
(
state_file
.
replace
(
'
.pt
'
,
'
_
'
)
+
'
pretrained_
'
+
self
.
pretrained_model
)
# path to model state
def
init_loss_function
(
self
):
state
=
self
.
state_path
.
joinpath
(
state_file
)
# path to model loss/accuracy
loss_function
=
self
.
loss_class
()
loss_state
=
pathlib
.
Path
(
str
(
state
).
replace
(
'
.pt
'
,
'
_loss.pt
'
))
return
state
,
loss_state
return
loss_function
def
init_model
(
self
,
ds
):
def
init_model
(
self
,
ds
):
# case (1): build a new model
# case (1): build a new model
if
not
self
.
pretrained
:
if
not
self
.
transfer
:
# set the random seed for reproducibility
# set the random seed for reproducibility
torch
.
manual_seed
(
self
.
torch_seed
)
torch
.
manual_seed
(
self
.
torch_seed
)
...
@@ -284,130 +258,172 @@ class ModelConfig(BaseConfig):
...
@@ -284,130 +258,172 @@ class ModelConfig(BaseConfig):
skip
=
self
.
skip_connection
,
skip
=
self
.
skip_connection
,
**
self
.
kwargs
)
**
self
.
kwargs
)
# case (2): load a pretrained model
# case (2): load a pretrained model
for transfer learning
else
:
else
:
# load pretrained model
# load pretrained model
model
=
self
.
load_pretrained
()
model
,
_
=
self
.
load_pretrained
(
self
.
pretrained_path
,
new_ds
=
ds
)
return
model
return
model
def
load
_checkpoint
(
self
,
state_file
,
loss_state
,
model
,
optimizer
):
def
from
_checkpoint
(
self
,
model
,
optimizer
,
state_file
,
loss_state
):
# initial accuracy on the validation set
# whether to resume training from an existing model checkpoint
checkpoint_state
=
{}
max_accuracy
=
0
max_accuracy
=
0
if
self
.
checkpoint
:
# set the model checkpoint to None, overwritten when resuming
# check whether the checkpoint exists
# training from an existing model checkpoint
if
state_file
.
exists
()
and
loss_state
.
exists
():
checkpoint_state
=
{}
# load model checkpoint
model
,
optimizer
=
self
.
load_pretrained
(
state_file
,
optimizer
,
new_ds
=
None
)
(
checkpoint_state
,
max_accuracy
)
=
self
.
load_checkpoint
(
loss_state
)
else
:
LOGGER
.
info
(
'
Checkpoint for model {} does not exist.
'
'
Initializing new model.
'
.
format
(
state_file
.
name
))
# whether to resume training from an existing model
return
model
,
optimizer
,
checkpoint_state
,
max_accuracy
if
self
.
checkpoint
:
# check if a model checkpoint exists
@staticmethod
if
not
state_file
.
exists
():
def
load_pretrained
(
state_file
,
optimizer
=
None
,
new_ds
=
None
):
raise
FileNotFoundError
(
'
Model checkpoint {} does not exist.
'
.
format
(
state_file
))
# load the model state
# load the pretrained model
state
=
model
.
load
(
state_file
.
name
,
optimizer
,
self
.
state_path
)
if
not
state_file
.
exists
():
print
(
'
Found checkpoint: {}
'
.
format
(
state
))
raise
FileNotFoundError
(
'
Pretrained model {} does not exist.
'
print
(
'
Resuming training from checkpoint ...
'
.
format
(
state
))
.
format
(
state_file
))
print
(
'
Model epoch: {:d}
'
.
format
(
model
.
epoch
))
# load the model loss and accuracy
LOGGER
.
info
(
'
Loading pretrained model: {}
'
.
format
(
state_file
.
name
))
checkpoint_state
=
torch
.
load
(
loss_state
)
# get all non-zero elements, i.e. get number of epochs trained
# load the model state
# before the early stop
model_state
=
torch
.
load
(
state_file
)
checkpoint_state
=
{
k
:
v
[
np
.
nonzero
(
v
)].
reshape
(
v
.
shape
[
0
],
-
1
)
for
k
,
v
in
checkpoint_state
.
items
()}
# maximum accuracy on the validation set
# the model class
max_accuracy
=
checkpoint_state
[
'
va
'
][:,
-
1
].
mean
().
item
()
model_class
=
model_state
[
'
cls
'
]
return
checkpoint_state
,
max_accuracy
# instanciate pretrained model architecture
model
=
model_class
(
**
model_state
[
'
params
'
],
**
model_state
[
'
kwargs
'
])
def
load_pretrained
(
self
,
ds
):
# load pretrained model weights
_
=
model
.
load
(
state_file
.
name
,
optimizer
=
optimizer
,
inpath
=
str
(
state_file
.
parent
))
LOGGER
.
info
(
'
Model epoch: {:d}
'
.
format
(
model
.
epoch
))
# load the pretrained model
# check whether to apply pretrained model on a new dataset
model_state
=
self
.
state_path
.
joinpath
(
self
.
pretrained_model
)
if
new_ds
is
not
None
:
if
not
model_state
.
exists
():
LOGGER
.
info
(
'
Configuring model for new dataset: {}.
'
raise
FileNotFoundError
(
'
Pretrained model {} does not exist.
'
.
format
(
new_ds
.
__class__
.
__name__
))
.
format
(
model_state
))
# load the model state
# the bands the model was trained with
model_state
=
torch
.
load
(
model_state
)
bands
=
model_state
[
'
bands
'
]
# get the input bands of the pretrained model
# check whether the current dataset uses the correct spectral bands
bands
=
model_state
[
'
bands
'
]
if
new_ds
.
use_bands
!=
bands
:
raise
ValueError
(
'
The pretrained network was trained with the
'
'
bands {}, not with: {}
'
.
format
(
bands
,
new_ds
.
use_bands
))
# get the number of convolutional filters
# get the number of convolutional filters
filters
=
model_state
[
'
params
'
][
'
filters
'
]
filters
=
model_state
[
'
params
'
][
'
filters
'
]
# check whether the current dataset uses the correct spectral bands
# reset model epoch to 0, since the model is trained on a different
if
ds
.
use_bands
!=
bands
:
# dataset
raise
ValueError
(
'
The bands of the pretrained network do not
'
model
.
epoch
=
0
'
match the specified bands: {}
'
.
format
(
bands
))
# instanciate pretrained model architecture
# adjust the number of classes in the model
model
=
self
.
model_class
(
**
model_state
[
'
params
'
],
model
.
nclasses
=
len
(
new_ds
.
labels
)
**
model_state
[
'
kwargs
'
])
LOGGER
.
info
(
'
Replacing classification layer to classes: {}.
'
.
format
(
'
,
'
.
join
(
'
({}, {})
'
.
format
(
k
,
v
[
'
label
'
])
for
k
,
v
in
new_ds
.
labels
.
items
())))
# load pretrained model weights
# adjust the classification layer to the number of classes of the
model
.
load
(
self
.
pretrained_model
,
inpath
=
str
(
self
.
state_path
))
# current dataset
model
.
classifier
=
Conv2dSame
(
in_channels
=
filters
[
0
],
out_channels
=
model
.
nclasses
,
kernel_size
=
1
)
# reset model epoch to 0, since the model is trained on a different
return
model
,
optimizer
# dataset
model
.
epoch
=
0
# adjust the number of classes in the model
@staticmethod
model
.
nclasses
=
len
(
ds
.
labels
)
def
load_checkpoint
(
loss_state
):
# adjust the classification layer to the number of classes of the
# load the model loss and accuracy
# current dataset
checkpoint_state
=
torch
.
load
(
loss_state
)
model
.
classifier
=
Conv2dSame
(
in_channels
=
filters
[
0
],
out_channels
=
model
.
nclasses
,
kernel_size
=
1
)
return
model
# get all non-zero elements, i.e. get number of epochs trained
# before the early stop
checkpoint_state
=
{
k
:
v
[
np
.
nonzero
(
v
)].
reshape
(
v
.
shape
[
0
],
-
1
)
for
k
,
v
in
checkpoint_state
.
items
()}
# maximum accuracy on the validation set
max_accuracy
=
checkpoint_state
[
'
va
'
][:,
-
1
].
mean
().
item
()
return
checkpoint_state
,
max_accuracy
@dataclasses.dataclass
@dataclasses.dataclass
class
TrainConfig
(
BaseConfig
):
class
StateConfig
(
BaseConfig
):
optim_name
:
str
ds
:
ImageDataset
loss_name
:
str
sc
:
SplitConfig
lr
:
float
=
0.001
mc
:
ModelConfig
early_stop
:
bool
=
False
mode
:
str
=
'
max
'
delta
:
float
=
0
patience
:
int
=
10
epochs
:
int
=
50
nthreads
:
int
=
torch
.
get_num_threads
()
save
:
bool
=
True
def
__post_init__
(
self
):
def
__post_init__
(
self
):
super
().
__post_init__
()
super
().
__post_init__
()
# check whether the optimizer is currently supported
def
init_state
(
self
):
self
.
optim_class
=
item_in_enum
(
self
.
optim_name
,
SupportedOptimizers
)
#
check whether the loss function is currently supported
#
file to save model state to:
self
.
loss_class
=
item_in_enum
(
self
.
loss_name
,
SupportedLossFunctions
)
# network_dataset_optim_split_splitparams_tilesize_batchsize_bands.pt
def
init_optimizer
(
self
,
model
):
# model state filename
state_file
=
'
{}_{}_{}_{}Split_{}_t{}_b{}_{}.pt
'
# initialize the optimizer for the specified model
# get the band numbers
optimizer
=
self
.
optim_class
(
model
.
parameters
(),
self
.
lr
)
bformat
=
''
.
join
(
band
[
0
]
+
str
(
self
.
ds
.
sensor
.
__members__
[
band
].
value
)
for
band
in
self
.
ds
.
use_bands
)
return
optimizer
# check which split mode was used
if
self
.
sc
.
split_mode
==
'
date
'
:
# store the date that was used to split the dataset
state_file
=
state_file
.
format
(
self
.
mc
.
model_name
,
self
.
ds
.
__class__
.
__name__
,
self
.
mc
.
optim_name
,
self
.
sc
.
split_mode
.
capitalize
(),
self
.
sc
.
date
,
self
.
ds
.
tile_size
,
self
.
mc
.
batch_size
,
bformat
)
else
:
# store the random split parameters
split_params
=
'
s{}_t{}v{}
'
.
format
(
self
.
ds
.
seed
,
str
(
self
.
sc
.
ttratio
).
replace
(
'
.
'
,
''
),
str
(
self
.
sc
.
tvratio
).
replace
(
'
.
'
,
''
))
def
init_loss_function
(
self
):
# model state filename
state_file
=
state_file
.
format
(
self
.
mc
.
model_name
,
self
.
ds
.
__class__
.
__name__
,
self
.
mc
.
optim_name
,
self
.
sc
.
split_mode
.
capitalize
(),
split_params
,
self
.
ds
.
tile_size
,
self
.
mc
.
batch_size
,
bformat
)
loss_function
=
self
.
loss_class
()
# check whether a pretrained model was used and change state filename
# accordingly
if
self
.
mc
.
transfer
:
# add the configuration of the pretrained model to the state name
state_file
=
(
state_file
.
replace
(
'
.pt
'
,
'
_
'
)
+
'
pretrained_
'
+
self
.
mc
.
pretrained_model
)
return
loss_function
# path to model state
state
=
self
.
mc
.
state_path
.
joinpath
(
state_file
)
# path to model loss/accuracy
loss_state
=
pathlib
.
Path
(
str
(
state
).
replace
(
'
.pt
'
,
'
_loss.pt
'
))
return
state
,
loss_state
@dataclasses.dataclass
@dataclasses.dataclass
...
@@ -428,6 +444,7 @@ class EvalConfig(BaseConfig):
...
@@ -428,6 +444,7 @@ class EvalConfig(BaseConfig):
raise
TypeError
(
'
Expected
"
test
"
to be None, True or False, got
'
raise
TypeError
(
'
Expected
"
test
"
to be None, True or False, got
'
'
{}.
'
.
format
(
self
.
test
))
'
{}.
'
.
format
(
self
.
test
))
@dataclasses.dataclass
@dataclasses.dataclass
class
NetworkTrainer
(
BaseConfig
):
class
NetworkTrainer
(
BaseConfig
):
model
:
Network
model
:
Network
...
@@ -457,15 +474,16 @@ class NetworkTrainer(BaseConfig):
...
@@ -457,15 +474,16 @@ class NetworkTrainer(BaseConfig):
# whether to use early stopping
# whether to use early stopping
self
.
es
=
None
self
.
es
=
None
if
self
.
early_stop
:
if
self
.
early_stop
:
self
.
es
=
EarlyStopping
(
self
.
mode
,
self
.
delta
,
self
.
patience
)
self
.
es
=
EarlyStopping
(
self
.
mode
,
self
.
max_accuracy
,
self
.
delta
,
self
.
patience
)
def
train
(
self
):
def
train
(
self
):
print
(
'
------------------------- Training --------------------------
-
'
)
LOGGER
.
info
(
30
*
'
-
'
+
'
Training
'
+
30
*
'
-
'
)
# set the number of threads
# set the number of threads
print
(
'
Device: {}
'
.
format
(
self
.
device
))
LOGGER
.
info
(
'
Device: {}
'
.
format
(
self
.
device
))
print
(
'
Number of cpu threads: {}
'
.
format
(
self
.
nthreads
))
LOGGER
.
info
(
'
Number of cpu threads: {}
'
.
format
(
self
.
nthreads
))
torch
.
set_num_threads
(
self
.
nthreads
)
torch
.
set_num_threads
(
self
.
nthreads
)
# create dictionary of the observed losses and accuracies on the
# create dictionary of the observed losses and accuracies on the
...
@@ -485,7 +503,7 @@ class NetworkTrainer(BaseConfig):
...
@@ -485,7 +503,7 @@ class NetworkTrainer(BaseConfig):
for
epoch
in
range
(
self
.
epochs
):
for
epoch
in
range
(
self
.
epochs
):
# set the model to training mode
# set the model to training mode
print
(
'
Setting model to training mode ...
'
)
LOGGER
.
info
(
'
Setting model to training mode ...
'
)
self
.
model
.
train
()
self
.
model
.
train
()
# iterate over the dataloader object
# iterate over the dataloader object
...
@@ -521,13 +539,14 @@ class NetworkTrainer(BaseConfig):
...
@@ -521,13 +539,14 @@ class NetworkTrainer(BaseConfig):
training_state
[
'
ta
'
][
batch
,
epoch
]
=
observed_accuracy
training_state
[
'
ta
'
][
batch
,
epoch
]
=
observed_accuracy
# print progress
# print progress
print
(
'
Epoch: {:d}/{:d}, Mini-batch: {:d}/{:d}, Loss: {:.2f},
'
LOGGER
.
info
(
'
Epoch: {:d}/{:d}, Mini-batch: {:d}/{:d},
'
'
Accuracy: {:.2f}
'
.
format
(
epoch
+
1
,
'
Loss: {:.2f}, Accuracy: {:.2f}
'
.
format
(
self
.
epochs
,
epoch
+
1
,
batch
+
1
,
self
.
epochs
,
len
(
self
.
train_dl
),
batch
+
1
,
observed_loss
,
len
(
self
.
train_dl
),
observed_accuracy
))
observed_loss
,
observed_accuracy
))
# update the number of epochs trained
# update the number of epochs trained
self
.
model
.
epoch
+=
1
self
.
model
.
epoch
+=
1
...
@@ -568,13 +587,13 @@ class NetworkTrainer(BaseConfig):
...
@@ -568,13 +587,13 @@ class NetworkTrainer(BaseConfig):
def
predict
(
self
):
def
predict
(
self
):
print
(
'
------------------------ Predicting -------------------------
-
'
)
LOGGER
.
info
(
30
*
'
-
'
+
'
Predicting
'
+
30
*
'
-
'
)
# send the model to the gpu if available
# send the model to the gpu if available
self
.
model
=
self
.
model
.
to
(
self
.
device
)
self
.
model
=
self
.
model
.
to
(
self
.
device
)
# set the model to evaluation mode
# set the model to evaluation mode
print
(
'
Setting model to evaluation mode ...
'
)
LOGGER
.
info
(
'
Setting model to evaluation mode ...
'
)
self
.
model
.
eval
()
self
.
model
.
eval
()
# create arrays of the observed losses and accuracies
# create arrays of the observed losses and accuracies
...
@@ -582,7 +601,7 @@ class NetworkTrainer(BaseConfig):
...
@@ -582,7 +601,7 @@ class NetworkTrainer(BaseConfig):
losses
=
np
.
zeros
(
shape
=
(
len
(
self
.
valid_dl
),
1
))
losses
=
np
.
zeros
(
shape
=
(
len
(
self
.
valid_dl
),
1
))
# iterate over the validation/test set
# iterate over the validation/test set
print
(
'
Calculating accuracy on the validation set ...
'
)
LOGGER
.
info
(
'
Calculating accuracy on the validation set ...
'
)
for
batch
,
(
inputs
,
labels
)
in
enumerate
(
self
.
valid_dl
):
for
batch
,
(
inputs
,
labels
)
in
enumerate
(
self
.
valid_dl
):
# send the data to the gpu if available
# send the data to the gpu if available
...
@@ -605,12 +624,12 @@ class NetworkTrainer(BaseConfig):
...
@@ -605,12 +624,12 @@ class NetworkTrainer(BaseConfig):
accuracies
[
batch
,
0
]
=
acc
accuracies
[
batch
,
0
]
=
acc
# print progress
# print progress
print
(
'
Mini-batch: {:d}/{:d}, Accuracy: {:.2f}
'
LOGGER
.
info
(
'
Mini-batch: {:d}/{:d}, Accuracy: {:.2f}
'
.
format
(
batch
+
1
,
len
(
self
.
valid_dl
),
acc
))
.
format
(
batch
+
1
,
len
(
self
.
valid_dl
),
acc
))
# calculate overall accuracy on the validation/test set
# calculate overall accuracy on the validation/test set
print
(
'
Epoch {:d}, Overall accuracy: {:.2f}%.
'
LOGGER
.
info
(
'
Epoch {:d}, Overall accuracy: {:.2f}%.
'
.
format
(
self
.
model
.
epoch
,
accuracies
.
mean
()
*
100
))
.
format
(
self
.
model
.
epoch
,
accuracies
.
mean
()
*
100
))
return
accuracies
,
losses
return
accuracies
,
losses
...
@@ -649,7 +668,7 @@ class NetworkTrainer(BaseConfig):
...
@@ -649,7 +668,7 @@ class NetworkTrainer(BaseConfig):
# dataset
# dataset
fs
+=
'
(dataset):
\n
'
fs
+=
'
(dataset):
\n
'
fs
+=
''
.
join
(
fs
+=
''
.
join
(
repr
(
self
.
train_dl
.
dataset
.
dataset
)).
replace
(
'
\n
'
,
'
\n
'
)
repr
(
self
.
train_dl
.
dataset
.
dataset
)).
replace
(
'
\n
'
,
'
\n
'
)
# batch size
# batch size
fs
+=
'
\n
(batch):
\n
'
fs
+=
'
\n
(batch):
\n
'
...
@@ -684,7 +703,7 @@ class NetworkTrainer(BaseConfig):
...
@@ -684,7 +703,7 @@ class NetworkTrainer(BaseConfig):
class
EarlyStopping
(
object
):
class
EarlyStopping
(
object
):
def
__init__
(
self
,
mode
=
'
max
'
,
min_delta
=
0
,
patience
=
10
):
def
__init__
(
self
,
mode
=
'
max
'
,
best
=
0
,
min_delta
=
0
,
patience
=
10
):
# check if mode is correctly specified
# check if mode is correctly specified
if
mode
not
in
[
'
min
'
,
'
max
'
]:
if
mode
not
in
[
'
min
'
,
'
max
'
]:
...
@@ -707,7 +726,7 @@ class EarlyStopping(object):
...
@@ -707,7 +726,7 @@ class EarlyStopping(object):
self
.
patience
=
patience
self
.
patience
=
patience
# initialize best metric
# initialize best metric
self
.
best
=
None
self
.
best
=
best
# initialize early stopping flag
# initialize early stopping flag
self
.
early_stop
=
False
self
.
early_stop
=
False
...
@@ -717,25 +736,20 @@ class EarlyStopping(object):
...
@@ -717,25 +736,20 @@ class EarlyStopping(object):
def
stop
(
self
,
metric
):
def
stop
(
self
,
metric
):
if
self
.
best
is
not
None
:
# if the metric improved, reset the epochs counter, else, advance
if
self
.
is_better
(
metric
,
self
.
best
,
self
.
min_delta
):
# if the metric improved, reset the epochs counter, else, advance
self
.
counter
=
0
if
self
.
is_better
(
metric
,
self
.
best
,
self
.
min_delta
):
self
.
counter
=
0
self
.
best
=
metric
else
:
self
.
counter
+=
1
print
(
'
Early stopping counter: {}/{}
'
.
format
(
self
.
counter
,
self
.
patience
))
# if the metric did not improve over the last patience epochs,
# the early stopping criterion is met
if
self
.
counter
>=
self
.
patience
:
print
(
'
Early stopping criterion met, exiting training ...
'
)
self
.
early_stop
=
True
else
:
self
.
best
=
metric
self
.
best
=
metric
else
:
self
.
counter
+=
1
LOGGER
.
info
(
'
Early stopping counter: {}/{}
'
.
format
(
self
.
counter
,
self
.
patience
))
# if the metric did not improve over the last patience epochs,
# the early stopping criterion is met
if
self
.
counter
>=
self
.
patience
:
LOGGER
.
info
(
'
Early stopping criterion met, stopping training.
'
)
self
.
early_stop
=
True
return
self
.
early_stop
return
self
.
early_stop
...
@@ -746,6 +760,7 @@ class EarlyStopping(object):
...
@@ -746,6 +760,7 @@ class EarlyStopping(object):
return
metric
>
best
+
min_delta
return
metric
>
best
+
min_delta
def
__repr__
(
self
):
def
__repr__
(
self
):
fs
=
(
self
.
__class__
.
__name__
+
'
(mode={}, delta={}, patience={})
'
fs
=
self
.
__class__
.
__name__
.
format
(
self
.
mode
,
self
.
min_delta
,
self
.
patience
))
fs
+=
'
(mode={}, best={}, delta={}, patience={})
'
.
format
(
self
.
mode
,
self
.
best
,
self
.
min_delta
,
self
.
patience
)
return
fs
return
fs
This diff is collapsed.
Click to expand it.
pysegcnn/main/train.py
+
30
−
26
View file @
a2814daf
...
@@ -5,11 +5,14 @@ Created on Tue Jun 30 09:33:38 2020
...
@@ -5,11 +5,14 @@ Created on Tue Jun 30 09:33:38 2020
@author: Daniel
@author: Daniel
"""
"""
# builtins
import
logging
# locals
# locals
from
pysegcnn.core.trainer
import
(
DatasetConfig
,
SplitConfig
,
ModelConfig
,
from
pysegcnn.core.trainer
import
(
DatasetConfig
,
SplitConfig
,
ModelConfig
,
Train
Config
,
NetworkTrainer
)
State
Config
,
NetworkTrainer
)
from
pysegcnn.
main.config
import
(
dataset_config
,
split
_conf
ig
,
from
pysegcnn.
core.logging
import
log
_conf
model_config
,
train
_config
)
from
pysegcnn.main.config
import
(
dataset_config
,
split_config
,
model
_config
)
if
__name__
==
'
__main__
'
:
if
__name__
==
'
__main__
'
:
...
@@ -20,35 +23,36 @@ if __name__ == '__main__':
...
@@ -20,35 +23,36 @@ if __name__ == '__main__':
dc
=
DatasetConfig
(
**
dataset_config
)
dc
=
DatasetConfig
(
**
dataset_config
)
sc
=
SplitConfig
(
**
split_config
)
sc
=
SplitConfig
(
**
split_config
)
mc
=
ModelConfig
(
**
model_config
)
mc
=
ModelConfig
(
**
model_config
)
tc
=
TrainConfig
(
**
train_config
)
# (ii) instanciate the dataset
# (ii) instanciate the dataset
ds
=
dc
.
init_dataset
()
ds
=
dc
.
init_dataset
()
ds
ds
# (iii) instanciate the training, validation and test datasets and
# (iii) instanciate the model state
state
=
StateConfig
(
ds
,
sc
,
mc
)
state_file
,
loss_state
=
state
.
init_state
()
# initialize logging
log_file
=
str
(
state_file
).
replace
(
'
.pt
'
,
'
_train.log
'
)
logging
.
config
.
dictConfig
(
log_conf
(
log_file
))
# (iv) instanciate the training, validation and test datasets and
# dataloaders
# dataloaders
train_ds
,
valid_ds
,
test_ds
=
sc
.
train_val_test_split
(
ds
)
train_ds
,
valid_ds
,
test_ds
=
sc
.
train_val_test_split
(
ds
)
train_dl
,
valid_dl
,
test_dl
=
sc
.
dataloaders
(
train_ds
,
train_dl
,
valid_dl
,
test_dl
=
sc
.
dataloaders
(
valid_ds
,
train_ds
,
valid_ds
,
test_ds
,
batch_size
=
mc
.
batch_size
,
shuffle
=
True
,
test_ds
,
drop_last
=
False
)
batch_size
=
mc
.
batch_size
,
shuffle
=
True
,
drop_last
=
False
)
# (iv) instanciate the model state files
state_file
,
loss_state
=
mc
.
init_state
(
ds
,
sc
,
tc
)
# (v) instanciate the model
# (
i
v) instanciate the model
model
=
mc
.
init_model
(
ds
)
model
=
mc
.
init_model
(
ds
)
# (vi) instanciate the optimizer and the loss function
# (vi) instanciate the optimizer and the loss function
optimizer
=
t
c
.
init_optimizer
(
model
)
optimizer
=
m
c
.
init_optimizer
(
model
)
loss_function
=
t
c
.
init_loss_function
()
loss_function
=
m
c
.
init_loss_function
()
# (vii) resume training from an existing model checkpoint
# (vii) resume training from an existing model checkpoint
checkpoint_state
,
max_accuracy
=
mc
.
load
_checkpoint
(
state_file
,
loss_state
,
(
model
,
optimizer
,
checkpoint_state
,
max_accuracy
)
=
mc
.
from
_checkpoint
(
model
,
optimizer
)
model
,
optimizer
,
state_file
,
loss_state
)
# (viii) initialize network trainer class for eays model training
# (viii) initialize network trainer class for eays model training
trainer
=
NetworkTrainer
(
model
=
model
,
trainer
=
NetworkTrainer
(
model
=
model
,
...
@@ -58,15 +62,15 @@ if __name__ == '__main__':
...
@@ -58,15 +62,15 @@ if __name__ == '__main__':
valid_dl
=
valid_dl
,
valid_dl
=
valid_dl
,
state_file
=
state_file
,
state_file
=
state_file
,
loss_state
=
loss_state
,
loss_state
=
loss_state
,
epochs
=
t
c
.
epochs
,
epochs
=
m
c
.
epochs
,
nthreads
=
t
c
.
nthreads
,
nthreads
=
m
c
.
nthreads
,
early_stop
=
t
c
.
early_stop
,
early_stop
=
m
c
.
early_stop
,
mode
=
t
c
.
mode
,
mode
=
m
c
.
mode
,
delta
=
t
c
.
delta
,
delta
=
m
c
.
delta
,
patience
=
t
c
.
patience
,
patience
=
m
c
.
patience
,
max_accuracy
=
max_accuracy
,
max_accuracy
=
max_accuracy
,
checkpoint_state
=
checkpoint_state
,
checkpoint_state
=
checkpoint_state
,
save
=
t
c
.
save
save
=
m
c
.
save
)
)
# (ix) train model
# (ix) train model
...
...
This diff is collapsed.
Click to expand it.
Preview
0%
Loading
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Save comment
Cancel
Please
register
or
sign in
to comment