From 32b0b0fdb0a7f5aa1348b0a86c1dffbe31a584ee Mon Sep 17 00:00:00 2001 From: "Daniel.Frisinghelli" <daniel.frisinghelli@eurac.edu> Date: Fri, 21 Aug 2020 17:21:42 +0200 Subject: [PATCH] Implemented generic classes for an encoder-decoder architecture --- pysegcnn/core/layers.py | 526 ++++++++++++++++++++++++++++++---------- 1 file changed, 399 insertions(+), 127 deletions(-) diff --git a/pysegcnn/core/layers.py b/pysegcnn/core/layers.py index d958555..da7d85d 100644 --- a/pysegcnn/core/layers.py +++ b/pysegcnn/core/layers.py @@ -15,6 +15,7 @@ License # -*- coding: utf-8 -*- # externals +import numpy as np import torch import torch.nn as nn import torch.nn.functional as F @@ -32,11 +33,11 @@ class Conv2dSame(nn.Conv2d): *args: `list` [`str`] positional arguments passed to `torch.nn.Conv2d`: ``'in_channels'``: `int` - Number of input channels - ``'kernel_size'``: `int` or `tuple` [`int`] - Size of the convolving kernel + Number of input channels. ``'out_channels'``: `int` - Number of desired output channels + Number of output channels. + ``'kernel_size'``: `int` or `tuple` [`int`] + Size of the convolving kernel. **kwargs: 'dict' [`str`] Additional keyword arguments passed to `torch.nn.Conv2d`_. @@ -88,15 +89,15 @@ def conv_bn_relu(in_channels, out_channels, **kwargs): in_channels : `int` Number of input channels. out_channels : `int` - Number of desired output channels + Number of output channels. **kwargs: 'dict' [`str`] - Additional keyword arguments passed to - `pysegcnn.core.layers.Conv2dSame`. + Additional arguments passed to `pysegcnn.core.layers.Conv2dSame`. Returns ------- - block : `torch.nn.Sequential` - An instance of `torch.nn.Sequential` containing the different layers. + block : `torch.nn.Sequential` [`torch.nn.Module`] + An instance of `torch.nn.Sequential` containing a sequence of + convolution, batch normalization and rectified linear unit layers. """ return nn.Sequential( @@ -109,18 +110,23 @@ def conv_bn_relu(in_channels, out_channels, **kwargs): ) -class Conv2dPool(nn.Module): - """Block of convolution, batchnorm, relu and 2x2 max pool. +class Block(nn.Module): + """Basic convolutional block. Parameters ---------- in_channels : `int` Number of input channels. out_channels : `int` - Number of desired output channels + Number of output channels. **kwargs: 'dict' [`str`] - Additional keyword arguments passed to - `pysegcnn.core.layers.Conv2dSame`. + Additional arguments passed to `pysegcnn.core.layers.Conv2dSame`. + + Raises + ------ + TypeError + Raised if `~pysegcnn.core.layers.Block.layers` method does not return + an instance of `torch.nn.Sequential`. Returns ------- @@ -129,56 +135,67 @@ class Conv2dPool(nn.Module): """ def __init__(self, in_channels, out_channels, **kwargs): - - # initialize nn.Module class super().__init__() - # create the convolutional blocks for this module - self.conv = conv_bn_relu(in_channels, out_channels, **kwargs) + # number of input and output channels + self.in_channels = in_channels + self.out_channels = out_channels - # create the 2x2 max pooling layer - self.pool = nn.MaxPool2d(2, return_indices=True) + # keyword arguments configuring convolutions + self.kwargs = kwargs - # defines the forward pass - def forward(self, x): - """Forward propagation through this block. + # the layers of the block + self.conv = self.layers() + if not isinstance(self.conv, nn.Sequential): + raise TypeError('{}.layers() should return an instance of {}.' + .format(self.__class__.__name__, + repr(nn.Sequential))) - Parameters - ---------- - x : `torch.Tensor` - Output of previous layer. + def layers(self): + """Define the layers of the block. + + Raises + ------ + NotImplementedError + Raised if `pysegcnn.core.layers.Block` is not inherited. Returns ------- - y : `torch.Tensor` - Output of this block. - x : `torch.Tensor` - Output before max pooling. Stored for skip connections. - i : `torch.Tensor` - Indices of the max pooling operation. Used in unpooling operation. + layers : `torch.nn.Sequential` [`torch.nn.Module`] + Return an instance of `torch.nn.Sequential` containing a sequence + of layer (`torch.nn.Module` ) instances. """ - # output of the convolutional block - x = self.conv(x) + raise NotImplementedError('Return an instance of {}.' + .format(repr(nn.Sequential))) + + def forward(self): + """Forward pass of the block. + + Raises + ------ + NotImplementedError + Raised if `pysegcnn.core.layers.Block` is not inherited. - # output of the pooling layer - y, i = self.pool(x) + Returns + ------- + None. - return (y, x, i) + """ + raise NotImplementedError('Implement the forward pass.') -class Conv2dUnpool(nn.Module): - """Block of convolution, batchnorm, relu and 2x2 max unpool. +class EncoderBlock(Block): + """Block of a convolutional encoder. Parameters ---------- in_channels : `int` Number of input channels. out_channels : `int` - Number of desired output channels + Number of output channels. **kwargs: 'dict' [`str`] - Additional keyword arguments passed to - `pysegcnn.core.layers.Conv2dSame`. + Additional arguments passed to `pysegcnn.core.layers.Conv2dSame`. Returns ------- @@ -187,62 +204,86 @@ class Conv2dUnpool(nn.Module): """ def __init__(self, in_channels, out_channels, **kwargs): - super().__init__() - - # create the convolutional blocks for this module - self.conv = conv_bn_relu(in_channels, out_channels, **kwargs) + super().__init__(in_channels, out_channels, **kwargs) - # create the unpooling layer - self.upsample = nn.MaxUnpool2d(2) - - # defines the forward pass - def forward(self, x, feature, indices, skip): - """Forward propagation through this block. + def forward(self, x): + """Forward pass of an encoder block. Parameters ---------- x : `torch.Tensor` - Output of previous layer. - feature : `torch.Tensor` - Encoder feature used for the skip connection. - indices : `torch.Tensor` - Indices of the max pooling operation. Used in unpooling operation. - skip : `bool` - Whether to apply skip connetion. + Input tensor, e.g. output of the previous block/layer. Returns ------- + y : `torch.Tensor` + Output of the encoder block. x : `torch.Tensor` - Output of this block. + Intermediate output before applying downsampling. Useful to + implement skip connections. + indices : `torch.Tensor` or `None` + Optional indices of the downsampling method, e.g. indices of the + maxima when using `torch.nn.functional.max_pool2d`. Useful for + upsampling later. If no indices are required to upsample, simply + return ``indices`` = `None`. """ - # upsampling with pooling indices - x = self.upsample(x, indices, output_size=feature.shape) + # the forward pass of the layers of the block + x = self.conv(x) - # check whether to apply the skip connection - # skip connection: concatenate the output of a layer in the encoder to - # the corresponding layer in the decoder (along the channel axis) - if skip: - x = torch.cat([x, feature], axis=1) + # the downsampling layer + y, indices = self.downsample(x) - # output of the convolutional layer - x = self.conv(x) + return (y, x, indices) - return x + def downsample(self, x): + """Define the downsampling method. + The `~pysegcnn.core.layers.EncoderBlock.downsample` `method should + implement the spatial pooling operation. -class Conv2dUpsample(nn.Module): - """Block of convolution, batchnorm, relu and nearest neighbor upsampling. + Use one of the following functions to downsample: + - `torch.nn.functional.max_pool2d` + - `torch.nn.functional.interpolate` + + See `pysegcnn.core.layers.ConvBnReluMaxPool` for an example + implementation. + + Parameters + ---------- + x : `torch.Tensor` + Input tensor, e.g. output of a convolutional block. + + Raises + ------ + NotImplementedError + Raised if `pysegcnn.core.layers.EncoderBlock` is not inherited. + + Returns + ------- + x : `torch.Tensor` + The spatially downsampled tensor. + indices : `torch.Tensor` or `None` + Optional indices of the downsampling method, e.g. indices of the + maxima when using `torch.nn.functional.max_pool2d`. Useful for + upsampling later. If no indices are required to upsample, simply + return ``indices`` = `None`. + + """ + raise NotImplementedError('Implement the downsampling function.') + + +class DecoderBlock(Block): + """Block of a convolutional decoder. Parameters ---------- in_channels : `int` Number of input channels. out_channels : `int` - Number of desired output channels + Number of output channels. **kwargs: 'dict' [`str`] - Additional keyword arguments passed to - `pysegcnn.core.layers.Conv2dSame`. + Additional arguments passed to `pysegcnn.core.layers.Conv2dSame`. Returns ------- @@ -251,43 +292,37 @@ class Conv2dUpsample(nn.Module): """ def __init__(self, in_channels, out_channels, **kwargs): - super().__init__() + super().__init__(in_channels, out_channels, **kwargs) - # create the convolutional blocks for this module - self.conv = conv_bn_relu(in_channels, out_channels, **kwargs) - - # create the upsampling layer - self.upsample = F.interpolate - - # defines the forward pass def forward(self, x, feature, indices, skip): - """Forward propagation through this block. + """Forward pass of a decoder block. Parameters ---------- x : `torch.Tensor` - Output of previous layer. - feature : `torch.Tensor` - Encoder feature used for the skip connection. - indices : `torch.Tensor` - Indices of the max pooling operation. Used in unpooling operation. - Not used here, but passed to preserve generic interface. Useful in - `pysegcnn.core.layers.Decoder`. + Input tensor. + feature : `torch.Tensor`, shape=(batch, channel, height, width) + Intermediate output of a layer in the encoder. + If ``skip`` = True, ``feature`` is concatenated (along the channel + axis) to the output of the respective upsampling layer in the + decoder (skip connection). + indices : `torch.Tensor` or `None` + Indices of the encoder downsampling method. skip : `bool` - Whether to apply skip connection. + Whether to apply the skip connection. Returns ------- x : `torch.Tensor` - Output of this block. + Output of the decoder block. """ - # upsampling with pooling indices - x = self.upsample(x, size=feature.shape[2:], mode='nearest') + # upsample + x = self.upsample(x, feature, indices) # check whether to apply the skip connection # skip connection: concatenate the output of a layer in the encoder to - # the corresponding layer in the decoder (along the channel axis ) + # the corresponding layer in the decoder (along the channel axis) if skip: x = torch.cat([x, feature], axis=1) @@ -296,20 +331,65 @@ class Conv2dUpsample(nn.Module): return x + def upsample(self, x, feature, indices): + """Define the upsampling method. + + The `~pysegcnn.core.layers.DecoderBlock.upsample` `method should + implement the spatial upsampling operation. + + Use one of the following functions to upsample: + - `torch.nn.functional.max_unpool2d` + - `torch.nn.functional.interpolate` + + See `pysegcnn.core.layers.ConvBnReluMaxUnpool` or + `pysegcnn.core.layers.ConvBnReluUpsample` for an example + implementation. + + Parameters + ---------- + x : `torch.Tensor` + Input tensor, e.g. output of a convolutional block. + feature : `torch.Tensor`, shape=(batch, channel, height, width) + Intermediate output of a layer in the encoder. Used to implement + skip connections. + indices : `torch.Tensor` or `None` + Indices of the encoder downsampling method. + + Raises + ------ + NotImplementedError + Raised if `pysegcnn.core.layers.DecoderBlock` is not inherited. + + Returns + ------- + x : `torch.Tensor` + The spatially upsampled tensor. + + """ + raise NotImplementedError('Implement the upsampling function') + class Encoder(nn.Module): """Generic convolutional encoder. + When instanciating an encoder-decoder architechure, ``filters`` should be + the same for `pysegcnn.core.layers.Encoder` and + `pysegcnn.core.layers.Decoder`. + + See `pysegcnn.core.models.UNet` for an example implementation. + Parameters ---------- filters : `list` [`int`] - List of input channels to each convolutional block. - block : `torch.nn.Module` - The convolutional block. ``block`` should inherit from - `torch.nn.Module`, e.g. `pysegcnn.core.layers.Conv2dPool`. + List of input channels to each convolutional block. The length of + ``filters`` determines the depth of the encoder. The first element of + ``filters`` has to be the number of channels of the input images. + block : `pysegcnn.core.layers.EncoderBlock` + The convolutional block defining a layer in the encoder. + A subclass of `pysegcnn.core.layers.EncoderBlock`, e.g. + `pysegcnn.core.layers.ConvBnReluMaxPool`. **kwargs: 'dict' [`str`] - Additional keyword arguments passed to - `pysegcnn.core.layers.Conv2dSame`. + Additional arguments passed to `pysegcnn.core.layers.Conv2dSame`. Returns ------- @@ -322,23 +402,34 @@ class Encoder(nn.Module): # the number of filters for each block: the first element of filters # has to be the number of input channels - self.features = filters + self.features = np.asarray(filters) # the block of operations defining a layer in the encoder + if not issubclass(block, EncoderBlock): + raise TypeError('"block" expected to be a subclass of {}.' + .format(repr(EncoderBlock))) self.block = block # construct the encoder layers self.layers = [] - for i, (l, lp1) in enumerate(zip(self.features, self.features[1:])): + for lyr, lyrp1 in zip(self.features, self.features[1:]): # append blocks to the encoder layers - self.layers.append(self.block(l, lp1, **kwargs)) + self.layers.append(self.block(lyr, lyrp1, **kwargs)) # convert list of layers to ModuleList self.layers = nn.ModuleList(*[self.layers]) - # forward pass through the encoder def forward(self, x): - """Forward propagation through the encoder. + """Forward pass of the encoder. + + Stores intermediate outputs in a dictionary. The keys of the dictionary + are the number of the network layers and the values are dictionaries + with the following (key, value) pairs: + ``"feature"`` + The intermediate encoder outputs (`torch.Tensor`). + ``"indices"`` + The indices of the max pooling layer, if required + (`torch.Tensor`). Parameters ---------- @@ -354,7 +445,6 @@ class Encoder(nn.Module): # initialize a dictionary that caches the intermediate outputs, i.e. # features and pooling indices of each block in the encoder self.cache = {} - for i, layer in enumerate(self.layers): # apply current encoder layer forward pass x, y, ind = layer.forward(x) @@ -368,18 +458,26 @@ class Encoder(nn.Module): class Decoder(nn.Module): """Generic convolutional decoder. + When instanciating an encoder-decoder architechure, ``filters`` should be + the same for `pysegcnn.core.layers.Encoder` and + `pysegcnn.core.layers.Decoder`. + + See `pysegcnn.core.models.UNet` for an example implementation. + Parameters ---------- filters : `list` [`int`] - List of input channels to each convolutional block. - block : `torch.nn.Module` - The convolutional block. ``block`` should inherit from - `torch.nn.Module`, e.g. `pysegcnn.core.layers.Conv2dUnpool`. + List of input channels to each convolutional block. The length of + ``filters`` determines the depth of the decoder. The first element of + ``filters`` has to be the number of channels of the input images. + block : `pysegcnn.core.layers.DecoderBlock` + The convolutional block defining a layer in the decoder. + A subclass of `pysegcnn.core.layers.DecoderBlock`, e.g. + `pysegcnn.core.layers.ConvBnReluMaxUnpool`. skip : `bool` Whether to apply skip connections from the encoder to the decoder. **kwargs: 'dict' [`str`] - Additional keyword arguments passed to - `pysegcnn.core.layers.Conv2dSame`. + Additional arguments passed to `pysegcnn.core.layers.Conv2dSame`. Returns ------- @@ -391,12 +489,15 @@ class Decoder(nn.Module): super().__init__() # the block of operations defining a layer in the decoder + if not issubclass(block, DecoderBlock): + raise TypeError('"block" expected to be a subclass of {}.' + .format(repr(DecoderBlock))) self.block = block # the number of filters for each block is symmetric to the encoder: # the last two element of filters have to be equal in order to apply # last skip connection - self.features = filters[::-1] + self.features = np.asarray(filters)[::-1] self.features[-1] = self.features[-2] # whether to apply skip connections @@ -404,9 +505,7 @@ class Decoder(nn.Module): # in case of skip connections, the number of input channels to # each block of the decoder is doubled - n_in = 1 - if self.skip: - n_in *= 2 + n_in = 2 if self.skip else 1 # construct decoder layers self.layers = [] @@ -416,20 +515,21 @@ class Decoder(nn.Module): # convert list of layers to ModuleList self.layers = nn.ModuleList(*[self.layers]) - # forward pass through decoder def forward(self, x, enc_cache): - """Forward propagation through the decoder. + """Forward pass of the decoder. Parameters ---------- x : `torch.Tensor` Output of the encoder. - enc_cache : `dict` - Cache dictionary with keys: - ``'feature'`` - Encoder features used for the skip connection. - ``'indices'`` - The indices of the max pooling operations. + enc_cache : `dict` [`dict`] + Cache dictionary. The keys of the dictionary are the number of the + network layers and the values are dictionaries with the following + (key, value) pairs: + ``"feature"`` + The intermediate encoder outputs (`torch.Tensor`). + ``"indices"`` + The indices of the max pooling layer (`torch.Tensor`). Returns ------- @@ -449,3 +549,175 @@ class Decoder(nn.Module): x = layer.forward(x, feature, indices, self.skip) return x + + +class ConvBnReluMaxPool(EncoderBlock): + """Block of convolution, batchnorm, relu and 2x2 max pool. + + Parameters + ---------- + in_channels : `int` + Number of input channels. + out_channels : `int` + Number of output channels. + **kwargs: 'dict' [`str`] + Additional keyword arguments passed to + `pysegcnn.core.layers.Conv2dSame`. + + Returns + ------- + None. + + """ + + def __init__(self, in_channels, out_channels, **kwargs): + super().__init__(in_channels, out_channels, **kwargs) + + def layers(self): + """Sequence of convolution, batchnorm and relu layers. + + Returns + ------- + layers : `torch.nn.Sequential` [`torch.nn.Module`] + An instance of `torch.nn.Sequential` containing the sequence + of convolution, batchnorm and relu layer (`torch.nn.Module`) + instances. + + """ + return conv_bn_relu(self.in_channels, self.out_channels, **self.kwargs) + + def downsample(self, x): + """2x2 max pooling layer, `torch.nn.functional.max_pool2d`. + + Parameters + ---------- + x : `torch.Tensor` + Input tensor. + + Returns + ------- + x : `torch.Tensor` + The 2x2 max pooled tensor. + indices : `torch.Tensor` or `None` + The indices of the maxima. Useful for upsampling with + `torch.nn.functional.max_unpool2d`. + """ + x, indices = F.max_pool2d(x, kernel_size=2, return_indices=True) + return x, indices + + +class ConvBnReluMaxUnpool(DecoderBlock): + """Block of convolution, batchnorm, relu and 2x2 max unpool. + + Parameters + ---------- + in_channels : `int` + Number of input channels. + out_channels : `int` + Number of output channels + **kwargs: 'dict' [`str`] + Additional keyword arguments passed to + `pysegcnn.core.layers.Conv2dSame`. + + Returns + ------- + None. + + """ + + def __init__(self, in_channels, out_channels, **kwargs): + super().__init__(in_channels, out_channels, **kwargs) + + def layers(self): + """Sequence of convolution, batchnorm and relu layers. + + Returns + ------- + layers : `torch.nn.Sequential` [`torch.nn.Module`] + An instance of `torch.nn.Sequential` containing the sequence + of convolution, batchnorm and relu layer (`torch.nn.Module`) + instances. + + """ + return conv_bn_relu(self.in_channels, self.out_channels, **self.kwargs) + + def upsample(self, x, feature, indices): + """2x2 max unpooling layer. + + Parameters + ---------- + x : `torch.Tensor` + Input tensor. + feature : `torch.Tensor`, shape=(batch, channel, height, width) + Intermediate output of a layer in the encoder. Used to determine + the output shape of the upsampling operation. + indices : `torch.Tensor` + The indices of the maxima of the max pooling operation + (as returned by `torch.nn.functional.max_pool2d`). + + Returns + ------- + x : `torch.Tensor` + The 2x2 max unpooled tensor. + + """ + return F.max_unpool2d(x, indices, kernel_size=2, + output_size=feature.shape[2:]) + + +class ConvBnReluUpsample(DecoderBlock): + """Block of convolution, batchnorm, relu and nearest neighbor upsampling. + + Parameters + ---------- + in_channels : `int` + Number of input channels. + out_channels : `int` + Number of output channels + **kwargs: 'dict' [`str`] + Additional arguments passed to `pysegcnn.core.layers.Conv2dSame`. + + Returns + ------- + None. + + """ + + def __init__(self, in_channels, out_channels, **kwargs): + super().__init__(in_channels, out_channels, **kwargs) + + def layers(self): + """Sequence of convolution, batchnorm and relu layers. + + Returns + ------- + layers : `torch.nn.Sequential` [`torch.nn.Module`] + An instance of `torch.nn.Sequential` containing the sequence + of convolution, batchnorm and relu layer (`torch.nn.Module`) + instances. + + """ + return conv_bn_relu(self.in_channels, self.out_channels, **self.kwargs) + + def upsample(self, x, feature, indices=None): + """Nearest neighbor upsampling. + + Parameters + ---------- + x : `torch.Tensor` + Input tensor. + feature : `torch.Tensor`, shape=(batch, channel, height, width) + Intermediate output of a layer in the encoder. Used to determine + the output shape of the upsampling operation. + indices : `None`, optional + The indices of the maxima of the max pooling operation + (as returned by `torch.nn.functional.max_pool2d`). Not required by + this upsampling method. + + Returns + ------- + x : `torch.Tensor` + The 2x2 max unpooled tensor. + + """ + return F.interpolate(x, size=feature.shape[2:], mode='nearest') -- GitLab