Skip to content
GitLab
Explore
Sign in
Primary navigation
Search or go to…
Project
P
PySegCNN
Manage
Activity
Members
Labels
Plan
Issues
Issue boards
Milestones
Wiki
Code
Merge requests
Repository
Branches
Commits
Tags
Repository graph
Compare revisions
Snippets
Deploy
Releases
Package registry
Container Registry
Model registry
Operate
Terraform modules
Monitor
Incidents
Analyze
Value stream analytics
Contributor analytics
Repository analytics
Model experiments
Help
Help
Support
GitLab documentation
Compare GitLab plans
Community forum
Contribute to GitLab
Provide feedback
Keyboard shortcuts
?
Snippets
Groups
Projects
Show more breadcrumbs
earth_observation_public
PySegCNN
Commits
d5a08ef8
Commit
d5a08ef8
authored
4 years ago
by
Frisinghelli Daniel
Browse files
Options
Downloads
Patches
Plain Diff
Added support to use a pretrained model on a differentdataset for transfer learning tasks
parent
bac6bc17
No related branches found
No related tags found
No related merge requests found
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
pytorch/trainer.py
+102
-77
102 additions, 77 deletions
pytorch/trainer.py
with
102 additions
and
77 deletions
pytorch/trainer.py
+
102
−
77
View file @
d5a08ef8
...
...
@@ -20,7 +20,7 @@ sys.path.append('..')
# local modules
from
pytorch.dataset
import
SparcsDataset
,
Cloud95Dataset
from
pytorch.
constant
s
import
SparcsLabels
,
Cloud95Labels
from
pytorch.
layer
s
import
Conv2dSame
class
NetworkTrainer
(
object
):
...
...
@@ -31,8 +31,6 @@ class NetworkTrainer(object):
for
k
,
v
in
config
.
items
():
setattr
(
self
,
k
,
v
)
def
initialize
(
self
):
# check which dataset the model is trained on
if
self
.
dataset_name
==
'
Sparcs
'
:
# instanciate the SparcsDataset
...
...
@@ -81,6 +79,12 @@ class NetworkTrainer(object):
self
.
dataset
,
self
.
tvratio
,
self
.
ttratio
,
self
.
seed
)
print
(
'
--------------------------------------------------------------
'
)
# number of batches in the validation set
self
.
nvbatches
=
int
(
len
(
self
.
valid_ds
)
/
self
.
batch_size
)
# number of batches in the training set
self
.
nbatches
=
int
(
len
(
self
.
train_ds
)
/
self
.
batch_size
)
# the training and validation dataloaders
self
.
train_dl
=
DataLoader
(
self
.
train_ds
,
self
.
batch_size
,
...
...
@@ -108,48 +112,49 @@ class NetworkTrainer(object):
self
.
batch_size
,
bformat
))
# check whether a pretrained model was used and change state filename
# accordingly
if
self
.
pretrained
:
# add the configuration of the pretrained model to the state name
self
.
state_file
=
(
self
.
state_file
.
replace
(
'
.pt
'
,
'
_
'
)
+
'
pretrained_
'
+
self
.
pretrained_model
)
# path to model state
self
.
state
=
os
.
path
.
join
(
self
.
state_path
,
self
.
state_file
)
# path to model loss/accuracy
self
.
loss_state
=
self
.
state
.
replace
(
'
.pt
'
,
'
_loss.pt
'
)
def
from_pretrained
(
self
):
# name of the dataset the pretrained model was trained on
dataset_name
=
self
.
pretrained_model
.
split
(
'
_
'
)[
1
]
# load the pretrained model
model_state
=
torch
.
load
(
os
.
path
.
join
(
self
.
state_path
,
self
.
pretrained_model
))
# input bands of the pretrained model
bands
=
self
.
pretrained_model
.
split
(
'
_
'
)[
-
1
].
split
(
'
.
'
)[
0
]
#
get the
input bands of the pretrained model
bands
=
model_state
[
'
bands
'
]
if
dataset_name
==
SparcsDataset
.
__name__
:
# get the number of convolutional filters
filters
=
model_state
[
'
params
'
][
'
filters
'
]
# number of input channels
in_channels
=
len
(
bands
)
if
bands
!=
'
all
'
else
10
# check whether the current dataset uses the correct spectral bands
if
self
.
bands
!=
bands
:
raise
ValueError
(
'
The bands of the pretrained network do not
'
'
match the specified bands: {}
'
.
format
(
self
.
bands
))
# instanciate pretrained model architecture
model
=
self
.
net
(
in_channels
=
in_channels
,
nclasses
=
len
(
SparcsLabels
),
filters
=
self
.
filters
,
skip
=
self
.
skip_connection
,
**
self
.
kwargs
)
if
dataset_name
==
Cloud95Dataset
.
__name__
:
# number of input channels
in_channels
=
len
(
bands
)
if
bands
!=
'
all
'
else
4
# instanciate pretrained model architecture
model
=
self
.
net
(
in_channels
=
in_channels
,
nclasses
=
len
(
Cloud95Labels
),
filters
=
self
.
filters
,
skip
=
self
.
skip_connection
,
**
self
.
kwargs
)
# instanciate pretrained model architecture
model
=
self
.
net
(
**
model_state
[
'
params
'
],
**
model_state
[
'
kwargs
'
])
# load pretrained model weights
model
.
load
(
self
.
pretrained_model
,
inpath
=
self
.
state_path
)
# adjust the classification layer to the number of classes of the
# current dataset
model
.
classifier
=
Conv2dSame
(
in_channels
=
filters
[
0
],
out_channels
=
len
(
self
.
dataset
.
labels
),
kernel_size
=
1
)
return
model
...
...
@@ -191,6 +196,22 @@ class NetworkTrainer(object):
def
accuracy_function
(
self
,
outputs
,
labels
):
return
(
outputs
==
labels
).
float
().
mean
()
def
_save_loss
(
self
,
training_state
,
checkpoint
=
False
,
checkpoint_state
=
None
):
# save losses and accuracy
if
checkpoint
and
checkpoint_state
is
not
None
:
# append values from checkpoint to current training
# state
torch
.
save
({
k1
:
np
.
hstack
([
v1
,
v2
])
for
(
k1
,
v1
),
(
k2
,
v2
)
in
zip
(
checkpoint_state
.
items
(),
training_state
.
items
())
if
k1
==
k2
},
self
.
loss_state
)
else
:
torch
.
save
(
training_state
,
self
.
loss_state
)
def
train
(
self
):
# set the number of threads
...
...
@@ -206,31 +227,36 @@ class NetworkTrainer(object):
# initial accuracy on the validation set
max_accuracy
=
0
# create dictionary of the observed losses and accuracies on the
# training and validation dataset
training_state
=
{
'
tl
'
:
np
.
zeros
(
shape
=
(
self
.
nbatches
,
self
.
epochs
)),
'
ta
'
:
np
.
zeros
(
shape
=
(
self
.
nbatches
,
self
.
epochs
)),
'
vl
'
:
np
.
zeros
(
shape
=
(
self
.
nvbatches
,
self
.
epochs
)),
'
va
'
:
np
.
zeros
(
shape
=
(
self
.
nvbatches
,
self
.
epochs
))
}
# whether to resume training from an existing model
checkpoint_state
=
None
if
os
.
path
.
exists
(
self
.
state
)
and
self
.
checkpoint
:
# load the model state
state
=
self
.
model
.
load
(
self
.
state_file
,
self
.
optimizer
,
self
.
state_path
)
print
(
'
Resuming training from {} ...
'
.
format
(
state
))
print
(
'
Model epoch: {:d}
'
.
format
(
self
.
model
.
epoch
))
# sen
d the model
to the gpu if available
self
.
model
=
self
.
model
.
to
(
self
.
devic
e
)
# loa
d the model
loss and accuracy
checkpoint_state
=
torch
.
load
(
self
.
loss_stat
e
)
# number of batches in the validation set
nvbatches
=
int
(
len
(
self
.
valid_ds
)
/
self
.
batch_size
)
# number of batches in the training set
nbatches
=
int
(
len
(
self
.
train_ds
)
/
self
.
batch_size
)
# get all non-zero elements, i.e. get number of epochs trained
# before the early stop
checkpoint_state
=
{
k
:
v
[
np
.
nonzero
(
v
)].
reshape
(
v
.
shape
[
0
],
-
1
)
for
k
,
v
in
checkpoint_state
.
items
()}
# create arrays of the observed losses and accuracies on the
# training set
losses
=
np
.
zeros
(
shape
=
(
nbatches
,
self
.
epochs
))
accuracies
=
np
.
zeros
(
shape
=
(
nbatches
,
self
.
epochs
))
# maximum accuracy on the validation set
max_accuracy
=
checkpoint_state
[
'
va
'
][:,
-
1
].
mean
().
item
()
# create arrays of the observed losses and accuracies on the
# validation set
vlosses
=
np
.
zeros
(
shape
=
(
nvbatches
,
self
.
epochs
))
vaccuracies
=
np
.
zeros
(
shape
=
(
nvbatches
,
self
.
epochs
))
# send the model to the gpu if available
self
.
model
=
self
.
model
.
to
(
self
.
device
)
# initialize the training: iterate over the entire training data set
for
epoch
in
range
(
self
.
epochs
):
...
...
@@ -254,7 +280,8 @@ class NetworkTrainer(object):
# compute loss
loss
=
self
.
loss_function
(
outputs
,
labels
.
long
())
losses
[
batch
,
epoch
]
=
loss
.
detach
().
numpy
().
item
()
observed_loss
=
loss
.
detach
().
numpy
().
item
()
training_state
[
'
tl
'
][
batch
,
epoch
]
=
observed_loss
# compute the gradients of the loss function w.r.t.
# the network weights
...
...
@@ -267,14 +294,17 @@ class NetworkTrainer(object):
ypred
=
F
.
softmax
(
outputs
,
dim
=
1
).
argmax
(
dim
=
1
)
# calculate accuracy on current batch
acc
=
self
.
accuracy_function
(
ypred
,
labels
)
accuracies
[
batch
,
epoch
]
=
acc
observed_accuracy
=
self
.
accuracy_function
(
ypred
,
labels
)
training_state
[
'
ta
'
][
batch
,
epoch
]
=
observed_accuracy
# print progress
print
(
'
Epoch: {:d}/{:d}, Batch: {:d}/{:d}, Loss: {:.2f},
'
'
Accuracy: {:.2f}
'
.
format
(
epoch
,
self
.
epochs
,
batch
,
nbatches
,
losses
[
batch
,
epoch
],
acc
))
'
Accuracy: {:.2f}
'
.
format
(
epoch
,
self
.
epochs
,
batch
,
self
.
nbatches
,
observed_loss
,
observed_accuracy
))
# update the number of epochs trained
self
.
model
.
epoch
+=
1
...
...
@@ -287,8 +317,8 @@ class NetworkTrainer(object):
_
,
vacc
,
vloss
=
self
.
predict
()
# append observed accuracy and loss to arrays
vaccuracies
[:,
epoch
]
=
vacc
.
squeeze
()
vlosses
[:,
epoch
]
=
vloss
.
squeeze
()
training_state
[
'
va
'
]
[:,
epoch
]
=
vacc
.
squeeze
()
training_state
[
'
vl
'
]
[:,
epoch
]
=
vloss
.
squeeze
()
# metric to assess model performance on the validation set
epoch_acc
=
vacc
.
squeeze
().
mean
()
...
...
@@ -298,37 +328,34 @@ class NetworkTrainer(object):
max_accuracy
=
epoch_acc
# save model state if the model improved with
# respect to the previous epoch
_
=
self
.
model
.
save
(
self
.
optimizer
,
self
.
state_file
,
_
=
self
.
model
.
save
(
self
.
state_file
,
self
.
optimizer
,
self
.
bands
,
self
.
state_path
)
# save losses and accuracy
self
.
_save_loss
(
training_state
,
self
.
checkpoint
,
checkpoint_state
)
# whether the early stopping criterion is met
if
es
.
stop
(
epoch_acc
):
# save losses and accuracy before exiting training
torch
.
save
({
'
epoch
'
:
epoch
,
'
training_loss
'
:
losses
,
'
training_accuracy
'
:
accuracies
,
'
validation_loss
'
:
vlosses
,
'
validation_accuracy
'
:
vaccuracies
},
self
.
loss_state
)
break
else
:
# if no early stopping is required, the model state is saved
# after each epoch
_
=
self
.
model
.
save
(
self
.
optimizer
,
self
.
state_file
,
_
=
self
.
model
.
save
(
self
.
state_file
,
self
.
optimizer
,
self
.
bands
,
self
.
state_path
)
# save losses and accuracy after each epoch to file
torch
.
save
({
'
epoch
'
:
epoch
,
'
training_loss
'
:
losses
,
'
training_accuracy
'
:
accuracies
,
'
validation_loss
'
:
vlosses
,
'
validation_accuracy
'
:
vaccuracies
},
self
.
loss_state
)
# save losses and accuracy after each epoch
self
.
_save_loss
(
training_state
,
self
.
checkpoint
,
checkpoint_state
)
return
losses
,
accuracies
,
vlosses
,
vaccuracies
return
training_state
def
predict
(
self
,
pretrained
=
False
,
confusion
=
False
):
...
...
@@ -347,12 +374,9 @@ class NetworkTrainer(object):
# initialize confusion matrix
cm
=
torch
.
zeros
(
self
.
model
.
nclasses
,
self
.
model
.
nclasses
)
# number of batches in the validation set
nbatches
=
int
(
len
(
self
.
valid_ds
)
/
self
.
batch_size
)
# create arrays of the observed losses and accuracies
accuracies
=
np
.
zeros
(
shape
=
(
n
batches
,
1
))
losses
=
np
.
zeros
(
shape
=
(
n
batches
,
1
))
accuracies
=
np
.
zeros
(
shape
=
(
self
.
nv
batches
,
1
))
losses
=
np
.
zeros
(
shape
=
(
self
.
nv
batches
,
1
))
# iterate over the validation/test set
print
(
'
Calculating accuracy on validation set ...
'
)
...
...
@@ -379,7 +403,8 @@ class NetworkTrainer(object):
# print progress
print
(
'
Batch: {:d}/{:d}, Accuracy: {:.2f}
'
.
format
(
batch
,
nbatches
,
acc
))
self
.
nvbatches
,
acc
))
# update confusion matrix
if
confusion
:
...
...
This diff is collapsed.
Click to expand it.
Preview
0%
Loading
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Save comment
Cancel
Please
register
or
sign in
to comment