Skip to content
GitLab
Explore
Sign in
Primary navigation
Search or go to…
Project
P
PySegCNN
Manage
Activity
Members
Labels
Plan
Issues
Issue boards
Milestones
Wiki
Code
Merge requests
Repository
Branches
Commits
Tags
Repository graph
Compare revisions
Snippets
Deploy
Releases
Package registry
Container registry
Model registry
Operate
Terraform modules
Monitor
Incidents
Analyze
Value stream analytics
Contributor analytics
Repository analytics
Model experiments
Help
Help
Support
GitLab documentation
Compare GitLab plans
Community forum
Contribute to GitLab
Provide feedback
Keyboard shortcuts
?
Snippets
Groups
Projects
Show more breadcrumbs
earth_observation_public
PySegCNN
Commits
6e330d11
Commit
6e330d11
authored
4 years ago
by
Frisinghelli Daniel
Browse files
Options
Downloads
Patches
Plain Diff
Improved initialization of dataset and model
parent
7cce87b3
No related branches found
No related tags found
No related merge requests found
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
pysegcnn/core/trainer.py
+100
-66
100 additions, 66 deletions
pysegcnn/core/trainer.py
with
100 additions
and
66 deletions
pysegcnn/core/trainer.py
+
100
−
66
View file @
6e330d11
...
...
@@ -22,7 +22,6 @@ from pysegcnn.core.split import (RandomTileSplit, RandomSceneSplit, DateSplit,
VALID_SPLIT_MODES
)
class
NetworkTrainer
(
object
):
def
__init__
(
self
,
config
):
...
...
@@ -38,12 +37,11 @@ class NetworkTrainer(object):
# initialize the dataset to train the model on
self
.
_init_dataset
()
# initialize the model
self
.
_init_model
()
# initialize the model state files
self
.
_init_state
()
# initialize the model
self
.
_init_model
()
def
from_pretrained
(
self
):
...
...
@@ -69,7 +67,7 @@ class NetworkTrainer(object):
.
format
(
self
.
bands
))
# instanciate pretrained model architecture
model
=
self
.
net
(
**
model_state
[
'
params
'
],
**
model_state
[
'
kwargs
'
])
model
=
self
.
model
(
**
model_state
[
'
params
'
],
**
model_state
[
'
kwargs
'
])
# load pretrained model weights
model
.
load
(
self
.
pretrained_model
,
inpath
=
self
.
state_path
)
...
...
@@ -78,39 +76,41 @@ class NetworkTrainer(object):
# dataset
model
.
epoch
=
0
# adjust the number of classes in the model
model
.
nclasses
=
len
(
self
.
dataset
.
labels
)
# adjust the classification layer to the number of classes of the
# current dataset
model
.
classifier
=
Conv2dSame
(
in_channels
=
filters
[
0
],
out_channels
=
len
(
self
.
dataset
.
labels
)
,
out_channels
=
model
.
nclasses
,
kernel_size
=
1
)
# adjust the number of classes in the model
model
.
nclasses
=
len
(
self
.
dataset
.
labels
)
return
model
def
from_checkpoint
(
self
):
# whether to resume training from an existing model
checkpoint_state
=
None
max_accuracy
=
0
if
os
.
path
.
exists
(
self
.
state
)
and
self
.
checkpoint
:
# load the model state
state
=
self
.
model
.
load
(
self
.
state_file
,
self
.
optimizer
,
self
.
state_path
)
print
(
'
Resuming training from {} ...
'
.
format
(
state
))
print
(
'
Model epoch: {:d}
'
.
format
(
self
.
model
.
epoch
))
if
not
os
.
path
.
exists
(
self
.
state
):
raise
FileNotFoundError
(
'
Model checkpoint {} does not exist.
'
.
format
(
self
.
state
))
# load the model state
state
=
self
.
model
.
load
(
self
.
state_file
,
self
.
optimizer
,
self
.
state_path
)
print
(
'
Resuming training from {} ...
'
.
format
(
state
))
print
(
'
Model epoch: {:d}
'
.
format
(
self
.
model
.
epoch
))
# load the model loss and accuracy
checkpoint_state
=
torch
.
load
(
self
.
loss_state
)
# load the model loss and accuracy
checkpoint_state
=
torch
.
load
(
self
.
loss_state
)
# get all non-zero elements, i.e. get number of epochs trained
# before the early stop
checkpoint_state
=
{
k
:
v
[
np
.
nonzero
(
v
)].
reshape
(
v
.
shape
[
0
],
-
1
)
for
k
,
v
in
checkpoint_state
.
items
()}
# get all non-zero elements, i.e. get number of epochs trained
# before the early stop
checkpoint_state
=
{
k
:
v
[
np
.
nonzero
(
v
)].
reshape
(
v
.
shape
[
0
],
-
1
)
for
k
,
v
in
checkpoint_state
.
items
()}
# maximum accuracy on the validation set
max_accuracy
=
checkpoint_state
[
'
va
'
][:,
-
1
].
mean
().
item
()
# maximum accuracy on the validation set
max_accuracy
=
checkpoint_state
[
'
va
'
][:,
-
1
].
mean
().
item
()
return
checkpoint_state
,
max_accuracy
...
...
@@ -128,9 +128,6 @@ class NetworkTrainer(object):
print
(
'
mode = {}, delta = {}, patience = {} epochs ...
'
.
format
(
self
.
mode
,
self
.
delta
,
self
.
patience
))
# initial accuracy on the validation set
max_accuracy
=
0
# create dictionary of the observed losses and accuracies on the
# training and validation dataset
tshape
=
(
len
(
self
.
train_dl
),
self
.
epochs
)
...
...
@@ -141,9 +138,6 @@ class NetworkTrainer(object):
'
va
'
:
np
.
zeros
(
shape
=
vshape
)
}
# whether to resume training from an existing model
checkpoint_state
,
max_accuracy
=
self
.
from_checkpoint
()
# send the model to the gpu if available
self
.
model
=
self
.
model
.
to
(
self
.
device
)
...
...
@@ -188,7 +182,8 @@ class NetworkTrainer(object):
# print progress
print
(
'
Epoch: {:d}/{:d}, Mini-batch: {:d}/{:d}, Loss: {:.2f},
'
'
Accuracy: {:.2f}
'
.
format
(
epoch
+
1
,
self
.
epochs
,
'
Accuracy: {:.2f}
'
.
format
(
epoch
+
1
,
self
.
epochs
,
batch
+
1
,
len
(
self
.
train_dl
),
observed_loss
,
...
...
@@ -212,8 +207,8 @@ class NetworkTrainer(object):
epoch_acc
=
vacc
.
squeeze
().
mean
()
# whether the model improved with respect to the previous epoch
if
es
.
increased
(
epoch_acc
,
max_accuracy
,
self
.
delta
):
max_accuracy
=
epoch_acc
if
es
.
increased
(
epoch_acc
,
self
.
max_accuracy
,
self
.
delta
):
self
.
max_accuracy
=
epoch_acc
# save model state if the model improved with
# respect to the previous epoch
_
=
self
.
model
.
save
(
self
.
state_file
,
...
...
@@ -224,7 +219,7 @@ class NetworkTrainer(object):
# save losses and accuracy
self
.
_save_loss
(
training_state
,
self
.
checkpoint
,
checkpoint_state
)
self
.
checkpoint_state
)
# whether the early stopping criterion is met
if
es
.
stop
(
epoch_acc
):
...
...
@@ -241,7 +236,7 @@ class NetworkTrainer(object):
# save losses and accuracy after each epoch
self
.
_save_loss
(
training_state
,
self
.
checkpoint
,
checkpoint_state
)
self
.
checkpoint_state
)
return
training_state
...
...
@@ -300,7 +295,7 @@ class NetworkTrainer(object):
# format: networkname_datasetname_t(tilesize)_b(batchsize)_bands.pt
bformat
=
''
.
join
([
b
[
0
]
for
b
in
self
.
bands
])
if
self
.
bands
else
'
all
'
self
.
state_file
=
(
'
{}_{}_t{}_b{}_{}.pt
'
.
format
(
self
.
model
.
__
class__
.
__
name__
,
.
format
(
self
.
model
.
__name__
,
self
.
dataset
.
__class__
.
__name__
,
self
.
tile_size
,
self
.
batch_size
,
...
...
@@ -321,34 +316,38 @@ class NetworkTrainer(object):
def
_init_dataset
(
self
):
# check whether the dataset is currently supported
self
.
dataset
=
None
for
dataset
in
SupportedDatasets
:
if
self
.
dataset_name
==
dataset
.
name
:
self
.
dataset
=
dataset
.
value
[
'
class
'
](
self
.
dataset_path
,
use_bands
=
self
.
bands
,
tile_size
=
self
.
tile_size
,
sort
=
self
.
sort
,
transforms
=
self
.
transforms
,
pad
=
self
.
pad
,
cval
=
self
.
cval
,
gt_pattern
=
self
.
gt_pattern
)
# the dataset name
self
.
dataset_name
=
os
.
path
.
basename
(
self
.
root_dir
)
if
self
.
dataset
is
None
:
# check whether the dataset is currently supported
if
self
.
dataset_name
not
in
SupportedDatasets
.
__members__
:
raise
ValueError
(
'
{} is not a valid dataset.
'
.
format
(
self
.
dataset_name
)
+
'
Available datasets are:
\n
'
+
'
\n
'
.
join
(
name
for
name
,
_
in
SupportedDatasets
.
__members__
.
items
()))
else
:
self
.
dataset_class
=
SupportedDatasets
.
__members__
[
self
.
dataset_name
].
value
# instanciate the dataset
self
.
dataset
=
self
.
dataset_class
(
self
.
root_dir
,
use_bands
=
self
.
bands
,
tile_size
=
self
.
tile_size
,
sort
=
self
.
sort
,
transforms
=
self
.
transforms
,
pad
=
self
.
pad
,
cval
=
self
.
cval
,
gt_pattern
=
self
.
gt_pattern
)
# the mode to split
if
self
.
split_mode
not
in
VALID_SPLIT_MODES
:
raise
ValueError
(
'
{} is not supported. Valid modes are {}, see
'
'
pysegcnn.main.config.py for a description of
'
'
each mode.
'
.
format
(
self
.
split_mode
,
VALID_SPLIT_MODES
))
'
pysegcnn.main.config.py for a description of
'
'
each mode.
'
.
format
(
self
.
split_mode
,
VALID_SPLIT_MODES
))
if
self
.
split_mode
==
'
random
'
:
self
.
subset
=
RandomTileSplit
(
self
.
dataset
,
self
.
ttratio
,
...
...
@@ -364,12 +363,12 @@ class NetworkTrainer(object):
self
.
date
,
self
.
dateformat
)
# the training, validation and dataset
# the training, validation and
test
dataset
self
.
train_ds
,
self
.
valid_ds
,
self
.
test_ds
=
self
.
subset
.
split
()
# whether to drop training samples with a fraction of pixels equal to
# the constant padding value self.cval >= self.drop
if
self
.
pad
:
if
self
.
pad
and
self
.
drop
:
self
.
_drop
(
self
.
train_ds
)
# the shape of a single batch
...
...
@@ -400,18 +399,53 @@ class NetworkTrainer(object):
def
_init_model
(
self
):
# instanciate the segmentation network
if
self
.
pretrained
:
# initial accuracy on the validation set
self
.
max_accuracy
=
0
# set the model checkpoint to None, overwritten when resuming
# training from an existing model checkpoint
self
.
checkpoint_state
=
None
# case (1): build a model for the specified dataset
if
not
self
.
pretrained
and
not
self
.
checkpoint
:
# instanciate the model
self
.
model
=
self
.
model
(
in_channels
=
len
(
self
.
dataset
.
use_bands
),
nclasses
=
len
(
self
.
dataset
.
labels
),
filters
=
self
.
filters
,
skip
=
self
.
skip_connection
,
**
self
.
kwargs
)
# the optimizer used to update the model weights
self
.
optimizer
=
self
.
optimizer
(
self
.
model
.
parameters
(),
self
.
lr
)
# case (2): using a pretrained model withouth existing checkpoint on
# a new dataset, i.e. transfer learning
if
self
.
pretrained
and
not
self
.
checkpoint
:
# load pretrained model
self
.
model
=
self
.
from_pretrained
()
else
:
self
.
model
=
self
.
net
(
in_channels
=
len
(
self
.
dataset
.
use_bands
),
nclasses
=
len
(
self
.
dataset
.
labels
),
filters
=
self
.
filters
,
skip
=
self
.
skip_connection
,
**
self
.
kwargs
)
# the optimizer used to update the model weights
self
.
optimizer
=
self
.
optimizer
(
self
.
model
.
parameters
(),
self
.
lr
)
# the optimizer used to update the model weights
self
.
optimizer
=
self
.
optimizer
(
self
.
model
.
parameters
(),
self
.
lr
)
# case (3): using a pretrained model with existing checkpoint on the
# same dataset the pretrained model was trained on
elif
self
.
checkpoint
:
# instanciate the model
self
.
model
=
self
.
model
(
in_channels
=
len
(
self
.
dataset
.
use_bands
),
nclasses
=
len
(
self
.
dataset
.
labels
),
filters
=
self
.
filters
,
skip
=
self
.
skip_connection
,
**
self
.
kwargs
)
# the optimizer used to update the model weights
self
.
optimizer
=
self
.
optimizer
(
self
.
model
.
parameters
(),
self
.
lr
)
# whether to resume training from an existing model checkpoint
if
self
.
checkpoint
:
(
self
.
checkpoint_state
,
self
.
max_accuracy
)
=
self
.
from_checkpoint
()
# function to drop samples with a fraction of pixels equal to the constant
# padding value self.cval >= self.drop
...
...
This diff is collapsed.
Click to expand it.
Preview
0%
Loading
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Save comment
Cancel
Please
register
or
sign in
to comment