Skip to content
GitLab
Explore
Sign in
Primary navigation
Search or go to…
Project
P
PySegCNN
Manage
Activity
Members
Labels
Plan
Issues
Issue boards
Milestones
Wiki
Code
Merge requests
Repository
Branches
Commits
Tags
Repository graph
Compare revisions
Snippets
Deploy
Releases
Package registry
Container Registry
Model registry
Operate
Terraform modules
Monitor
Incidents
Analyze
Value stream analytics
Contributor analytics
Repository analytics
Model experiments
Help
Help
Support
GitLab documentation
Compare GitLab plans
Community forum
Contribute to GitLab
Provide feedback
Keyboard shortcuts
?
Snippets
Groups
Projects
Show more breadcrumbs
earth_observation_public
PySegCNN
Commits
2b0eca74
Commit
2b0eca74
authored
4 years ago
by
Frisinghelli Daniel
Browse files
Options
Downloads
Patches
Plain Diff
Major generalization of Network class; added support for AdamW and AMSgrad.
parent
91726ddd
No related branches found
No related tags found
No related merge requests found
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
pysegcnn/core/models.py
+129
-71
129 additions, 71 deletions
pysegcnn/core/models.py
with
129 additions
and
71 deletions
pysegcnn/core/models.py
+
129
−
71
View file @
2b0eca74
...
...
@@ -28,14 +28,14 @@ import torch.optim as optim
# locals
from
pysegcnn.core.layers
import
(
Encoder
,
Decoder
,
ConvBnReluMaxPool
,
ConvBnReluMaxUnpool
,
Conv2dSame
)
from
pysegcnn.core.utils
import
check_filename_length
from
pysegcnn.core.utils
import
check_filename_length
,
item_in_enum
# module level logger
LOGGER
=
logging
.
getLogger
(
__name__
)
class
Network
(
nn
.
Module
):
"""
Generic
N
etwork class.
"""
Generic
neural n
etwork class
for image classification tasks
.
The base class for each model. If you want to implement a new model,
inherit the :py:class:`pysegcnn.core.models.Network` class.
...
...
@@ -44,16 +44,26 @@ class Network(nn.Module):
----------
state_file : `str` or `None` or :py:class:`pathlib.Path`
The model state file, where the model parameters are saved.
in_channels : `int`
Number of channels of the input images.
nclasses : `int`
Number of classes.
epoch : `int`
Number of epochs the network was trained.
"""
def
__init__
(
self
,
state_file
=
None
):
def
__init__
(
self
,
state_file
,
in_channels
,
nclasses
):
"""
Initialize.
Parameters
----------
state_file : `str` or `None` or :py:class:`pathlib.Path`
The model state file, where the model parameters are saved.
in_channels : `int`
Number of channels of the input images.
nclasses : `int`
Number of classes.
"""
super
().
__init__
()
...
...
@@ -61,6 +71,12 @@ class Network(nn.Module):
# initialize state file
self
.
state_file
=
state_file
# number of spectral bands of the input images
self
.
in_channels
=
in_channels
# number of output classes
self
.
nclasses
=
nclasses
# number of epochs trained
self
.
epoch
=
0
...
...
@@ -111,10 +127,7 @@ class Network(nn.Module):
param
.
requires_grad
=
True
def
save
(
self
,
state_file
,
optimizer
,
**
kwargs
):
"""
Save the model state.
Saves the model and optimizer states together with the model
construction parameters, to easily re-instanciate the model.
"""
Save the model and optimizer state.
Optional ``kwargs`` are also saved.
...
...
@@ -142,27 +155,6 @@ class Network(nn.Module):
# initialize dictionary to store network parameters
model_state
=
{
**
kwargs
}
# store the spectral bands the model is trained with
# model_state['bands'] = bands
# store model and optimizer class
# model_state['cls'] = self.__class__
# model_state['optim_cls'] = optimizer.__class__
# store construction parameters to instanciate the network
# model_state['params'] = {
# 'skip': self.skip,
# 'filters': self.filters[1:],
# 'nclasses': self.nclasses,
# 'in_channels': self.in_channels
# }
# store optimizer construction parameters
# model_state['optim_params'] = optimizer.defaults
# store optional keyword arguments
# model_state['params'] = {**model_state['params'], **self.kwargs}
# store model epoch
model_state
[
'
epoch
'
]
=
self
.
epoch
...
...
@@ -177,22 +169,13 @@ class Network(nn.Module):
return
model_state
@staticmethod
def
load
(
model
,
optimizer
,
state_file
):
"""
Load a model state.
Returns the model in ``state_file`` with the pretrained model and
optimizer weights. Useful when resuming training an existing model.
def
load
(
state_file
):
"""
Load a model state file.
Parameters
----------
model : :py:class:`pysegcnn.core.models.Network`
An instance of the model for which the pretrained weights are
stored in ``state_file``.
optimizer : :py:class:`torch.optim.Optimizer`
An instance of the optimizer used to train ``model``.
state_file : `str` or :py:class:`pathlib.Path`
The model state file containing the pretrained parameters for
``model`` and ``optimizer``.
The model state file containing the pretrained parameters.
Raises
------
...
...
@@ -215,40 +198,106 @@ class Network(nn.Module):
# load the model state
model_state
=
torch
.
load
(
state_file
)
# the model and optimizer class
# model_class = model_state['cls']
# optim_class = model_state['optim_cls']
return
model_state
@staticmethod
def
load_pretrained_model_weights
(
model
,
model_state
):
"""
Load the pretrained model weights from a state file.
# instanciate pretrained model architecture
# model = model_class(**model_state['params'])
Parameters
----------
model : :py:class:`pysegcnn.core.models.Network`
An instance of the model for which the pretrained weights are
stored in ``model_state``.
model_state : `dict`
A dictionary containing the model and optimizer state, as
constructed by :py:meth:`~pysegcnn.core.Network.save`.
# store state file as instance attribute
model
.
state_file
=
state_file
Returns
-------
model : :py:class:`pysegcnn.core.models.Network`
An instance of the pretrained model in ``model_state``.
"""
# load pretrained model weights
LOGGER
.
info
(
'
Loading model parameters ...
'
)
model
.
load_state_dict
(
model_state
[
'
model_state_dict
'
])
# set model epoch
model
.
epoch
=
model_state
[
'
epoch
'
]
LOGGER
.
info
(
'
Model epoch: {:d}
'
.
format
(
model
.
epoch
))
return
model
@staticmethod
def
load_pretrained_optimizer_weights
(
optimizer
,
model_state
):
"""
Load the pretrained optimizer weights from a state file.
Parameters
----------
optimizer : :py:class:`torch.optim.Optimizer`
An instance of the optimizer used to train ``model`` for which the
pretrained weights are stored in ``model_state``.
model_state : `dict`
A dictionary containing the model and optimizer state, as
constructed by :py:meth:`~pysegcnn.core.Network.save`.
Returns
-------
optimizer : :py:class:`torch.optim.Optimizer`
An instance of the pretrained optimizer in ``model_state``.
"""
# resume optimizer parameters
LOGGER
.
info
(
'
Loading optimizer parameters ...
'
)
optimizer
.
load_state_dict
(
model_state
[
'
optim_state_dict
'
])
LOGGER
.
info
(
'
Model epoch: {:d}
'
.
format
(
model
.
epoch
))
return
optimizer
return
model_state
@staticmethod
def
load_pretrained_model
(
state_file
):
"""
Load an instance of the pretrained model in ``state_file``.
@property
def
state
(
self
):
"""
Return the model state file.
Parameters
----------
state_file : `str` or :py:class:`pathlib.Path`
The model state file containing the pretrained parameters.
Returns
-------
state_file : :py:class:`pathlib.Path` or `None`
The model state file.
model : :py:class:`pysegcnn.core.models.Network`
An instance of the pretrained model in ``state_file``.
optimizer : :py:class:`torch.optim.Optimizer`
An instance of the pretrained optimizer in ``state_file``.
"""
return
self
.
state_file
# get the model class of the pretrained model
model_class
=
item_in_enum
(
str
(
state_file
).
split
(
'
_
'
)[
0
],
SupportedModels
)
# get the optimizer class of the pretrained model
optim_class
=
item_in_enum
(
str
(
state_file
).
split
(
'
_
'
)[
1
],
SupportedOptimizers
)
# load the pretrained model configuration
model_state
=
Network
.
load
(
state_file
)
# instanciate the pretrained model architecture
model
=
model_class
(
state_file
=
state_file
,
in_channels
=
len
(
model_state
[
'
bands
'
]),
nclasses
=
model_state
[
'
nclasses
'
])
# instanciate the optimizer
optimizer
=
optim_class
(
model
.
parameters
())
# load pretrained model weights
model
=
Network
.
load_pretrained_model_weights
(
model
,
model_state
)
# load pretrained optimizer weights
optimizer
=
Network
.
load_pretrained_optimizer_weights
(
optimizer
,
model_state
)
return
model
,
optimizer
class
EncoderDecoderNetwork
(
Network
):
...
...
@@ -256,6 +305,8 @@ class EncoderDecoderNetwork(Network):
Attributes
----------
state_file : `str` or `None` or :py:class:`pathlib.Path`
The model state file, where the model parameters are saved.
in_channels : `int`
Number of channels of the input images.
nclasses : `int`
...
...
@@ -278,12 +329,14 @@ class EncoderDecoderNetwork(Network):
"""
def
__init__
(
self
,
in_channels
,
nclasses
,
encoder_block
,
decoder_block
,
filters
,
skip
,
**
kwargs
):
def
__init__
(
self
,
state_file
,
in_channels
,
nclasses
,
encoder_block
,
decoder_block
,
filters
,
skip
,
**
kwargs
):
"""
Initialize.
Parameters
----------
state_file : `str` or `None` or :py:class:`pathlib.Path`
The model state file, where the model parameters are saved.
in_channels : `int`
Number of channels of the input images.
nclasses : `int`
...
...
@@ -305,13 +358,7 @@ class EncoderDecoderNetwork(Network):
:py:class:`pysegcnn.core.layers.Conv2dSame`.
"""
super
().
__init__
()
# number of input channels
self
.
in_channels
=
in_channels
# number of classes
self
.
nclasses
=
nclasses
super
().
__init__
(
state_file
,
in_channels
,
nclasses
)
# number of convolutional filters for each block
self
.
filters
=
np
.
hstack
([
np
.
array
(
in_channels
),
np
.
array
(
filters
)])
...
...
@@ -370,6 +417,8 @@ class SegNet(EncoderDecoderNetwork):
Attributes
----------
state_file : `str` or `None` or :py:class:`pathlib.Path`
The model state file, where the model parameters are saved.
in_channels : `int`
Number of channels of the input images.
nclasses : `int`
...
...
@@ -392,25 +441,33 @@ class SegNet(EncoderDecoderNetwork):
"""
def
__init__
(
self
,
in_channels
,
nclasses
,
filters
,
skip
,
**
kwargs
):
def
__init__
(
self
,
state_file
,
in_channels
,
nclasses
,
filters
=
[
32
,
64
,
128
,
256
],
skip
=
True
,
kwargs
=
{
'
kernel_size
'
:
3
,
'
stride
'
:
1
,
'
dilation
'
:
1
}):
"""
Initialize.
Parameters
----------
state_file : `str` or `None` or :py:class:`pathlib.Path`
The model state file, where the model parameters are saved.
in_channels : `int`
Number of channels of the input images.
nclasses : `int`
Number of classes.
filters : `list` [`int`]
List of input channels to each convolutional block.
skip : `bool`
filters : `list` [`int`], optional
List of input channels to each convolutional block. The default is
`[32, 64, 128, 256]`.
skip : `bool`, optional
Whether to apply skip connections from the encoder to the decoder.
**kwargs: `dict` [`str`]
The default is `True`.
kwargs: `dict` [`str`: `int`]
Additional keyword arguments passed to
:py:class:`pysegcnn.core.layers.Conv2dSame`.
:py:class:`pysegcnn.core.layers.Conv2dSame`. The default is
`{
'
kernel_size
'
: 3,
'
stride
'
: 1,
'
dilation
'
: 1}`.
"""
super
().
__init__
(
in_channels
=
in_channels
,
super
().
__init__
(
state_file
=
state_file
,
in_channels
=
in_channels
,
nclasses
=
nclasses
,
encoder_block
=
ConvBnReluMaxPool
,
decoder_block
=
ConvBnReluMaxUnpool
,
...
...
@@ -429,6 +486,7 @@ class SupportedOptimizers(enum.Enum):
"""
Names and corresponding classes of the tested optimizers.
"""
Adam
=
optim
.
Adam
AdamW
=
optim
.
AdamW
class
SupportedLossFunctions
(
enum
.
Enum
):
...
...
This diff is collapsed.
Click to expand it.
Preview
0%
Loading
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Save comment
Cancel
Please
register
or
sign in
to comment