Source code for core.models

"""Neural networks for semantic image segmentation.

License
-------

    Copyright (c) 2020 Daniel Frisinghelli

    This source code is licensed under the GNU General Public License v3.

    See the LICENSE file in the repository's root directory.

"""

# !/usr/bin/env python
# -*- coding: utf-8 -*-

# builtins
import enum
import logging
import pathlib

# externals
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim

# locals
from pysegcnn.core.layers import (Encoder, Decoder, ConvBnReluMaxPool,
                                  ConvBnReluMaxUnpool, Conv2dSame)

# module level logger
LOGGER = logging.getLogger(__name__)


[docs]class Network(nn.Module): """Generic Network class. The base class for each model. If you want to implement a new model, inherit the ``~pysegcnn.core.models.Network`` class. Returns ------- None. """ def __init__(self): super().__init__() # initialize state file self.state_file = None
[docs] def freeze(self): """Freeze the weights of a model. Disables gradient computation: useful when using a pretrained model for inference. Returns ------- None. """ for param in self.parameters(): param.requires_grad = False
[docs] def unfreeze(self): """Unfreeze the weights of a model. Enables gradient computation: useful when adjusting a pretrained model to a new dataset. Returns ------- None. """ for param in self.parameters(): param.requires_grad = True
[docs] def save(self, state_file, optimizer, bands=None, **kwargs): """Save the model state. Saves the model and optimizer states together with the model construction parameters, to easily re-instanciate the model. Optional ``kwargs`` are also saved. Parameters ---------- state_file : `str` or `pathlib.Path` Path to save the model state. optimizer : `torch.optim.Optimizer` The optimizer used to train the model. bands : `list` [`str`] or `None`, optional List of bands the model is trained with. The default is None. **kwargs Arbitrary keyword arguments. Each keyword argument will be saved as (key, value) pair in ``state_file``. Returns ------- model_state : `dict` A dictionary containing the model and optimizer state """ # check if the output path exists and if not, create it state_file = pathlib.Path(state_file) if not state_file.parent.is_dir(): state_file.parent.mkdir(parents=True, exist_ok=True) # initialize dictionary to store network parameters model_state = {**kwargs} # store the spectral bands the model is trained with model_state['bands'] = bands # store model and optimizer class model_state['cls'] = self.__class__ model_state['optim_cls'] = optimizer.__class__ # store construction parameters to instanciate the network model_state['params'] = { 'skip': self.skip, 'filters': self.nfilters, 'nclasses': self.nclasses, 'in_channels': self.in_channels } # store optimizer construction parameters model_state['optim_params'] = optimizer.defaults # store optional keyword arguments model_state['params'] = {**model_state['params'], **self.kwargs} # store model epoch model_state['epoch'] = self.epoch # store model and optimizer state model_state['model_state_dict'] = self.state_dict() model_state['optim_state_dict'] = optimizer.state_dict() # model state dictionary stores the values of all trainable parameters torch.save(model_state, state_file) LOGGER.info('Network parameters saved in {}'.format(state_file)) return model_state
[docs] @staticmethod def load(state_file): """Load a model state. Returns the model in ``state_file`` with the pretrained model and optimizer weights. Useful when resuming training an existing model. Parameters ---------- state_file : `str` or `pathlib.Path` The model state file. Model state files are stored in pysegcnn/main/_models. Raises ------ FileNotFoundError Raised if ``state_file`` does not exist. Returns ------- model : `pysegcnn.core.models.Network` The pretrained model. optimizer : `torch.optim.Optimizer` The optimizer used to train the model. model_state : '`dict` A dictionary containing the model and optimizer state, as constructed by `~pysegcnn.core.Network.save`. """ # load the pretrained model state_file = pathlib.Path(state_file) if not state_file.exists(): raise FileNotFoundError('{} does not exist.'.format(state_file)) LOGGER.info('Loading pretrained weights from: {}'.format(state_file)) # load the model state model_state = torch.load(state_file) # the model and optimizer class model_class = model_state['cls'] optim_class = model_state['optim_cls'] # instanciate pretrained model architecture model = model_class(**model_state['params']) # store state file as instance attribute model.state_file = state_file # load pretrained model weights LOGGER.info('Loading model parameters ...') model.load_state_dict(model_state['model_state_dict']) model.epoch = model_state['epoch'] # resume optimizer parameters LOGGER.info('Loading optimizer parameters ...') optimizer = optim_class(model.parameters(), **model_state['optim_params']) optimizer.load_state_dict(model_state['optim_state_dict']) LOGGER.info('Model epoch: {:d}'.format(model.epoch)) return model, optimizer, model_state
@property def state(self): """Return the model state file. Returns ------- state_file : `pathlib.Path` or `None` The model state file. """ return self.state_file
[docs]class UNet(Network): """A PyTorch implementation of `U-Net`_. Slightly modified version of U-Net: - each convolution is followed by a batch normalization layer - the upsampling is implemented by a 2x2 max unpooling operation .. _U-Net: https://arxiv.org/abs/1505.04597 Parameters ---------- in_channels : `int` Number of channels of the input images. nclasses : `int` Number of classes. filters : `list` [`int`] List of input channels to each convolutional block. skip : `bool` Whether to apply skip connections from the encoder to the decoder. **kwargs: 'dict' [`str`] Additional keyword arguments passed to `pysegcnn.core.layers.Conv2dSame`. Returns ------- None. """ def __init__(self, in_channels, nclasses, filters, skip, **kwargs): super().__init__() # number of input channels self.in_channels = in_channels # number of classes self.nclasses = nclasses # configuration of the convolutional layers in the network self.kwargs = kwargs self.nfilters = filters # convolutional layers of the encoder self.filters = np.hstack([np.array(in_channels), np.array(filters)]) # whether to apply skip connections self.skip = skip # number of epochs trained self.epoch = 0 # construct the encoder self.encoder = Encoder(filters=self.filters, block=ConvBnReluMaxPool, **kwargs) # construct the decoder self.decoder = Decoder(filters=self.filters, block=ConvBnReluMaxUnpool, skip=skip, **kwargs) # construct the classifier self.classifier = Conv2dSame(in_channels=filters[0], out_channels=self.nclasses, kernel_size=1)
[docs] def forward(self, x): """Forward propagation of U-Net. Parameters ---------- x : `torch.Tensor` The input image, shape=(batch_size, channels, height, width). Returns ------- y : 'torch.tensor' The classified image, shape=(batch_size, height, width). """ # forward pass: encoder x = self.encoder(x) # forward pass: decoder x = self.decoder(x, self.encoder.cache) # clear intermediate outputs del self.encoder.cache # classification return self.classifier(x)
[docs]class SupportedModels(enum.Enum): """Names and corresponding classes of the implemented models.""" Unet = UNet
[docs]class SupportedOptimizers(enum.Enum): """Names and corresponding classes of the tested optimizers.""" Adam = optim.Adam
[docs]class SupportedLossFunctions(enum.Enum): """Names and corresponding classes of the tested loss functions.""" CrossEntropy = nn.CrossEntropyLoss