From 32b0b0fdb0a7f5aa1348b0a86c1dffbe31a584ee Mon Sep 17 00:00:00 2001
From: "Daniel.Frisinghelli" <daniel.frisinghelli@eurac.edu>
Date: Fri, 21 Aug 2020 17:21:42 +0200
Subject: [PATCH] Implemented generic classes for an encoder-decoder
 architecture

---
 pysegcnn/core/layers.py | 526 ++++++++++++++++++++++++++++++----------
 1 file changed, 399 insertions(+), 127 deletions(-)

diff --git a/pysegcnn/core/layers.py b/pysegcnn/core/layers.py
index d958555..da7d85d 100644
--- a/pysegcnn/core/layers.py
+++ b/pysegcnn/core/layers.py
@@ -15,6 +15,7 @@ License
 # -*- coding: utf-8 -*-
 
 # externals
+import numpy as np
 import torch
 import torch.nn as nn
 import torch.nn.functional as F
@@ -32,11 +33,11 @@ class Conv2dSame(nn.Conv2d):
     *args: `list` [`str`]
         positional arguments passed to `torch.nn.Conv2d`:
             ``'in_channels'``: `int`
-                Number of input channels
-            ``'kernel_size'``: `int` or `tuple` [`int`]
-                Size of the convolving kernel
+                Number of input channels.
             ``'out_channels'``: `int`
-                Number of desired output channels
+                Number of output channels.
+            ``'kernel_size'``: `int` or `tuple` [`int`]
+                Size of the convolving kernel.
     **kwargs: 'dict' [`str`]
         Additional keyword arguments passed to `torch.nn.Conv2d`_.
 
@@ -88,15 +89,15 @@ def conv_bn_relu(in_channels, out_channels, **kwargs):
     in_channels : `int`
         Number of input channels.
     out_channels : `int`
-        Number of desired output channels
+        Number of output channels.
     **kwargs: 'dict' [`str`]
-        Additional keyword arguments passed to
-        `pysegcnn.core.layers.Conv2dSame`.
+        Additional arguments passed to `pysegcnn.core.layers.Conv2dSame`.
 
     Returns
     -------
-    block : `torch.nn.Sequential`
-        An instance of `torch.nn.Sequential` containing the different layers.
+    block : `torch.nn.Sequential` [`torch.nn.Module`]
+        An instance of `torch.nn.Sequential` containing a sequence of
+        convolution, batch normalization and rectified linear unit layers.
 
     """
     return nn.Sequential(
@@ -109,18 +110,23 @@ def conv_bn_relu(in_channels, out_channels, **kwargs):
             )
 
 
-class Conv2dPool(nn.Module):
-    """Block of convolution, batchnorm, relu and 2x2 max pool.
+class Block(nn.Module):
+    """Basic convolutional block.
 
     Parameters
     ----------
     in_channels : `int`
         Number of input channels.
     out_channels : `int`
-        Number of desired output channels
+        Number of output channels.
     **kwargs: 'dict' [`str`]
-        Additional keyword arguments passed to
-        `pysegcnn.core.layers.Conv2dSame`.
+         Additional arguments passed to `pysegcnn.core.layers.Conv2dSame`.
+
+    Raises
+    ------
+    TypeError
+        Raised if `~pysegcnn.core.layers.Block.layers` method does not return
+        an instance of `torch.nn.Sequential`.
 
     Returns
     -------
@@ -129,56 +135,67 @@ class Conv2dPool(nn.Module):
     """
 
     def __init__(self, in_channels, out_channels, **kwargs):
-
-        # initialize nn.Module class
         super().__init__()
 
-        # create the convolutional blocks for this module
-        self.conv = conv_bn_relu(in_channels, out_channels, **kwargs)
+        # number of input and output channels
+        self.in_channels = in_channels
+        self.out_channels = out_channels
 
-        # create the 2x2 max pooling layer
-        self.pool = nn.MaxPool2d(2, return_indices=True)
+        # keyword arguments configuring convolutions
+        self.kwargs = kwargs
 
-    # defines the forward pass
-    def forward(self, x):
-        """Forward propagation through this block.
+        # the layers of the block
+        self.conv = self.layers()
+        if not isinstance(self.conv, nn.Sequential):
+            raise TypeError('{}.layers() should return an instance of {}.'
+                            .format(self.__class__.__name__,
+                                    repr(nn.Sequential)))
 
-        Parameters
-        ----------
-        x : `torch.Tensor`
-            Output of previous layer.
+    def layers(self):
+        """Define the layers of the block.
+
+        Raises
+        ------
+        NotImplementedError
+            Raised if `pysegcnn.core.layers.Block` is not inherited.
 
         Returns
         -------
-        y : `torch.Tensor`
-            Output of this block.
-        x : `torch.Tensor`
-            Output before max pooling. Stored for skip connections.
-        i : `torch.Tensor`
-            Indices of the max pooling operation. Used in unpooling operation.
+        layers : `torch.nn.Sequential` [`torch.nn.Module`]
+            Return an instance of `torch.nn.Sequential` containing a sequence
+            of layer (`torch.nn.Module` ) instances.
 
         """
-        # output of the convolutional block
-        x = self.conv(x)
+        raise NotImplementedError('Return an instance of {}.'
+                                  .format(repr(nn.Sequential)))
+
+    def forward(self):
+        """Forward pass of the block.
+
+        Raises
+        ------
+        NotImplementedError
+            Raised if `pysegcnn.core.layers.Block` is not inherited.
 
-        # output of the pooling layer
-        y, i = self.pool(x)
+        Returns
+        -------
+        None.
 
-        return (y, x, i)
+        """
+        raise NotImplementedError('Implement the forward pass.')
 
 
-class Conv2dUnpool(nn.Module):
-    """Block of convolution, batchnorm, relu and 2x2 max unpool.
+class EncoderBlock(Block):
+    """Block of a convolutional encoder.
 
     Parameters
     ----------
     in_channels : `int`
         Number of input channels.
     out_channels : `int`
-        Number of desired output channels
+        Number of output channels.
     **kwargs: 'dict' [`str`]
-        Additional keyword arguments passed to
-        `pysegcnn.core.layers.Conv2dSame`.
+         Additional arguments passed to `pysegcnn.core.layers.Conv2dSame`.
 
     Returns
     -------
@@ -187,62 +204,86 @@ class Conv2dUnpool(nn.Module):
     """
 
     def __init__(self, in_channels, out_channels, **kwargs):
-        super().__init__()
-
-        # create the convolutional blocks for this module
-        self.conv = conv_bn_relu(in_channels, out_channels, **kwargs)
+        super().__init__(in_channels, out_channels, **kwargs)
 
-        # create the unpooling layer
-        self.upsample = nn.MaxUnpool2d(2)
-
-    # defines the forward pass
-    def forward(self, x, feature, indices, skip):
-        """Forward propagation through this block.
+    def forward(self, x):
+        """Forward pass of an encoder block.
 
         Parameters
         ----------
         x : `torch.Tensor`
-            Output of previous layer.
-        feature : `torch.Tensor`
-            Encoder feature used for the skip connection.
-        indices : `torch.Tensor`
-            Indices of the max pooling operation. Used in unpooling operation.
-        skip : `bool`
-            Whether to apply skip connetion.
+            Input tensor, e.g. output of the previous block/layer.
 
         Returns
         -------
+        y : `torch.Tensor`
+            Output of the encoder block.
         x : `torch.Tensor`
-            Output of this block.
+            Intermediate output before applying downsampling. Useful to
+            implement skip connections.
+        indices : `torch.Tensor` or `None`
+            Optional indices of the downsampling method, e.g. indices of the
+            maxima when using `torch.nn.functional.max_pool2d`. Useful for
+            upsampling later. If no indices are required to upsample, simply
+            return ``indices`` = `None`.
 
         """
-        # upsampling with pooling indices
-        x = self.upsample(x, indices, output_size=feature.shape)
+        # the forward pass of the layers of the block
+        x = self.conv(x)
 
-        # check whether to apply the skip connection
-        # skip connection: concatenate the output of a layer in the encoder to
-        # the corresponding layer in the decoder (along the channel axis)
-        if skip:
-            x = torch.cat([x, feature], axis=1)
+        # the downsampling layer
+        y, indices = self.downsample(x)
 
-        # output of the convolutional layer
-        x = self.conv(x)
+        return (y, x, indices)
 
-        return x
+    def downsample(self, x):
+        """Define the downsampling method.
 
+        The `~pysegcnn.core.layers.EncoderBlock.downsample` `method should
+        implement the spatial pooling operation.
 
-class Conv2dUpsample(nn.Module):
-    """Block of convolution, batchnorm, relu and nearest neighbor upsampling.
+        Use one of the following functions to downsample:
+            - `torch.nn.functional.max_pool2d`
+            - `torch.nn.functional.interpolate`
+
+        See `pysegcnn.core.layers.ConvBnReluMaxPool` for an example
+        implementation.
+
+        Parameters
+        ----------
+        x : `torch.Tensor`
+            Input tensor, e.g. output of a convolutional block.
+
+        Raises
+        ------
+        NotImplementedError
+            Raised if `pysegcnn.core.layers.EncoderBlock` is not inherited.
+
+        Returns
+        -------
+        x : `torch.Tensor`
+            The spatially downsampled tensor.
+        indices : `torch.Tensor` or `None`
+            Optional indices of the downsampling method, e.g. indices of the
+            maxima when using `torch.nn.functional.max_pool2d`. Useful for
+            upsampling later. If no indices are required to upsample, simply
+            return ``indices`` = `None`.
+
+        """
+        raise NotImplementedError('Implement the downsampling function.')
+
+
+class DecoderBlock(Block):
+    """Block of a convolutional decoder.
 
     Parameters
     ----------
     in_channels : `int`
         Number of input channels.
     out_channels : `int`
-        Number of desired output channels
+        Number of output channels.
     **kwargs: 'dict' [`str`]
-        Additional keyword arguments passed to
-        `pysegcnn.core.layers.Conv2dSame`.
+         Additional arguments passed to `pysegcnn.core.layers.Conv2dSame`.
 
     Returns
     -------
@@ -251,43 +292,37 @@ class Conv2dUpsample(nn.Module):
     """
 
     def __init__(self, in_channels, out_channels, **kwargs):
-        super().__init__()
+        super().__init__(in_channels, out_channels, **kwargs)
 
-        # create the convolutional blocks for this module
-        self.conv = conv_bn_relu(in_channels, out_channels, **kwargs)
-
-        # create the upsampling layer
-        self.upsample = F.interpolate
-
-    # defines the forward pass
     def forward(self, x, feature, indices, skip):
-        """Forward propagation through this block.
+        """Forward pass of a decoder block.
 
         Parameters
         ----------
         x : `torch.Tensor`
-            Output of previous layer.
-        feature : `torch.Tensor`
-            Encoder feature used for the skip connection.
-        indices : `torch.Tensor`
-            Indices of the max pooling operation. Used in unpooling operation.
-            Not used here, but passed to preserve generic interface. Useful in
-            `pysegcnn.core.layers.Decoder`.
+            Input tensor.
+        feature : `torch.Tensor`, shape=(batch, channel, height, width)
+            Intermediate output of a layer in the encoder.
+            If ``skip`` = True, ``feature`` is concatenated (along the channel
+            axis) to the output of the respective upsampling layer in the
+            decoder (skip connection).
+        indices : `torch.Tensor` or `None`
+            Indices of the encoder downsampling method.
         skip : `bool`
-            Whether to apply skip connection.
+            Whether to apply the skip connection.
 
         Returns
         -------
         x : `torch.Tensor`
-            Output of this block.
+            Output of the decoder block.
 
         """
-        # upsampling with pooling indices
-        x = self.upsample(x, size=feature.shape[2:], mode='nearest')
+        # upsample
+        x = self.upsample(x, feature, indices)
 
         # check whether to apply the skip connection
         # skip connection: concatenate the output of a layer in the encoder to
-        # the corresponding layer in the decoder (along the channel axis )
+        # the corresponding layer in the decoder (along the channel axis)
         if skip:
             x = torch.cat([x, feature], axis=1)
 
@@ -296,20 +331,65 @@ class Conv2dUpsample(nn.Module):
 
         return x
 
+    def upsample(self, x, feature, indices):
+        """Define the upsampling method.
+
+        The `~pysegcnn.core.layers.DecoderBlock.upsample` `method should
+        implement the spatial upsampling operation.
+
+        Use one of the following functions to upsample:
+            - `torch.nn.functional.max_unpool2d`
+            - `torch.nn.functional.interpolate`
+
+        See `pysegcnn.core.layers.ConvBnReluMaxUnpool` or
+        `pysegcnn.core.layers.ConvBnReluUpsample` for an example
+        implementation.
+
+        Parameters
+        ----------
+        x : `torch.Tensor`
+            Input tensor, e.g. output of a convolutional block.
+        feature : `torch.Tensor`, shape=(batch, channel, height, width)
+            Intermediate output of a layer in the encoder. Used to implement
+            skip connections.
+        indices : `torch.Tensor` or `None`
+            Indices of the encoder downsampling method.
+
+        Raises
+        ------
+        NotImplementedError
+            Raised if `pysegcnn.core.layers.DecoderBlock` is not inherited.
+
+        Returns
+        -------
+        x : `torch.Tensor`
+            The spatially upsampled tensor.
+
+        """
+        raise NotImplementedError('Implement the upsampling function')
+
 
 class Encoder(nn.Module):
     """Generic convolutional encoder.
 
+    When instanciating an encoder-decoder architechure, ``filters`` should be
+    the same for `pysegcnn.core.layers.Encoder` and
+    `pysegcnn.core.layers.Decoder`.
+
+    See `pysegcnn.core.models.UNet` for an example implementation.
+
     Parameters
     ----------
     filters : `list` [`int`]
-        List of input channels to each convolutional block.
-    block : `torch.nn.Module`
-        The convolutional block. ``block`` should inherit from
-        `torch.nn.Module`, e.g. `pysegcnn.core.layers.Conv2dPool`.
+        List of input channels to each convolutional block. The length of
+        ``filters`` determines the depth of the encoder. The first element of
+        ``filters`` has to be the number of channels of the input images.
+    block : `pysegcnn.core.layers.EncoderBlock`
+        The convolutional block defining a layer in the encoder.
+        A subclass of `pysegcnn.core.layers.EncoderBlock`, e.g.
+        `pysegcnn.core.layers.ConvBnReluMaxPool`.
     **kwargs: 'dict' [`str`]
-        Additional keyword arguments passed to
-        `pysegcnn.core.layers.Conv2dSame`.
+        Additional arguments passed to `pysegcnn.core.layers.Conv2dSame`.
 
     Returns
     -------
@@ -322,23 +402,34 @@ class Encoder(nn.Module):
 
         # the number of filters for each block: the first element of filters
         # has to be the number of input channels
-        self.features = filters
+        self.features = np.asarray(filters)
 
         # the block of operations defining a layer in the encoder
+        if not issubclass(block, EncoderBlock):
+            raise TypeError('"block" expected to be a subclass of {}.'
+                            .format(repr(EncoderBlock)))
         self.block = block
 
         # construct the encoder layers
         self.layers = []
-        for i, (l, lp1) in enumerate(zip(self.features, self.features[1:])):
+        for lyr, lyrp1 in zip(self.features, self.features[1:]):
             # append blocks to the encoder layers
-            self.layers.append(self.block(l, lp1, **kwargs))
+            self.layers.append(self.block(lyr, lyrp1, **kwargs))
 
         # convert list of layers to ModuleList
         self.layers = nn.ModuleList(*[self.layers])
 
-    # forward pass through the encoder
     def forward(self, x):
-        """Forward propagation through the encoder.
+        """Forward pass of the encoder.
+
+        Stores intermediate outputs in a dictionary. The keys of the dictionary
+        are the number of the network layers and the values are dictionaries
+        with the following (key, value) pairs:
+            ``"feature"``
+                The intermediate encoder outputs (`torch.Tensor`).
+            ``"indices"``
+                The indices of the max pooling layer, if required
+                (`torch.Tensor`).
 
         Parameters
         ----------
@@ -354,7 +445,6 @@ class Encoder(nn.Module):
         # initialize a dictionary that caches the intermediate outputs, i.e.
         # features and pooling indices of each block in the encoder
         self.cache = {}
-
         for i, layer in enumerate(self.layers):
             # apply current encoder layer forward pass
             x, y, ind = layer.forward(x)
@@ -368,18 +458,26 @@ class Encoder(nn.Module):
 class Decoder(nn.Module):
     """Generic convolutional decoder.
 
+    When instanciating an encoder-decoder architechure, ``filters`` should be
+    the same for `pysegcnn.core.layers.Encoder` and
+    `pysegcnn.core.layers.Decoder`.
+
+    See `pysegcnn.core.models.UNet` for an example implementation.
+
     Parameters
     ----------
     filters : `list` [`int`]
-        List of input channels to each convolutional block.
-    block : `torch.nn.Module`
-        The convolutional block. ``block`` should inherit from
-        `torch.nn.Module`, e.g. `pysegcnn.core.layers.Conv2dUnpool`.
+        List of input channels to each convolutional block. The length of
+        ``filters`` determines the depth of the decoder. The first element of
+        ``filters`` has to be the number of channels of the input images.
+    block : `pysegcnn.core.layers.DecoderBlock`
+        The convolutional block defining a layer in the decoder.
+        A subclass of `pysegcnn.core.layers.DecoderBlock`, e.g.
+        `pysegcnn.core.layers.ConvBnReluMaxUnpool`.
     skip : `bool`
         Whether to apply skip connections from the encoder to the decoder.
     **kwargs: 'dict' [`str`]
-        Additional keyword arguments passed to
-        `pysegcnn.core.layers.Conv2dSame`.
+        Additional arguments passed to `pysegcnn.core.layers.Conv2dSame`.
 
     Returns
     -------
@@ -391,12 +489,15 @@ class Decoder(nn.Module):
         super().__init__()
 
         # the block of operations defining a layer in the decoder
+        if not issubclass(block, DecoderBlock):
+            raise TypeError('"block" expected to be a subclass of {}.'
+                            .format(repr(DecoderBlock)))
         self.block = block
 
         # the number of filters for each block is symmetric to the encoder:
         # the last two element of filters have to be equal in order to apply
         # last skip connection
-        self.features = filters[::-1]
+        self.features = np.asarray(filters)[::-1]
         self.features[-1] = self.features[-2]
 
         # whether to apply skip connections
@@ -404,9 +505,7 @@ class Decoder(nn.Module):
 
         # in case of skip connections, the number of input channels to
         # each block of the decoder is doubled
-        n_in = 1
-        if self.skip:
-            n_in *= 2
+        n_in = 2 if self.skip else 1
 
         # construct decoder layers
         self.layers = []
@@ -416,20 +515,21 @@ class Decoder(nn.Module):
         # convert list of layers to ModuleList
         self.layers = nn.ModuleList(*[self.layers])
 
-    # forward pass through decoder
     def forward(self, x, enc_cache):
-        """Forward propagation through the decoder.
+        """Forward pass of the decoder.
 
         Parameters
         ----------
         x : `torch.Tensor`
             Output of the encoder.
-        enc_cache : `dict`
-            Cache dictionary with keys:
-                ``'feature'``
-                    Encoder features used for the skip connection.
-                ``'indices'``
-                    The indices of the max pooling operations.
+        enc_cache : `dict` [`dict`]
+            Cache dictionary. The keys of the dictionary are the number of the
+            network layers and the values are dictionaries with the following
+            (key, value) pairs:
+                ``"feature"``
+                    The intermediate encoder outputs (`torch.Tensor`).
+                ``"indices"``
+                    The indices of the max pooling layer (`torch.Tensor`).
 
         Returns
         -------
@@ -449,3 +549,175 @@ class Decoder(nn.Module):
             x = layer.forward(x, feature, indices, self.skip)
 
         return x
+
+
+class ConvBnReluMaxPool(EncoderBlock):
+    """Block of convolution, batchnorm, relu and 2x2 max pool.
+
+    Parameters
+    ----------
+    in_channels : `int`
+        Number of input channels.
+    out_channels : `int`
+        Number of output channels.
+    **kwargs: 'dict' [`str`]
+        Additional keyword arguments passed to
+        `pysegcnn.core.layers.Conv2dSame`.
+
+    Returns
+    -------
+    None.
+
+    """
+
+    def __init__(self, in_channels, out_channels, **kwargs):
+        super().__init__(in_channels, out_channels, **kwargs)
+
+    def layers(self):
+        """Sequence of convolution, batchnorm and relu layers.
+
+        Returns
+        -------
+        layers : `torch.nn.Sequential` [`torch.nn.Module`]
+            An instance of `torch.nn.Sequential` containing the sequence
+            of convolution, batchnorm and relu layer (`torch.nn.Module`)
+            instances.
+
+        """
+        return conv_bn_relu(self.in_channels, self.out_channels, **self.kwargs)
+
+    def downsample(self, x):
+        """2x2 max pooling layer, `torch.nn.functional.max_pool2d`.
+
+        Parameters
+        ----------
+        x : `torch.Tensor`
+            Input tensor.
+
+        Returns
+        -------
+        x : `torch.Tensor`
+            The 2x2 max pooled tensor.
+        indices : `torch.Tensor` or `None`
+            The indices of the maxima. Useful for upsampling with
+            `torch.nn.functional.max_unpool2d`.
+        """
+        x, indices = F.max_pool2d(x, kernel_size=2, return_indices=True)
+        return x, indices
+
+
+class ConvBnReluMaxUnpool(DecoderBlock):
+    """Block of convolution, batchnorm, relu and 2x2 max unpool.
+
+    Parameters
+    ----------
+    in_channels : `int`
+        Number of input channels.
+    out_channels : `int`
+        Number of output channels
+    **kwargs: 'dict' [`str`]
+        Additional keyword arguments passed to
+        `pysegcnn.core.layers.Conv2dSame`.
+
+    Returns
+    -------
+    None.
+
+    """
+
+    def __init__(self, in_channels, out_channels, **kwargs):
+        super().__init__(in_channels, out_channels, **kwargs)
+
+    def layers(self):
+        """Sequence of convolution, batchnorm and relu layers.
+
+        Returns
+        -------
+        layers : `torch.nn.Sequential` [`torch.nn.Module`]
+            An instance of `torch.nn.Sequential` containing the sequence
+            of convolution, batchnorm and relu layer (`torch.nn.Module`)
+            instances.
+
+        """
+        return conv_bn_relu(self.in_channels, self.out_channels, **self.kwargs)
+
+    def upsample(self, x, feature, indices):
+        """2x2 max unpooling layer.
+
+        Parameters
+        ----------
+        x : `torch.Tensor`
+            Input tensor.
+        feature : `torch.Tensor`, shape=(batch, channel, height, width)
+            Intermediate output of a layer in the encoder. Used to determine
+            the output shape of the upsampling operation.
+        indices : `torch.Tensor`
+            The indices of the maxima of the max pooling operation
+            (as returned by `torch.nn.functional.max_pool2d`).
+
+        Returns
+        -------
+        x : `torch.Tensor`
+            The 2x2 max unpooled tensor.
+
+        """
+        return F.max_unpool2d(x, indices, kernel_size=2,
+                              output_size=feature.shape[2:])
+
+
+class ConvBnReluUpsample(DecoderBlock):
+    """Block of convolution, batchnorm, relu and nearest neighbor upsampling.
+
+    Parameters
+    ----------
+    in_channels : `int`
+        Number of input channels.
+    out_channels : `int`
+        Number of output channels
+    **kwargs: 'dict' [`str`]
+        Additional arguments passed to `pysegcnn.core.layers.Conv2dSame`.
+
+    Returns
+    -------
+    None.
+
+    """
+
+    def __init__(self, in_channels, out_channels, **kwargs):
+        super().__init__(in_channels, out_channels, **kwargs)
+
+    def layers(self):
+        """Sequence of convolution, batchnorm and relu layers.
+
+        Returns
+        -------
+        layers : `torch.nn.Sequential` [`torch.nn.Module`]
+            An instance of `torch.nn.Sequential` containing the sequence
+            of convolution, batchnorm and relu layer (`torch.nn.Module`)
+            instances.
+
+        """
+        return conv_bn_relu(self.in_channels, self.out_channels, **self.kwargs)
+
+    def upsample(self, x, feature, indices=None):
+        """Nearest neighbor upsampling.
+
+        Parameters
+        ----------
+        x : `torch.Tensor`
+            Input tensor.
+        feature : `torch.Tensor`, shape=(batch, channel, height, width)
+            Intermediate output of a layer in the encoder. Used to determine
+            the output shape of the upsampling operation.
+        indices : `None`, optional
+            The indices of the maxima of the max pooling operation
+            (as returned by `torch.nn.functional.max_pool2d`). Not required by
+            this upsampling method.
+
+        Returns
+        -------
+        x : `torch.Tensor`
+            The 2x2 max unpooled tensor.
+
+        """
+        return F.interpolate(x, size=feature.shape[2:], mode='nearest')
-- 
GitLab