Skip to content
GitLab
Explore
Sign in
Primary navigation
Search or go to…
Project
P
PySegCNN
Manage
Activity
Members
Labels
Plan
Issues
Issue boards
Milestones
Wiki
Code
Merge requests
Repository
Branches
Commits
Tags
Repository graph
Compare revisions
Snippets
Deploy
Releases
Package registry
Container Registry
Model registry
Operate
Terraform modules
Monitor
Incidents
Analyze
Value stream analytics
Contributor analytics
Repository analytics
Model experiments
Help
Help
Support
GitLab documentation
Compare GitLab plans
Community forum
Contribute to GitLab
Provide feedback
Keyboard shortcuts
?
Snippets
Groups
Projects
Show more breadcrumbs
earth_observation_public
PySegCNN
Commits
1e145d4d
Commit
1e145d4d
authored
4 years ago
by
Frisinghelli Daniel
Browse files
Options
Downloads
Patches
Plain Diff
Divided init method into smaller submethods; included support for data augmentations
parent
5e28a1c8
No related branches found
No related tags found
No related merge requests found
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
pytorch/trainer.py
+126
-107
126 additions, 107 deletions
pytorch/trainer.py
with
126 additions
and
107 deletions
pytorch/trainer.py
+
126
−
107
View file @
1e145d4d
...
...
@@ -18,6 +18,7 @@ from torch.utils.data import random_split, DataLoader
# local modules
from
pytorch.dataset
import
SparcsDataset
,
Cloud95Dataset
from
pytorch.layers
import
Conv2dSame
from
pytorch.constants
import
SupportedDatasets
class
NetworkTrainer
(
object
):
...
...
@@ -28,105 +29,26 @@ class NetworkTrainer(object):
for
k
,
v
in
config
.
items
():
setattr
(
self
,
k
,
v
)
# check which dataset the model is trained on
if
self
.
dataset_name
==
'
Sparcs
'
:
# instanciate the SparcsDataset
self
.
dataset
=
SparcsDataset
(
self
.
dataset_path
,
use_bands
=
self
.
bands
,
tile_size
=
self
.
tile_size
)
elif
self
.
dataset_name
==
'
Cloud95
'
:
# instanciate the Cloud95Dataset
self
.
dataset
=
Cloud95Dataset
(
self
.
dataset_path
,
use_bands
=
self
.
bands
,
tile_size
=
self
.
tile_size
,
exclude
=
self
.
patches
)
else
:
raise
ValueError
(
'
{} is not a valid dataset. Available datasets
'
'
are
"
Sparcs
"
and
"
Cloud95
"
.
'
.
format
(
self
.
dataset_name
))
# print the bands used for the segmentation
print
(
'
------------------------ Input bands -------------------------
'
)
print
(
*
[
'
Band {}: {}
'
.
format
(
i
,
b
)
for
i
,
b
in
enumerate
(
self
.
dataset
.
use_bands
)],
sep
=
'
\n
'
)
print
(
'
--------------------------------------------------------------
'
)
# print the classes of interest
print
(
'
-------------------------- Classes ---------------------------
'
)
print
(
*
[
'
Class {}: {}
'
.
format
(
k
,
v
[
'
label
'
])
for
k
,
v
in
self
.
dataset
.
labels
.
items
()],
sep
=
'
\n
'
)
print
(
'
--------------------------------------------------------------
'
)
# instanciate the segmentation network
print
(
'
------------------- Network architecture ---------------------
'
)
if
self
.
pretrained
:
self
.
model
=
self
.
from_pretrained
()
else
:
self
.
model
=
self
.
net
(
in_channels
=
len
(
self
.
dataset
.
use_bands
),
nclasses
=
len
(
self
.
dataset
.
labels
),
filters
=
self
.
filters
,
skip
=
self
.
skip_connection
,
**
self
.
kwargs
)
print
(
self
.
model
)
print
(
'
--------------------------------------------------------------
'
)
# the training and validation dataset
print
(
'
------------------------ Dataset split -----------------------
'
)
self
.
train_ds
,
self
.
valid_ds
,
self
.
test_ds
=
self
.
train_val_test_split
(
self
.
dataset
,
self
.
tvratio
,
self
.
ttratio
,
self
.
seed
)
print
(
'
--------------------------------------------------------------
'
)
# number of batches in the validation set
self
.
nvbatches
=
int
(
len
(
self
.
valid_ds
)
/
self
.
batch_size
)
# number of batches in the training set
self
.
nbatches
=
int
(
len
(
self
.
train_ds
)
/
self
.
batch_size
)
# the training and validation dataloaders
self
.
train_dl
=
DataLoader
(
self
.
train_ds
,
self
.
batch_size
,
shuffle
=
True
,
drop_last
=
True
)
self
.
valid_dl
=
DataLoader
(
self
.
valid_ds
,
self
.
batch_size
,
shuffle
=
True
,
drop_last
=
True
)
# the optimizer used to update the model weights
self
.
optimizer
=
self
.
optimizer
(
self
.
model
.
parameters
(),
self
.
lr
)
# whether to use the gpu
self
.
device
=
torch
.
device
(
"
cuda:0
"
if
torch
.
cuda
.
is_available
()
else
"
cpu
"
)
# file to save model state to
# format: networkname_datasetname_t(tilesize)_b(batchsize)_bands.pt
bformat
=
''
.
join
([
b
[
0
]
for
b
in
self
.
bands
])
if
self
.
bands
else
'
all
'
self
.
state_file
=
(
'
{}_{}_t{}_b{}_{}.pt
'
.
format
(
self
.
model
.
__class__
.
__name__
,
self
.
dataset
.
__class__
.
__name__
,
self
.
tile_size
,
self
.
batch_size
,
bformat
))
# check whether a pretrained model was used and change state filename
# accordingly
if
self
.
pretrained
:
# add the configuration of the pretrained model to the state name
self
.
state_file
=
(
self
.
state_file
.
replace
(
'
.pt
'
,
'
_
'
)
+
'
pretrained_
'
+
self
.
pretrained_model
)
# path to model state
self
.
state
=
os
.
path
.
join
(
self
.
state_path
,
self
.
state_file
)
# initialize the dataset to train the model on
self
.
_init_dataset
()
#
path to model loss/accuracy
self
.
loss_state
=
self
.
state
.
replace
(
'
.pt
'
,
'
_loss.pt
'
)
#
initialize the model
self
.
_init_model
(
)
def
from_pretrained
(
self
):
# load the pretrained model
model_state
=
torch
.
load
(
os
.
path
.
join
(
self
.
state_path
,
self
.
pretrained_model
))
model_state
=
os
.
path
.
join
(
self
.
state_path
,
self
.
pretrained_model
)
if
not
os
.
path
.
exists
(
model_state
):
raise
FileNotFoundError
(
'
Pretrained model {} does not exist.
'
.
format
(
model_state
))
# load the model state
model_state
=
torch
.
load
(
model_state
)
# get the input bands of the pretrained model
bands
=
model_state
[
'
bands
'
]
...
...
@@ -152,8 +74,10 @@ class NetworkTrainer(object):
out_channels
=
len
(
self
.
dataset
.
labels
),
kernel_size
=
1
)
return
model
# adjust the number of classes in the model
model
.
nclasses
=
len
(
self
.
dataset
.
labels
)
return
model
def
ds_len
(
self
,
ds
,
ratio
):
return
int
(
np
.
round
(
len
(
ds
)
*
ratio
))
...
...
@@ -193,22 +117,6 @@ class NetworkTrainer(object):
def
accuracy_function
(
self
,
outputs
,
labels
):
return
(
outputs
==
labels
).
float
().
mean
()
def
_save_loss
(
self
,
training_state
,
checkpoint
=
False
,
checkpoint_state
=
None
):
# save losses and accuracy
if
checkpoint
and
checkpoint_state
is
not
None
:
# append values from checkpoint to current training
# state
torch
.
save
({
k1
:
np
.
hstack
([
v1
,
v2
])
for
(
k1
,
v1
),
(
k2
,
v2
)
in
zip
(
checkpoint_state
.
items
(),
training_state
.
items
())
if
k1
==
k2
},
self
.
loss_state
)
else
:
torch
.
save
(
training_state
,
self
.
loss_state
)
def
train
(
self
):
# set the number of threads
...
...
@@ -418,6 +326,117 @@ class NetworkTrainer(object):
return
cm
,
accuracies
,
losses
def
_init_dataset
(
self
):
# check whether the dataset is currently supported
self
.
dataset
=
None
for
dataset
in
SupportedDatasets
:
if
self
.
dataset_name
==
dataset
.
name
:
self
.
dataset
=
dataset
.
value
[
'
class
'
](
self
.
dataset_path
,
self
.
bands
,
self
.
tile_size
,
self
.
sort
,
self
.
transforms
)
if
self
.
dataset
is
None
:
print
(
'
{} is not a valid dataset.
'
).
format
(
self
.
dataset_name
)
print
(
'
Available datasets are:
'
)
for
name
,
_
in
SupportedDatasets
.
__members__
.
items
():
print
(
name
)
raise
ValueError
(
'
Dataset not supported.
'
)
# print the bands used for the segmentation
print
(
'
------------------------ Input bands -------------------------
'
)
print
(
*
[
'
Band {}: {}
'
.
format
(
i
,
b
)
for
i
,
b
in
enumerate
(
self
.
dataset
.
use_bands
)],
sep
=
'
\n
'
)
print
(
'
--------------------------------------------------------------
'
)
# print the classes of interest
print
(
'
-------------------------- Classes ---------------------------
'
)
print
(
*
[
'
Class {}: {}
'
.
format
(
k
,
v
[
'
label
'
])
for
k
,
v
in
self
.
dataset
.
labels
.
items
()],
sep
=
'
\n
'
)
print
(
'
--------------------------------------------------------------
'
)
# the training and validation dataset
print
(
'
------------------------ Dataset split -----------------------
'
)
self
.
train_ds
,
self
.
valid_ds
,
self
.
test_ds
=
self
.
train_val_test_split
(
self
.
dataset
,
self
.
tvratio
,
self
.
ttratio
,
self
.
seed
)
# number of batches in the validation set
self
.
nvbatches
=
int
(
len
(
self
.
valid_ds
)
/
self
.
batch_size
)
# number of batches in the training set
self
.
nbatches
=
int
(
len
(
self
.
train_ds
)
/
self
.
batch_size
)
print
(
'
Number of training batches: {}
'
.
format
(
self
.
nbatches
))
print
(
'
Number of validation batches: {}
'
.
format
(
self
.
nvbatches
))
print
(
'
--------------------------------------------------------------
'
)
# the training and validation dataloaders
self
.
train_dl
=
DataLoader
(
self
.
train_ds
,
self
.
batch_size
,
shuffle
=
True
,
drop_last
=
True
)
self
.
valid_dl
=
DataLoader
(
self
.
valid_ds
,
self
.
batch_size
,
shuffle
=
True
,
drop_last
=
True
)
def
_init_model
(
self
):
# instanciate the segmentation network
print
(
'
------------------- Network architecture ---------------------
'
)
if
self
.
pretrained
:
self
.
model
=
self
.
from_pretrained
()
else
:
self
.
model
=
self
.
net
(
in_channels
=
len
(
self
.
dataset
.
use_bands
),
nclasses
=
len
(
self
.
dataset
.
labels
),
filters
=
self
.
filters
,
skip
=
self
.
skip_connection
,
**
self
.
kwargs
)
print
(
self
.
model
)
print
(
'
--------------------------------------------------------------
'
)
# the optimizer used to update the model weights
self
.
optimizer
=
self
.
optimizer
(
self
.
model
.
parameters
(),
self
.
lr
)
# file to save model state to
# format: networkname_datasetname_t(tilesize)_b(batchsize)_bands.pt
bformat
=
''
.
join
([
b
[
0
]
for
b
in
self
.
bands
])
if
self
.
bands
else
'
all
'
self
.
state_file
=
(
'
{}_{}_t{}_b{}_{}.pt
'
.
format
(
self
.
model
.
__class__
.
__name__
,
self
.
dataset
.
__class__
.
__name__
,
self
.
tile_size
,
self
.
batch_size
,
bformat
))
# check whether a pretrained model was used and change state filename
# accordingly
if
self
.
pretrained
:
# add the configuration of the pretrained model to the state name
self
.
state_file
=
(
self
.
state_file
.
replace
(
'
.pt
'
,
'
_
'
)
+
'
pretrained_
'
+
self
.
pretrained_model
)
# path to model state
self
.
state
=
os
.
path
.
join
(
self
.
state_path
,
self
.
state_file
)
# path to model loss/accuracy
self
.
loss_state
=
self
.
state
.
replace
(
'
.pt
'
,
'
_loss.pt
'
)
def
_save_loss
(
self
,
training_state
,
checkpoint
=
False
,
checkpoint_state
=
None
):
# save losses and accuracy
if
checkpoint
and
checkpoint_state
is
not
None
:
# append values from checkpoint to current training
# state
torch
.
save
({
k1
:
np
.
hstack
([
v1
,
v2
])
for
(
k1
,
v1
),
(
k2
,
v2
)
in
zip
(
checkpoint_state
.
items
(),
training_state
.
items
())
if
k1
==
k2
},
self
.
loss_state
)
else
:
torch
.
save
(
training_state
,
self
.
loss_state
)
class
EarlyStopping
(
object
):
...
...
This diff is collapsed.
Click to expand it.
Preview
0%
Loading
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Save comment
Cancel
Please
register
or
sign in
to comment