Skip to content
GitLab
Explore
Sign in
Primary navigation
Search or go to…
Project
P
PySegCNN
Manage
Activity
Members
Labels
Plan
Issues
Issue boards
Milestones
Wiki
Code
Merge requests
Repository
Branches
Commits
Tags
Repository graph
Compare revisions
Snippets
Deploy
Releases
Package Registry
Container Registry
Model registry
Operate
Terraform modules
Monitor
Incidents
Analyze
Value stream analytics
Contributor analytics
Repository analytics
Model experiments
Help
Help
Support
GitLab documentation
Compare GitLab plans
Community forum
Contribute to GitLab
Provide feedback
Keyboard shortcuts
?
Snippets
Groups
Projects
Show more breadcrumbs
earth_observation_public
PySegCNN
Commits
f7112679
Commit
f7112679
authored
4 years ago
by
Frisinghelli Daniel
Browse files
Options
Downloads
Patches
Plain Diff
Major refactor: Increased modularity
parent
a2f71786
No related branches found
No related tags found
No related merge requests found
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
pysegcnn/core/trainer.py
+147
-121
147 additions, 121 deletions
pysegcnn/core/trainer.py
with
147 additions
and
121 deletions
pysegcnn/core/trainer.py
+
147
−
121
View file @
f7112679
...
...
@@ -8,6 +8,7 @@ Created on Wed Aug 12 10:24:34 2020
import
dataclasses
import
pathlib
import
logging
import
datetime
# externals
import
numpy
as
np
...
...
@@ -199,7 +200,6 @@ class ModelConfig(BaseConfig):
skip_connection
:
bool
=
True
kwargs
:
dict
=
dataclasses
.
field
(
default_factory
=
lambda
:
{
'
kernel_size
'
:
3
,
'
stride
'
:
1
,
'
dilation
'
:
1
})
state_path
:
pathlib
.
Path
=
pathlib
.
Path
(
HERE
).
joinpath
(
'
_models/
'
)
batch_size
:
int
=
64
checkpoint
:
bool
=
False
transfer
:
bool
=
False
...
...
@@ -226,11 +226,16 @@ class ModelConfig(BaseConfig):
# check whether the loss function is currently supported
self
.
loss_class
=
item_in_enum
(
self
.
loss_name
,
SupportedLossFunctions
)
# path to model states
self
.
state_path
=
pathlib
.
Path
(
HERE
).
joinpath
(
'
_models/
'
)
# path to pretrained model
self
.
pretrained_path
=
self
.
state_path
.
joinpath
(
self
.
pretrained_model
)
def
init_optimizer
(
self
,
model
):
LOGGER
.
info
(
'
Optimizer: {}.
'
.
format
(
repr
(
self
.
optim_class
)))
# initialize the optimizer for the specified model
optimizer
=
self
.
optim_class
(
model
.
parameters
(),
self
.
lr
)
...
...
@@ -238,17 +243,28 @@ class ModelConfig(BaseConfig):
def
init_loss_function
(
self
):
LOGGER
.
info
(
'
Loss function: {}.
'
.
format
(
repr
(
self
.
loss_class
)))
# instanciate the loss function
loss_function
=
self
.
loss_class
()
return
loss_function
def
init_model
(
self
,
ds
):
def
init_model
(
self
,
ds
,
state_file
):
# write an initialization string to the log file
# now = datetime.datetime.strftime(datetime.datetime.now(),
# '%Y-%m-%dT%H:%M:%S')
# LOGGER.info(80 * '-')
# LOGGER.info('{}: Initializing model run. '.format(now) + 35 * '-')
# LOGGER.info(80 * '-')
# case (1): build a new model
if
not
self
.
transfer
:
# set the random seed for reproducibility
torch
.
manual_seed
(
self
.
torch_seed
)
LOGGER
.
info
(
'
Initializing model: {}
'
.
format
(
state_file
.
name
))
# instanciate the model
model
=
self
.
model_class
(
...
...
@@ -261,104 +277,86 @@ class ModelConfig(BaseConfig):
# case (2): load a pretrained model for transfer learning
else
:
# load pretrained model
model
,
_
=
self
.
load_pretrained
(
self
.
pretrained_path
,
new_ds
=
ds
)
LOGGER
.
info
(
'
Loading pretrained model for transfer learning from:
'
'
{}
'
.
format
(
self
.
pretrained_path
))
model
=
self
.
transfer_model
(
self
.
pretrained_path
,
ds
)
return
model
def
from_checkpoint
(
self
,
model
,
optimizer
,
state_file
,
loss_state
):
# initialize the optimizer
optimizer
=
self
.
init_optimizer
(
model
)
# whether to resume training from an existing model checkpoint
checkpoint_state
=
{}
max_accuracy
=
0
if
self
.
checkpoint
:
model
,
optimizer
,
checkpoint_state
=
self
.
load_checkpoint
(
model
,
optimizer
,
state_file
)
# check whether the checkpoint exists
if
state_file
.
exists
()
and
loss_state
.
exists
():
# load model checkpoint
model
,
optimizer
=
self
.
load_pretrained
(
state_file
,
optimizer
,
new_ds
=
None
)
(
checkpoint_state
,
max_accuracy
)
=
self
.
load_checkpoint
(
loss_state
)
else
:
LOGGER
.
info
(
'
Checkpoint for model {} does not exist.
'
'
Initializing new model.
'
.
format
(
state_file
.
name
))
return
model
,
optimizer
,
checkpoint_state
,
max_accuracy
return
model
,
optimizer
,
checkpoint_state
@staticmethod
def
load_pretrained
(
state_file
,
optimizer
=
None
,
new_ds
=
None
):
# load the pretrained model
if
not
state_file
.
exists
():
raise
FileNotFoundError
(
'
Pretrained model {} does not exist.
'
.
format
(
state_file
))
LOGGER
.
info
(
'
Loading pretrained model: {}
'
.
format
(
state_file
.
name
))
# load the model state
model_state
=
torch
.
load
(
state_file
)
def
load_checkpoint
(
model
,
optimizer
,
state_file
):
# the model class
model_class
=
model_state
[
'
cls
'
]
# instanciate pretrained model architecture
model
=
model_class
(
**
model_state
[
'
params
'
],
**
model_state
[
'
kwargs
'
])
# load pretrained model weights
_
=
model
.
load
(
state_file
.
name
,
optimizer
=
optimizer
,
inpath
=
str
(
state_file
.
parent
))
LOGGER
.
info
(
'
Model epoch: {:d}
'
.
format
(
model
.
epoch
))
# whether to resume training from an existing model checkpoint
checkpoint_state
=
{}
# check whether to apply pretrained model on a new dataset
if
new_ds
is
not
None
:
LOGGER
.
info
(
'
Configuring model for new dataset: {}.
'
.
format
(
new_ds
.
__class__
.
__name__
))
# if no checkpoint exists, file a warning and continue with a model
# initialized from scratch
if
not
state_file
.
exists
():
LOGGER
.
warning
(
'
Checkpoint for model {} does not exist.
'
'
Initializing new model.
'
.
format
(
state_file
.
name
))
else
:
# load model checkpoint
model
,
optimizer
,
model_state
=
Network
.
load
(
state_file
,
optimizer
)
# the bands the model was trained with
bands
=
model_state
[
'
bands
'
]
# load model loss and accuracy
# check whether the current dataset uses the correct spectral bands
if
new_ds
.
use_bands
!=
bands
:
raise
ValueError
(
'
The pretrained network was trained with the
'
'
bands {}, not with: {}
'
.
format
(
bands
,
new_ds
.
use_bands
))
# get all non-zero elements, i.e. get number of epochs trained
# before the early stop
checkpoint_state
=
{
k
:
v
[
np
.
nonzero
(
v
)].
reshape
(
v
.
shape
[
0
],
-
1
)
for
k
,
v
in
model_state
[
'
state
'
].
items
()}
# get the number of convolutional filters
filters
=
model_state
[
'
params
'
][
'
filters
'
]
return
model
,
optimizer
,
checkpoint_state
# reset model epoch to 0, since the model is trained on a different
# dataset
model
.
epoch
=
0
@staticmethod
def
transfer_model
(
state_file
,
ds
):
# adjust the number of classes in the model
model
.
nclasses
=
len
(
new_ds
.
labels
)
LOGGER
.
info
(
'
Replacing classification layer to classes:
{}.
'
.
format
(
'
,
'
.
join
(
'
({}, {})
'
.
format
(
k
,
v
[
'
label
'
])
for
k
,
v
in
new_ds
.
labels
.
items
()
)))
# check input type
if
not
isinstance
(
ds
,
ImageDataset
):
raise
TypeError
(
'
Expected
"
ds
"
to be
{}.
'
.
format
(
'
.
'
.
join
(
[
ImageDataset
.
__module__
,
ImageDataset
.
__name__
]
)))
# adjust the classification layer to the number of classes of the
# current dataset
model
.
classifier
=
Conv2dSame
(
in_channels
=
filters
[
0
],
out_channels
=
model
.
nclasses
,
kernel_size
=
1
)
# load the pretrained model
model
,
_
,
model_state
=
Network
.
load
(
state_file
)
LOGGER
.
info
(
'
Configuring model for new dataset: {}.
'
.
format
(
ds
.
__class__
.
__name__
))
return
model
,
optimizer
# check whether the current dataset uses the correct spectral bands
if
new_ds
.
use_bands
!=
model_state
[
'
bands
'
]:
raise
ValueError
(
'
The pretrained network was trained with
'
'
bands {}, not with bands {}.
'
.
format
(
model_state
[
'
bands
'
],
new_ds
.
use_bands
))
@staticmethod
def
load_checkpoint
(
loss_state
):
# get the number of convolutional filters
filters
=
model_state
[
'
params
'
][
'
filters
'
]
# load the model loss and accuracy
checkpoint_state
=
torch
.
load
(
loss_state
)
# reset model epoch to 0, since the model is trained on a different
# dataset
model
.
epoch
=
0
# get all non-zero elements, i.e. get number of epochs trained
# before the early stop
checkpoint_state
=
{
k
:
v
[
np
.
nonzero
(
v
)].
reshape
(
v
.
shape
[
0
],
-
1
)
for
k
,
v
in
checkpoint_state
.
items
()}
# adjust the number of classes in the model
model
.
nclasses
=
len
(
ds
.
labels
)
LOGGER
.
info
(
'
Replacing classification layer to classes: {}.
'
.
format
(
'
,
'
.
join
(
'
({}, {})
'
.
format
(
k
,
v
[
'
label
'
])
for
k
,
v
in
ds
.
labels
.
items
())))
# maximum accuracy on the validation set
max_accuracy
=
checkpoint_state
[
'
va
'
][:,
-
1
].
mean
().
item
()
# adjust the classification layer to the number of classes of the
# current dataset
model
.
classifier
=
Conv2dSame
(
in_channels
=
filters
[
0
],
out_channels
=
model
.
nclasses
,
kernel_size
=
1
)
return
checkpoint_state
,
max_accuracy
return
model
@dataclasses.dataclass
...
...
@@ -420,14 +418,12 @@ class StateConfig(BaseConfig):
# path to model state
state
=
self
.
mc
.
state_path
.
joinpath
(
state_file
)
# path to model loss/accuracy
loss_state
=
pathlib
.
Path
(
str
(
state
).
replace
(
'
.pt
'
,
'
_loss.pt
'
))
return
state
,
loss_state
return
state
@dataclasses.dataclass
class
EvalConfig
(
BaseConfig
):
state_file
:
pathlib
.
Path
test
:
object
predict_scene
:
bool
=
False
plot_samples
:
bool
=
False
...
...
@@ -446,6 +442,32 @@ class EvalConfig(BaseConfig):
raise
TypeError
(
'
Expected
"
test
"
to be None, True or False, got
'
'
{}.
'
.
format
(
self
.
test
))
# the output paths for the different graphics
self
.
base_path
=
pathlib
.
Path
(
HERE
)
self
.
sample_path
=
self
.
base_path
.
joinpath
(
'
_samples
'
)
self
.
scenes_path
=
self
.
base_path
.
joinpath
(
'
_scenes
'
)
self
.
models_path
=
self
.
base_path
.
joinpath
(
'
_graphics
'
)
# write initialization string to log file
# LOGGER.info(80 * '-')
# LOGGER.info('{}')
# LOGGER.info(80 * '-')
@dataclasses.dataclass
class
LogConfig
(
BaseConfig
):
state_file
:
pathlib
.
Path
def
__post_init__
(
self
):
super
().
__post_init__
()
# the path to store model logs
self
.
log_path
=
pathlib
.
Path
(
HERE
).
joinpath
(
'
_logs
'
)
# the log file of the current model
self
.
log_file
=
self
.
log_path
.
joinpath
(
self
.
state_file
.
name
.
replace
(
'
.pt
'
,
'
.log
'
))
@dataclasses.dataclass
class
NetworkTrainer
(
BaseConfig
):
...
...
@@ -454,15 +476,14 @@ class NetworkTrainer(BaseConfig):
loss_function
:
nn
.
Module
train_dl
:
DataLoader
valid_dl
:
DataLoader
test_dl
:
DataLoader
state_file
:
pathlib
.
Path
loss_state
:
pathlib
.
Path
epochs
:
int
=
1
nthreads
:
int
=
torch
.
get_num_threads
()
early_stop
:
bool
=
False
mode
:
str
=
'
max
'
delta
:
float
=
0
patience
:
int
=
10
max_accuracy
:
float
=
0
checkpoint_state
:
dict
=
dataclasses
.
field
(
default_factory
=
dict
)
save
:
bool
=
True
...
...
@@ -473,16 +494,21 @@ class NetworkTrainer(BaseConfig):
self
.
device
=
torch
.
device
(
"
cuda:0
"
if
torch
.
cuda
.
is_available
()
else
"
cpu
"
)
# maximum accuracy on the validation dataset
self
.
max_accuracy
=
0
if
self
.
checkpoint_state
:
self
.
max_accuracy
=
self
.
checkpoint_state
[
'
va
'
].
mean
(
axis
=
0
).
max
().
item
()
# whether to use early stopping
self
.
es
=
None
if
self
.
early_stop
:
self
.
es
=
EarlyStopping
(
self
.
mode
,
self
.
max_accuracy
,
self
.
delta
,
self
.
patience
)
def
train
(
self
):
LOGGER
.
info
(
3
0
*
'
-
'
+
'
Training
'
+
3
0
*
'
-
'
)
LOGGER
.
info
(
3
5
*
'
-
'
+
'
Training
'
+
3
5
*
'
-
'
)
# set the number of threads
LOGGER
.
info
(
'
Device: {}
'
.
format
(
self
.
device
))
...
...
@@ -493,11 +519,11 @@ class NetworkTrainer(BaseConfig):
# training and validation dataset
tshape
=
(
len
(
self
.
train_dl
),
self
.
epochs
)
vshape
=
(
len
(
self
.
valid_dl
),
self
.
epochs
)
training_state
=
{
'
tl
'
:
np
.
zeros
(
shape
=
tshape
),
'
ta
'
:
np
.
zeros
(
shape
=
tshape
),
'
vl
'
:
np
.
zeros
(
shape
=
vshape
),
'
va
'
:
np
.
zeros
(
shape
=
vshape
)
}
self
.
training_state
=
{
'
tl
'
:
np
.
zeros
(
shape
=
tshape
),
'
ta
'
:
np
.
zeros
(
shape
=
tshape
),
'
vl
'
:
np
.
zeros
(
shape
=
vshape
),
'
va
'
:
np
.
zeros
(
shape
=
vshape
)
}
# send the model to the gpu if available
self
.
model
=
self
.
model
.
to
(
self
.
device
)
...
...
@@ -525,7 +551,7 @@ class NetworkTrainer(BaseConfig):
# compute loss
loss
=
self
.
loss_function
(
outputs
,
labels
.
long
())
observed_loss
=
loss
.
detach
().
numpy
().
item
()
training_state
[
'
tl
'
][
batch
,
epoch
]
=
observed_loss
self
.
training_state
[
'
tl
'
][
batch
,
epoch
]
=
observed_loss
# compute the gradients of the loss function w.r.t.
# the network weights
...
...
@@ -539,7 +565,7 @@ class NetworkTrainer(BaseConfig):
# calculate accuracy on current batch
observed_accuracy
=
accuracy_function
(
ypred
,
labels
)
training_state
[
'
ta
'
][
batch
,
epoch
]
=
observed_accuracy
self
.
training_state
[
'
ta
'
][
batch
,
epoch
]
=
observed_accuracy
# print progress
LOGGER
.
info
(
'
Epoch: {:d}/{:d}, Mini-batch: {:d}/{:d},
'
...
...
@@ -562,8 +588,8 @@ class NetworkTrainer(BaseConfig):
vacc
,
vloss
=
self
.
predict
()
# append observed accuracy and loss to arrays
training_state
[
'
va
'
][:,
epoch
]
=
vacc
.
squeeze
()
training_state
[
'
vl
'
][:,
epoch
]
=
vloss
.
squeeze
()
self
.
training_state
[
'
va
'
][:,
epoch
]
=
vacc
.
squeeze
()
self
.
training_state
[
'
vl
'
][:,
epoch
]
=
vloss
.
squeeze
()
# metric to assess model performance on the validation set
epoch_acc
=
vacc
.
squeeze
().
mean
()
...
...
@@ -574,7 +600,7 @@ class NetworkTrainer(BaseConfig):
# save model state if the model improved with
# respect to the previous epoch
self
.
save_state
(
training_state
)
self
.
save_state
()
# whether the early stopping criterion is met
if
self
.
es
.
stop
(
epoch_acc
):
...
...
@@ -583,15 +609,13 @@ class NetworkTrainer(BaseConfig):
else
:
# if no early stopping is required, the model state is
# saved after each epoch
self
.
save_state
(
training_state
)
self
.
save_state
()
return
training_state
return
self
.
training_state
def
predict
(
self
):
LOGGER
.
info
(
30
*
'
-
'
+
'
Predicting
'
+
30
*
'
-
'
)
# send the model to the gpu if available
self
.
model
=
self
.
model
.
to
(
self
.
device
)
...
...
@@ -631,37 +655,38 @@ class NetworkTrainer(BaseConfig):
.
format
(
batch
+
1
,
len
(
self
.
valid_dl
),
acc
))
# calculate overall accuracy on the validation/test set
LOGGER
.
info
(
'
Epoch {:d},
Overall
accuracy: {:.2f}%.
'
LOGGER
.
info
(
'
Epoch
:
{:d},
Mean
accuracy: {:.2f}%.
'
.
format
(
self
.
model
.
epoch
,
accuracies
.
mean
()
*
100
))
return
accuracies
,
losses
def
save_state
(
self
,
training_state
):
def
save_state
(
self
):
# whether to save the model state
if
self
.
save
:
# save model state
state
=
self
.
model
.
save
(
self
.
state_file
.
name
,
self
.
optimizer
,
self
.
train_dl
.
dataset
.
dataset
.
use_bands
,
self
.
state_file
.
parent
)
# save losses and accuracy
self
.
_save_loss
(
training_state
)
def
_save_loss
(
self
,
training_state
):
# append the model performance before the checkpoint to the model
# state, if a checkpoint is passed
if
self
.
checkpoint_state
:
# save losses and accuracy
state
=
training_state
if
self
.
checkpoint_state
:
# append values from checkpoint to current training state
state
=
{
k1
:
np
.
hstack
([
v1
,
v2
])
for
(
k1
,
v1
),
(
k2
,
v2
)
in
zip
(
self
.
checkpoint_state
.
items
(),
self
.
training_state
.
items
())
if
k1
==
k2
}
else
:
state
=
self
.
training_state
# append values from checkpoint to current training state
state
=
{
k1
:
np
.
hstack
([
v1
,
v2
])
for
(
k1
,
v1
),
(
k2
,
v2
)
in
zip
(
self
.
checkpoint_state
.
items
(),
training_state
.
items
())
if
k1
==
k2
}
# save model state
_
=
self
.
model
.
save
(
self
.
state_file
,
self
.
optimizer
,
bands
=
self
.
train_dl
.
dataset
.
dataset
.
use_bands
,
train_ds
=
self
.
train_dl
.
dataset
,
valid_ds
=
self
.
valid_dl
.
dataset
,
test_ds
=
self
.
test_dl
.
dataset
,
state
=
state
,
)
# save the model loss and accuracies to file
torch
.
save
(
state
,
self
.
loss_state
)
def
__repr__
(
self
):
...
...
@@ -687,6 +712,7 @@ class NetworkTrainer(BaseConfig):
fs
+=
'
\n
(split):
'
fs
+=
'
\n
'
+
repr
(
self
.
train_dl
.
dataset
)
fs
+=
'
\n
'
+
repr
(
self
.
valid_dl
.
dataset
)
fs
+=
'
\n
'
+
repr
(
self
.
test_dl
.
dataset
)
# model
fs
+=
'
\n
(model):
\n
'
...
...
@@ -764,6 +790,6 @@ class EarlyStopping(object):
def
__repr__
(
self
):
fs
=
self
.
__class__
.
__name__
fs
+=
'
(mode={}, best={}, delta={}, patience={})
'
.
format
(
fs
+=
'
(mode={}, best={
:.2f
}, delta={}, patience={})
'
.
format
(
self
.
mode
,
self
.
best
,
self
.
min_delta
,
self
.
patience
)
return
fs
This diff is collapsed.
Click to expand it.
Preview
0%
Loading
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Save comment
Cancel
Please
register
or
sign in
to comment