Skip to content
Snippets Groups Projects
Commit 2b0eca74 authored by Frisinghelli Daniel's avatar Frisinghelli Daniel
Browse files

Major generalization of Network class; added support for AdamW and AMSgrad.

parent 91726ddd
No related branches found
No related tags found
No related merge requests found
......@@ -28,14 +28,14 @@ import torch.optim as optim
# locals
from pysegcnn.core.layers import (Encoder, Decoder, ConvBnReluMaxPool,
ConvBnReluMaxUnpool, Conv2dSame)
from pysegcnn.core.utils import check_filename_length
from pysegcnn.core.utils import check_filename_length, item_in_enum
# module level logger
LOGGER = logging.getLogger(__name__)
class Network(nn.Module):
"""Generic Network class.
"""Generic neural network class for image classification tasks.
The base class for each model. If you want to implement a new model,
inherit the :py:class:`pysegcnn.core.models.Network` class.
......@@ -44,16 +44,26 @@ class Network(nn.Module):
----------
state_file : `str` or `None` or :py:class:`pathlib.Path`
The model state file, where the model parameters are saved.
in_channels : `int`
Number of channels of the input images.
nclasses : `int`
Number of classes.
epoch : `int`
Number of epochs the network was trained.
"""
def __init__(self, state_file=None):
def __init__(self, state_file, in_channels, nclasses):
"""Initialize.
Parameters
----------
state_file : `str` or `None` or :py:class:`pathlib.Path`
The model state file, where the model parameters are saved.
in_channels : `int`
Number of channels of the input images.
nclasses : `int`
Number of classes.
"""
super().__init__()
......@@ -61,6 +71,12 @@ class Network(nn.Module):
# initialize state file
self.state_file = state_file
# number of spectral bands of the input images
self.in_channels = in_channels
# number of output classes
self.nclasses = nclasses
# number of epochs trained
self.epoch = 0
......@@ -111,10 +127,7 @@ class Network(nn.Module):
param.requires_grad = True
def save(self, state_file, optimizer, **kwargs):
"""Save the model state.
Saves the model and optimizer states together with the model
construction parameters, to easily re-instanciate the model.
"""Save the model and optimizer state.
Optional ``kwargs`` are also saved.
......@@ -142,27 +155,6 @@ class Network(nn.Module):
# initialize dictionary to store network parameters
model_state = {**kwargs}
# store the spectral bands the model is trained with
# model_state['bands'] = bands
# store model and optimizer class
# model_state['cls'] = self.__class__
# model_state['optim_cls'] = optimizer.__class__
# store construction parameters to instanciate the network
# model_state['params'] = {
# 'skip': self.skip,
# 'filters': self.filters[1:],
# 'nclasses': self.nclasses,
# 'in_channels': self.in_channels
# }
# store optimizer construction parameters
# model_state['optim_params'] = optimizer.defaults
# store optional keyword arguments
# model_state['params'] = {**model_state['params'], **self.kwargs}
# store model epoch
model_state['epoch'] = self.epoch
......@@ -177,22 +169,13 @@ class Network(nn.Module):
return model_state
@staticmethod
def load(model, optimizer, state_file):
"""Load a model state.
Returns the model in ``state_file`` with the pretrained model and
optimizer weights. Useful when resuming training an existing model.
def load(state_file):
"""Load a model state file.
Parameters
----------
model : :py:class:`pysegcnn.core.models.Network`
An instance of the model for which the pretrained weights are
stored in ``state_file``.
optimizer : :py:class:`torch.optim.Optimizer`
An instance of the optimizer used to train ``model``.
state_file : `str` or :py:class:`pathlib.Path`
The model state file containing the pretrained parameters for
``model`` and ``optimizer``.
The model state file containing the pretrained parameters.
Raises
------
......@@ -215,40 +198,106 @@ class Network(nn.Module):
# load the model state
model_state = torch.load(state_file)
# the model and optimizer class
# model_class = model_state['cls']
# optim_class = model_state['optim_cls']
return model_state
@staticmethod
def load_pretrained_model_weights(model, model_state):
"""Load the pretrained model weights from a state file.
# instanciate pretrained model architecture
# model = model_class(**model_state['params'])
Parameters
----------
model : :py:class:`pysegcnn.core.models.Network`
An instance of the model for which the pretrained weights are
stored in ``model_state``.
model_state : `dict`
A dictionary containing the model and optimizer state, as
constructed by :py:meth:`~pysegcnn.core.Network.save`.
# store state file as instance attribute
model.state_file = state_file
Returns
-------
model : :py:class:`pysegcnn.core.models.Network`
An instance of the pretrained model in ``model_state``.
"""
# load pretrained model weights
LOGGER.info('Loading model parameters ...')
model.load_state_dict(model_state['model_state_dict'])
# set model epoch
model.epoch = model_state['epoch']
LOGGER.info('Model epoch: {:d}'.format(model.epoch))
return model
@staticmethod
def load_pretrained_optimizer_weights(optimizer, model_state):
"""Load the pretrained optimizer weights from a state file.
Parameters
----------
optimizer : :py:class:`torch.optim.Optimizer`
An instance of the optimizer used to train ``model`` for which the
pretrained weights are stored in ``model_state``.
model_state : `dict`
A dictionary containing the model and optimizer state, as
constructed by :py:meth:`~pysegcnn.core.Network.save`.
Returns
-------
optimizer : :py:class:`torch.optim.Optimizer`
An instance of the pretrained optimizer in ``model_state``.
"""
# resume optimizer parameters
LOGGER.info('Loading optimizer parameters ...')
optimizer.load_state_dict(model_state['optim_state_dict'])
LOGGER.info('Model epoch: {:d}'.format(model.epoch))
return optimizer
return model_state
@staticmethod
def load_pretrained_model(state_file):
"""Load an instance of the pretrained model in ``state_file``.
@property
def state(self):
"""Return the model state file.
Parameters
----------
state_file : `str` or :py:class:`pathlib.Path`
The model state file containing the pretrained parameters.
Returns
-------
state_file : :py:class:`pathlib.Path` or `None`
The model state file.
model : :py:class:`pysegcnn.core.models.Network`
An instance of the pretrained model in ``state_file``.
optimizer : :py:class:`torch.optim.Optimizer`
An instance of the pretrained optimizer in ``state_file``.
"""
return self.state_file
# get the model class of the pretrained model
model_class = item_in_enum(str(state_file).split('_')[0],
SupportedModels)
# get the optimizer class of the pretrained model
optim_class = item_in_enum(str(state_file).split('_')[1],
SupportedOptimizers)
# load the pretrained model configuration
model_state = Network.load(state_file)
# instanciate the pretrained model architecture
model = model_class(state_file=state_file,
in_channels=len(model_state['bands']),
nclasses=model_state['nclasses'])
# instanciate the optimizer
optimizer = optim_class(model.parameters())
# load pretrained model weights
model = Network.load_pretrained_model_weights(model, model_state)
# load pretrained optimizer weights
optimizer = Network.load_pretrained_optimizer_weights(optimizer,
model_state)
return model, optimizer
class EncoderDecoderNetwork(Network):
......@@ -256,6 +305,8 @@ class EncoderDecoderNetwork(Network):
Attributes
----------
state_file : `str` or `None` or :py:class:`pathlib.Path`
The model state file, where the model parameters are saved.
in_channels : `int`
Number of channels of the input images.
nclasses : `int`
......@@ -278,12 +329,14 @@ class EncoderDecoderNetwork(Network):
"""
def __init__(self, in_channels, nclasses, encoder_block, decoder_block,
filters, skip, **kwargs):
def __init__(self, state_file, in_channels, nclasses, encoder_block,
decoder_block, filters, skip, **kwargs):
"""Initialize.
Parameters
----------
state_file : `str` or `None` or :py:class:`pathlib.Path`
The model state file, where the model parameters are saved.
in_channels : `int`
Number of channels of the input images.
nclasses : `int`
......@@ -305,13 +358,7 @@ class EncoderDecoderNetwork(Network):
:py:class:`pysegcnn.core.layers.Conv2dSame`.
"""
super().__init__()
# number of input channels
self.in_channels = in_channels
# number of classes
self.nclasses = nclasses
super().__init__(state_file, in_channels, nclasses)
# number of convolutional filters for each block
self.filters = np.hstack([np.array(in_channels), np.array(filters)])
......@@ -370,6 +417,8 @@ class SegNet(EncoderDecoderNetwork):
Attributes
----------
state_file : `str` or `None` or :py:class:`pathlib.Path`
The model state file, where the model parameters are saved.
in_channels : `int`
Number of channels of the input images.
nclasses : `int`
......@@ -392,25 +441,33 @@ class SegNet(EncoderDecoderNetwork):
"""
def __init__(self, in_channels, nclasses, filters, skip, **kwargs):
def __init__(self, state_file, in_channels, nclasses,
filters=[32, 64, 128, 256], skip=True,
kwargs={'kernel_size': 3, 'stride': 1, 'dilation': 1}):
"""Initialize.
Parameters
----------
state_file : `str` or `None` or :py:class:`pathlib.Path`
The model state file, where the model parameters are saved.
in_channels : `int`
Number of channels of the input images.
nclasses : `int`
Number of classes.
filters : `list` [`int`]
List of input channels to each convolutional block.
skip : `bool`
filters : `list` [`int`], optional
List of input channels to each convolutional block. The default is
`[32, 64, 128, 256]`.
skip : `bool`, optional
Whether to apply skip connections from the encoder to the decoder.
**kwargs: `dict` [`str`]
The default is `True`.
kwargs: `dict` [`str`: `int`]
Additional keyword arguments passed to
:py:class:`pysegcnn.core.layers.Conv2dSame`.
:py:class:`pysegcnn.core.layers.Conv2dSame`. The default is
`{'kernel_size': 3, 'stride': 1, 'dilation': 1}`.
"""
super().__init__(in_channels=in_channels,
super().__init__(state_file=state_file,
in_channels=in_channels,
nclasses=nclasses,
encoder_block=ConvBnReluMaxPool,
decoder_block=ConvBnReluMaxUnpool,
......@@ -429,6 +486,7 @@ class SupportedOptimizers(enum.Enum):
"""Names and corresponding classes of the tested optimizers."""
Adam = optim.Adam
AdamW = optim.AdamW
class SupportedLossFunctions(enum.Enum):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment