Skip to content
Snippets Groups Projects
Commit 32b0b0fd authored by Frisinghelli Daniel's avatar Frisinghelli Daniel
Browse files

Implemented generic classes for an encoder-decoder architecture

parent 4d102fdc
No related branches found
No related tags found
No related merge requests found
......@@ -15,6 +15,7 @@ License
# -*- coding: utf-8 -*-
# externals
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
......@@ -32,11 +33,11 @@ class Conv2dSame(nn.Conv2d):
*args: `list` [`str`]
positional arguments passed to `torch.nn.Conv2d`:
``'in_channels'``: `int`
Number of input channels
``'kernel_size'``: `int` or `tuple` [`int`]
Size of the convolving kernel
Number of input channels.
``'out_channels'``: `int`
Number of desired output channels
Number of output channels.
``'kernel_size'``: `int` or `tuple` [`int`]
Size of the convolving kernel.
**kwargs: 'dict' [`str`]
Additional keyword arguments passed to `torch.nn.Conv2d`_.
......@@ -88,15 +89,15 @@ def conv_bn_relu(in_channels, out_channels, **kwargs):
in_channels : `int`
Number of input channels.
out_channels : `int`
Number of desired output channels
Number of output channels.
**kwargs: 'dict' [`str`]
Additional keyword arguments passed to
`pysegcnn.core.layers.Conv2dSame`.
Additional arguments passed to `pysegcnn.core.layers.Conv2dSame`.
Returns
-------
block : `torch.nn.Sequential`
An instance of `torch.nn.Sequential` containing the different layers.
block : `torch.nn.Sequential` [`torch.nn.Module`]
An instance of `torch.nn.Sequential` containing a sequence of
convolution, batch normalization and rectified linear unit layers.
"""
return nn.Sequential(
......@@ -109,18 +110,23 @@ def conv_bn_relu(in_channels, out_channels, **kwargs):
)
class Conv2dPool(nn.Module):
"""Block of convolution, batchnorm, relu and 2x2 max pool.
class Block(nn.Module):
"""Basic convolutional block.
Parameters
----------
in_channels : `int`
Number of input channels.
out_channels : `int`
Number of desired output channels
Number of output channels.
**kwargs: 'dict' [`str`]
Additional keyword arguments passed to
`pysegcnn.core.layers.Conv2dSame`.
Additional arguments passed to `pysegcnn.core.layers.Conv2dSame`.
Raises
------
TypeError
Raised if `~pysegcnn.core.layers.Block.layers` method does not return
an instance of `torch.nn.Sequential`.
Returns
-------
......@@ -129,56 +135,67 @@ class Conv2dPool(nn.Module):
"""
def __init__(self, in_channels, out_channels, **kwargs):
# initialize nn.Module class
super().__init__()
# create the convolutional blocks for this module
self.conv = conv_bn_relu(in_channels, out_channels, **kwargs)
# number of input and output channels
self.in_channels = in_channels
self.out_channels = out_channels
# create the 2x2 max pooling layer
self.pool = nn.MaxPool2d(2, return_indices=True)
# keyword arguments configuring convolutions
self.kwargs = kwargs
# defines the forward pass
def forward(self, x):
"""Forward propagation through this block.
# the layers of the block
self.conv = self.layers()
if not isinstance(self.conv, nn.Sequential):
raise TypeError('{}.layers() should return an instance of {}.'
.format(self.__class__.__name__,
repr(nn.Sequential)))
Parameters
----------
x : `torch.Tensor`
Output of previous layer.
def layers(self):
"""Define the layers of the block.
Raises
------
NotImplementedError
Raised if `pysegcnn.core.layers.Block` is not inherited.
Returns
-------
y : `torch.Tensor`
Output of this block.
x : `torch.Tensor`
Output before max pooling. Stored for skip connections.
i : `torch.Tensor`
Indices of the max pooling operation. Used in unpooling operation.
layers : `torch.nn.Sequential` [`torch.nn.Module`]
Return an instance of `torch.nn.Sequential` containing a sequence
of layer (`torch.nn.Module` ) instances.
"""
# output of the convolutional block
x = self.conv(x)
raise NotImplementedError('Return an instance of {}.'
.format(repr(nn.Sequential)))
def forward(self):
"""Forward pass of the block.
Raises
------
NotImplementedError
Raised if `pysegcnn.core.layers.Block` is not inherited.
# output of the pooling layer
y, i = self.pool(x)
Returns
-------
None.
return (y, x, i)
"""
raise NotImplementedError('Implement the forward pass.')
class Conv2dUnpool(nn.Module):
"""Block of convolution, batchnorm, relu and 2x2 max unpool.
class EncoderBlock(Block):
"""Block of a convolutional encoder.
Parameters
----------
in_channels : `int`
Number of input channels.
out_channels : `int`
Number of desired output channels
Number of output channels.
**kwargs: 'dict' [`str`]
Additional keyword arguments passed to
`pysegcnn.core.layers.Conv2dSame`.
Additional arguments passed to `pysegcnn.core.layers.Conv2dSame`.
Returns
-------
......@@ -187,62 +204,86 @@ class Conv2dUnpool(nn.Module):
"""
def __init__(self, in_channels, out_channels, **kwargs):
super().__init__()
# create the convolutional blocks for this module
self.conv = conv_bn_relu(in_channels, out_channels, **kwargs)
super().__init__(in_channels, out_channels, **kwargs)
# create the unpooling layer
self.upsample = nn.MaxUnpool2d(2)
# defines the forward pass
def forward(self, x, feature, indices, skip):
"""Forward propagation through this block.
def forward(self, x):
"""Forward pass of an encoder block.
Parameters
----------
x : `torch.Tensor`
Output of previous layer.
feature : `torch.Tensor`
Encoder feature used for the skip connection.
indices : `torch.Tensor`
Indices of the max pooling operation. Used in unpooling operation.
skip : `bool`
Whether to apply skip connetion.
Input tensor, e.g. output of the previous block/layer.
Returns
-------
y : `torch.Tensor`
Output of the encoder block.
x : `torch.Tensor`
Output of this block.
Intermediate output before applying downsampling. Useful to
implement skip connections.
indices : `torch.Tensor` or `None`
Optional indices of the downsampling method, e.g. indices of the
maxima when using `torch.nn.functional.max_pool2d`. Useful for
upsampling later. If no indices are required to upsample, simply
return ``indices`` = `None`.
"""
# upsampling with pooling indices
x = self.upsample(x, indices, output_size=feature.shape)
# the forward pass of the layers of the block
x = self.conv(x)
# check whether to apply the skip connection
# skip connection: concatenate the output of a layer in the encoder to
# the corresponding layer in the decoder (along the channel axis)
if skip:
x = torch.cat([x, feature], axis=1)
# the downsampling layer
y, indices = self.downsample(x)
# output of the convolutional layer
x = self.conv(x)
return (y, x, indices)
return x
def downsample(self, x):
"""Define the downsampling method.
The `~pysegcnn.core.layers.EncoderBlock.downsample` `method should
implement the spatial pooling operation.
class Conv2dUpsample(nn.Module):
"""Block of convolution, batchnorm, relu and nearest neighbor upsampling.
Use one of the following functions to downsample:
- `torch.nn.functional.max_pool2d`
- `torch.nn.functional.interpolate`
See `pysegcnn.core.layers.ConvBnReluMaxPool` for an example
implementation.
Parameters
----------
x : `torch.Tensor`
Input tensor, e.g. output of a convolutional block.
Raises
------
NotImplementedError
Raised if `pysegcnn.core.layers.EncoderBlock` is not inherited.
Returns
-------
x : `torch.Tensor`
The spatially downsampled tensor.
indices : `torch.Tensor` or `None`
Optional indices of the downsampling method, e.g. indices of the
maxima when using `torch.nn.functional.max_pool2d`. Useful for
upsampling later. If no indices are required to upsample, simply
return ``indices`` = `None`.
"""
raise NotImplementedError('Implement the downsampling function.')
class DecoderBlock(Block):
"""Block of a convolutional decoder.
Parameters
----------
in_channels : `int`
Number of input channels.
out_channels : `int`
Number of desired output channels
Number of output channels.
**kwargs: 'dict' [`str`]
Additional keyword arguments passed to
`pysegcnn.core.layers.Conv2dSame`.
Additional arguments passed to `pysegcnn.core.layers.Conv2dSame`.
Returns
-------
......@@ -251,43 +292,37 @@ class Conv2dUpsample(nn.Module):
"""
def __init__(self, in_channels, out_channels, **kwargs):
super().__init__()
super().__init__(in_channels, out_channels, **kwargs)
# create the convolutional blocks for this module
self.conv = conv_bn_relu(in_channels, out_channels, **kwargs)
# create the upsampling layer
self.upsample = F.interpolate
# defines the forward pass
def forward(self, x, feature, indices, skip):
"""Forward propagation through this block.
"""Forward pass of a decoder block.
Parameters
----------
x : `torch.Tensor`
Output of previous layer.
feature : `torch.Tensor`
Encoder feature used for the skip connection.
indices : `torch.Tensor`
Indices of the max pooling operation. Used in unpooling operation.
Not used here, but passed to preserve generic interface. Useful in
`pysegcnn.core.layers.Decoder`.
Input tensor.
feature : `torch.Tensor`, shape=(batch, channel, height, width)
Intermediate output of a layer in the encoder.
If ``skip`` = True, ``feature`` is concatenated (along the channel
axis) to the output of the respective upsampling layer in the
decoder (skip connection).
indices : `torch.Tensor` or `None`
Indices of the encoder downsampling method.
skip : `bool`
Whether to apply skip connection.
Whether to apply the skip connection.
Returns
-------
x : `torch.Tensor`
Output of this block.
Output of the decoder block.
"""
# upsampling with pooling indices
x = self.upsample(x, size=feature.shape[2:], mode='nearest')
# upsample
x = self.upsample(x, feature, indices)
# check whether to apply the skip connection
# skip connection: concatenate the output of a layer in the encoder to
# the corresponding layer in the decoder (along the channel axis )
# the corresponding layer in the decoder (along the channel axis)
if skip:
x = torch.cat([x, feature], axis=1)
......@@ -296,20 +331,65 @@ class Conv2dUpsample(nn.Module):
return x
def upsample(self, x, feature, indices):
"""Define the upsampling method.
The `~pysegcnn.core.layers.DecoderBlock.upsample` `method should
implement the spatial upsampling operation.
Use one of the following functions to upsample:
- `torch.nn.functional.max_unpool2d`
- `torch.nn.functional.interpolate`
See `pysegcnn.core.layers.ConvBnReluMaxUnpool` or
`pysegcnn.core.layers.ConvBnReluUpsample` for an example
implementation.
Parameters
----------
x : `torch.Tensor`
Input tensor, e.g. output of a convolutional block.
feature : `torch.Tensor`, shape=(batch, channel, height, width)
Intermediate output of a layer in the encoder. Used to implement
skip connections.
indices : `torch.Tensor` or `None`
Indices of the encoder downsampling method.
Raises
------
NotImplementedError
Raised if `pysegcnn.core.layers.DecoderBlock` is not inherited.
Returns
-------
x : `torch.Tensor`
The spatially upsampled tensor.
"""
raise NotImplementedError('Implement the upsampling function')
class Encoder(nn.Module):
"""Generic convolutional encoder.
When instanciating an encoder-decoder architechure, ``filters`` should be
the same for `pysegcnn.core.layers.Encoder` and
`pysegcnn.core.layers.Decoder`.
See `pysegcnn.core.models.UNet` for an example implementation.
Parameters
----------
filters : `list` [`int`]
List of input channels to each convolutional block.
block : `torch.nn.Module`
The convolutional block. ``block`` should inherit from
`torch.nn.Module`, e.g. `pysegcnn.core.layers.Conv2dPool`.
List of input channels to each convolutional block. The length of
``filters`` determines the depth of the encoder. The first element of
``filters`` has to be the number of channels of the input images.
block : `pysegcnn.core.layers.EncoderBlock`
The convolutional block defining a layer in the encoder.
A subclass of `pysegcnn.core.layers.EncoderBlock`, e.g.
`pysegcnn.core.layers.ConvBnReluMaxPool`.
**kwargs: 'dict' [`str`]
Additional keyword arguments passed to
`pysegcnn.core.layers.Conv2dSame`.
Additional arguments passed to `pysegcnn.core.layers.Conv2dSame`.
Returns
-------
......@@ -322,23 +402,34 @@ class Encoder(nn.Module):
# the number of filters for each block: the first element of filters
# has to be the number of input channels
self.features = filters
self.features = np.asarray(filters)
# the block of operations defining a layer in the encoder
if not issubclass(block, EncoderBlock):
raise TypeError('"block" expected to be a subclass of {}.'
.format(repr(EncoderBlock)))
self.block = block
# construct the encoder layers
self.layers = []
for i, (l, lp1) in enumerate(zip(self.features, self.features[1:])):
for lyr, lyrp1 in zip(self.features, self.features[1:]):
# append blocks to the encoder layers
self.layers.append(self.block(l, lp1, **kwargs))
self.layers.append(self.block(lyr, lyrp1, **kwargs))
# convert list of layers to ModuleList
self.layers = nn.ModuleList(*[self.layers])
# forward pass through the encoder
def forward(self, x):
"""Forward propagation through the encoder.
"""Forward pass of the encoder.
Stores intermediate outputs in a dictionary. The keys of the dictionary
are the number of the network layers and the values are dictionaries
with the following (key, value) pairs:
``"feature"``
The intermediate encoder outputs (`torch.Tensor`).
``"indices"``
The indices of the max pooling layer, if required
(`torch.Tensor`).
Parameters
----------
......@@ -354,7 +445,6 @@ class Encoder(nn.Module):
# initialize a dictionary that caches the intermediate outputs, i.e.
# features and pooling indices of each block in the encoder
self.cache = {}
for i, layer in enumerate(self.layers):
# apply current encoder layer forward pass
x, y, ind = layer.forward(x)
......@@ -368,18 +458,26 @@ class Encoder(nn.Module):
class Decoder(nn.Module):
"""Generic convolutional decoder.
When instanciating an encoder-decoder architechure, ``filters`` should be
the same for `pysegcnn.core.layers.Encoder` and
`pysegcnn.core.layers.Decoder`.
See `pysegcnn.core.models.UNet` for an example implementation.
Parameters
----------
filters : `list` [`int`]
List of input channels to each convolutional block.
block : `torch.nn.Module`
The convolutional block. ``block`` should inherit from
`torch.nn.Module`, e.g. `pysegcnn.core.layers.Conv2dUnpool`.
List of input channels to each convolutional block. The length of
``filters`` determines the depth of the decoder. The first element of
``filters`` has to be the number of channels of the input images.
block : `pysegcnn.core.layers.DecoderBlock`
The convolutional block defining a layer in the decoder.
A subclass of `pysegcnn.core.layers.DecoderBlock`, e.g.
`pysegcnn.core.layers.ConvBnReluMaxUnpool`.
skip : `bool`
Whether to apply skip connections from the encoder to the decoder.
**kwargs: 'dict' [`str`]
Additional keyword arguments passed to
`pysegcnn.core.layers.Conv2dSame`.
Additional arguments passed to `pysegcnn.core.layers.Conv2dSame`.
Returns
-------
......@@ -391,12 +489,15 @@ class Decoder(nn.Module):
super().__init__()
# the block of operations defining a layer in the decoder
if not issubclass(block, DecoderBlock):
raise TypeError('"block" expected to be a subclass of {}.'
.format(repr(DecoderBlock)))
self.block = block
# the number of filters for each block is symmetric to the encoder:
# the last two element of filters have to be equal in order to apply
# last skip connection
self.features = filters[::-1]
self.features = np.asarray(filters)[::-1]
self.features[-1] = self.features[-2]
# whether to apply skip connections
......@@ -404,9 +505,7 @@ class Decoder(nn.Module):
# in case of skip connections, the number of input channels to
# each block of the decoder is doubled
n_in = 1
if self.skip:
n_in *= 2
n_in = 2 if self.skip else 1
# construct decoder layers
self.layers = []
......@@ -416,20 +515,21 @@ class Decoder(nn.Module):
# convert list of layers to ModuleList
self.layers = nn.ModuleList(*[self.layers])
# forward pass through decoder
def forward(self, x, enc_cache):
"""Forward propagation through the decoder.
"""Forward pass of the decoder.
Parameters
----------
x : `torch.Tensor`
Output of the encoder.
enc_cache : `dict`
Cache dictionary with keys:
``'feature'``
Encoder features used for the skip connection.
``'indices'``
The indices of the max pooling operations.
enc_cache : `dict` [`dict`]
Cache dictionary. The keys of the dictionary are the number of the
network layers and the values are dictionaries with the following
(key, value) pairs:
``"feature"``
The intermediate encoder outputs (`torch.Tensor`).
``"indices"``
The indices of the max pooling layer (`torch.Tensor`).
Returns
-------
......@@ -449,3 +549,175 @@ class Decoder(nn.Module):
x = layer.forward(x, feature, indices, self.skip)
return x
class ConvBnReluMaxPool(EncoderBlock):
"""Block of convolution, batchnorm, relu and 2x2 max pool.
Parameters
----------
in_channels : `int`
Number of input channels.
out_channels : `int`
Number of output channels.
**kwargs: 'dict' [`str`]
Additional keyword arguments passed to
`pysegcnn.core.layers.Conv2dSame`.
Returns
-------
None.
"""
def __init__(self, in_channels, out_channels, **kwargs):
super().__init__(in_channels, out_channels, **kwargs)
def layers(self):
"""Sequence of convolution, batchnorm and relu layers.
Returns
-------
layers : `torch.nn.Sequential` [`torch.nn.Module`]
An instance of `torch.nn.Sequential` containing the sequence
of convolution, batchnorm and relu layer (`torch.nn.Module`)
instances.
"""
return conv_bn_relu(self.in_channels, self.out_channels, **self.kwargs)
def downsample(self, x):
"""2x2 max pooling layer, `torch.nn.functional.max_pool2d`.
Parameters
----------
x : `torch.Tensor`
Input tensor.
Returns
-------
x : `torch.Tensor`
The 2x2 max pooled tensor.
indices : `torch.Tensor` or `None`
The indices of the maxima. Useful for upsampling with
`torch.nn.functional.max_unpool2d`.
"""
x, indices = F.max_pool2d(x, kernel_size=2, return_indices=True)
return x, indices
class ConvBnReluMaxUnpool(DecoderBlock):
"""Block of convolution, batchnorm, relu and 2x2 max unpool.
Parameters
----------
in_channels : `int`
Number of input channels.
out_channels : `int`
Number of output channels
**kwargs: 'dict' [`str`]
Additional keyword arguments passed to
`pysegcnn.core.layers.Conv2dSame`.
Returns
-------
None.
"""
def __init__(self, in_channels, out_channels, **kwargs):
super().__init__(in_channels, out_channels, **kwargs)
def layers(self):
"""Sequence of convolution, batchnorm and relu layers.
Returns
-------
layers : `torch.nn.Sequential` [`torch.nn.Module`]
An instance of `torch.nn.Sequential` containing the sequence
of convolution, batchnorm and relu layer (`torch.nn.Module`)
instances.
"""
return conv_bn_relu(self.in_channels, self.out_channels, **self.kwargs)
def upsample(self, x, feature, indices):
"""2x2 max unpooling layer.
Parameters
----------
x : `torch.Tensor`
Input tensor.
feature : `torch.Tensor`, shape=(batch, channel, height, width)
Intermediate output of a layer in the encoder. Used to determine
the output shape of the upsampling operation.
indices : `torch.Tensor`
The indices of the maxima of the max pooling operation
(as returned by `torch.nn.functional.max_pool2d`).
Returns
-------
x : `torch.Tensor`
The 2x2 max unpooled tensor.
"""
return F.max_unpool2d(x, indices, kernel_size=2,
output_size=feature.shape[2:])
class ConvBnReluUpsample(DecoderBlock):
"""Block of convolution, batchnorm, relu and nearest neighbor upsampling.
Parameters
----------
in_channels : `int`
Number of input channels.
out_channels : `int`
Number of output channels
**kwargs: 'dict' [`str`]
Additional arguments passed to `pysegcnn.core.layers.Conv2dSame`.
Returns
-------
None.
"""
def __init__(self, in_channels, out_channels, **kwargs):
super().__init__(in_channels, out_channels, **kwargs)
def layers(self):
"""Sequence of convolution, batchnorm and relu layers.
Returns
-------
layers : `torch.nn.Sequential` [`torch.nn.Module`]
An instance of `torch.nn.Sequential` containing the sequence
of convolution, batchnorm and relu layer (`torch.nn.Module`)
instances.
"""
return conv_bn_relu(self.in_channels, self.out_channels, **self.kwargs)
def upsample(self, x, feature, indices=None):
"""Nearest neighbor upsampling.
Parameters
----------
x : `torch.Tensor`
Input tensor.
feature : `torch.Tensor`, shape=(batch, channel, height, width)
Intermediate output of a layer in the encoder. Used to determine
the output shape of the upsampling operation.
indices : `None`, optional
The indices of the maxima of the max pooling operation
(as returned by `torch.nn.functional.max_pool2d`). Not required by
this upsampling method.
Returns
-------
x : `torch.Tensor`
The 2x2 max unpooled tensor.
"""
return F.interpolate(x, size=feature.shape[2:], mode='nearest')
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment