Commit 8173421c authored by Frisinghelli Daniel's avatar Frisinghelli Daniel
Browse files

Added a strided convolutional downsampling layer.

parent 99fe1458
......@@ -664,6 +664,52 @@ class ConvBnReluMaxPool(EncoderBlock):
'dilation=1, ceil_mode=False)')
class ConvBnReluConvS2(EncoderBlock):
"""Block of convolution, batchnorm, relu and convolution with stride=2."""
def __init__(self, in_channels, out_channels, **kwargs):
super().__init__(in_channels, out_channels, **kwargs)
# downsampling layer: convolution with stride=2
self.pool = nn.Sequential(
Conv2dSame(in_channels, out_channels, stride=2, **kwargs),
nn.BatchNorm2d(out_channels),
nn.ReLU())
def layers(self):
"""Sequence of convolution, batchnorm and relu layers.
Returns
-------
layers : :py:class:`torch.nn.Sequential` [:py:class:`torch.nn.Module`]
An instance of :py:class:`torch.nn.Sequential` containing the
sequence of convolution, batchnorm and relu layer
(:py:class:`torch.nn.Module`) instances.
"""
return conv_bn_relu(self.in_channels, self.out_channels, **self.kwargs)
def downsample(self, x):
"""Convolutional layer with stride=2.
Parameters
----------
x : :py:class:`torch.Tensor`, shape=(b, c, h, w)
Input tensor.
Returns
-------
x : :py:class:`torch.Tensor`, shape=(b, c, h // 2, w // 2)
The 2x2 strided tensor.
indices : :py:class:`torch.Tensor` or `None`
The indices of the maxima. Useful for upsampling with
:py:func:`torch.nn.functional.max_unpool2d`.
"""
return self.pool(x), None
class ConvBnReluMaxUnpool(DecoderBlock):
"""Block of convolution, batchnorm, relu and 2x2 max unpool."""
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment