Skip to content
GitLab
Explore
Sign in
Primary navigation
Search or go to…
Project
P
PySegCNN
Manage
Activity
Members
Labels
Plan
Issues
Issue boards
Milestones
Wiki
Code
Merge requests
Repository
Branches
Commits
Tags
Repository graph
Compare revisions
Snippets
Deploy
Releases
Package Registry
Container Registry
Model registry
Operate
Terraform modules
Monitor
Incidents
Analyze
Value stream analytics
Contributor analytics
Repository analytics
Model experiments
Help
Help
Support
GitLab documentation
Compare GitLab plans
Community forum
Contribute to GitLab
Provide feedback
Keyboard shortcuts
?
Snippets
Groups
Projects
Show more breadcrumbs
earth_observation_public
PySegCNN
Commits
32b0b0fd
Commit
32b0b0fd
authored
4 years ago
by
Frisinghelli Daniel
Browse files
Options
Downloads
Patches
Plain Diff
Implemented generic classes for an encoder-decoder architecture
parent
4d102fdc
No related branches found
No related tags found
No related merge requests found
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
pysegcnn/core/layers.py
+399
-127
399 additions, 127 deletions
pysegcnn/core/layers.py
with
399 additions
and
127 deletions
pysegcnn/core/layers.py
+
399
−
127
View file @
32b0b0fd
...
...
@@ -15,6 +15,7 @@ License
# -*- coding: utf-8 -*-
# externals
import
numpy
as
np
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
...
...
@@ -32,11 +33,11 @@ class Conv2dSame(nn.Conv2d):
*args: `list` [`str`]
positional arguments passed to `torch.nn.Conv2d`:
``
'
in_channels
'
``: `int`
Number of input channels
``
'
kernel_size
'
``: `int` or `tuple` [`int`]
Size of the convolving kernel
Number of input channels.
``
'
out_channels
'
``: `int`
Number of desired output channels
Number of output channels.
``
'
kernel_size
'
``: `int` or `tuple` [`int`]
Size of the convolving kernel.
**kwargs:
'
dict
'
[`str`]
Additional keyword arguments passed to `torch.nn.Conv2d`_.
...
...
@@ -88,15 +89,15 @@ def conv_bn_relu(in_channels, out_channels, **kwargs):
in_channels : `int`
Number of input channels.
out_channels : `int`
Number of
desired
output channels
Number of output channels
.
**kwargs:
'
dict
'
[`str`]
Additional keyword arguments passed to
`pysegcnn.core.layers.Conv2dSame`.
Additional arguments passed to `pysegcnn.core.layers.Conv2dSame`.
Returns
-------
block : `torch.nn.Sequential`
An instance of `torch.nn.Sequential` containing the different layers.
block : `torch.nn.Sequential` [`torch.nn.Module`]
An instance of `torch.nn.Sequential` containing a sequence of
convolution, batch normalization and rectified linear unit layers.
"""
return
nn
.
Sequential
(
...
...
@@ -109,18 +110,23 @@ def conv_bn_relu(in_channels, out_channels, **kwargs):
)
class
Conv2dPool
(
nn
.
Module
):
"""
B
lock of
convolution
, batchnorm, relu and 2x2 max pool
.
class
Block
(
nn
.
Module
):
"""
B
asic
convolution
al block
.
Parameters
----------
in_channels : `int`
Number of input channels.
out_channels : `int`
Number of
desired
output channels
Number of output channels
.
**kwargs:
'
dict
'
[`str`]
Additional keyword arguments passed to
`pysegcnn.core.layers.Conv2dSame`.
Additional arguments passed to `pysegcnn.core.layers.Conv2dSame`.
Raises
------
TypeError
Raised if `~pysegcnn.core.layers.Block.layers` method does not return
an instance of `torch.nn.Sequential`.
Returns
-------
...
...
@@ -129,56 +135,67 @@ class Conv2dPool(nn.Module):
"""
def
__init__
(
self
,
in_channels
,
out_channels
,
**
kwargs
):
# initialize nn.Module class
super
().
__init__
()
# create the convolutional blocks for this module
self
.
conv
=
conv_bn_relu
(
in_channels
,
out_channels
,
**
kwargs
)
# number of input and output channels
self
.
in_channels
=
in_channels
self
.
out_channels
=
out_channels
#
create the 2x2 max pooling layer
self
.
pool
=
nn
.
MaxPool2d
(
2
,
return_indices
=
True
)
#
keyword arguments configuring convolutions
self
.
kwargs
=
kwargs
# defines the forward pass
def
forward
(
self
,
x
):
"""
Forward propagation through this block.
# the layers of the block
self
.
conv
=
self
.
layers
()
if
not
isinstance
(
self
.
conv
,
nn
.
Sequential
):
raise
TypeError
(
'
{}.layers() should return an instance of {}.
'
.
format
(
self
.
__class__
.
__name__
,
repr
(
nn
.
Sequential
)))
Parameters
----------
x : `torch.Tensor`
Output of previous layer.
def
layers
(
self
):
"""
Define the layers of the block.
Raises
------
NotImplementedError
Raised if `pysegcnn.core.layers.Block` is not inherited.
Returns
-------
y : `torch.Tensor`
Output of this block.
x : `torch.Tensor`
Output before max pooling. Stored for skip connections.
i : `torch.Tensor`
Indices of the max pooling operation. Used in unpooling operation.
layers : `torch.nn.Sequential` [`torch.nn.Module`]
Return an instance of `torch.nn.Sequential` containing a sequence
of layer (`torch.nn.Module` ) instances.
"""
# output of the convolutional block
x
=
self
.
conv
(
x
)
raise
NotImplementedError
(
'
Return an instance of {}.
'
.
format
(
repr
(
nn
.
Sequential
)))
def
forward
(
self
):
"""
Forward pass of the block.
Raises
------
NotImplementedError
Raised if `pysegcnn.core.layers.Block` is not inherited.
# output of the pooling layer
y
,
i
=
self
.
pool
(
x
)
Returns
-------
None.
return
(
y
,
x
,
i
)
"""
raise
NotImplementedError
(
'
Implement the forward pass.
'
)
class
Conv2dUnpool
(
nn
.
Module
):
"""
Block of convolution
, batchnorm, relu and 2x2 max unpool
.
class
EncoderBlock
(
Block
):
"""
Block of
a
convolution
al encoder
.
Parameters
----------
in_channels : `int`
Number of input channels.
out_channels : `int`
Number of
desired
output channels
Number of output channels
.
**kwargs:
'
dict
'
[`str`]
Additional keyword arguments passed to
`pysegcnn.core.layers.Conv2dSame`.
Additional arguments passed to `pysegcnn.core.layers.Conv2dSame`.
Returns
-------
...
...
@@ -187,62 +204,86 @@ class Conv2dUnpool(nn.Module):
"""
def
__init__
(
self
,
in_channels
,
out_channels
,
**
kwargs
):
super
().
__init__
()
# create the convolutional blocks for this module
self
.
conv
=
conv_bn_relu
(
in_channels
,
out_channels
,
**
kwargs
)
super
().
__init__
(
in_channels
,
out_channels
,
**
kwargs
)
# create the unpooling layer
self
.
upsample
=
nn
.
MaxUnpool2d
(
2
)
# defines the forward pass
def
forward
(
self
,
x
,
feature
,
indices
,
skip
):
"""
Forward propagation through this block.
def
forward
(
self
,
x
):
"""
Forward pass of an encoder block.
Parameters
----------
x : `torch.Tensor`
Output of previous layer.
feature : `torch.Tensor`
Encoder feature used for the skip connection.
indices : `torch.Tensor`
Indices of the max pooling operation. Used in unpooling operation.
skip : `bool`
Whether to apply skip connetion.
Input tensor, e.g. output of the previous block/layer.
Returns
-------
y : `torch.Tensor`
Output of the encoder block.
x : `torch.Tensor`
Output of this block.
Intermediate output before applying downsampling. Useful to
implement skip connections.
indices : `torch.Tensor` or `None`
Optional indices of the downsampling method, e.g. indices of the
maxima when using `torch.nn.functional.max_pool2d`. Useful for
upsampling later. If no indices are required to upsample, simply
return ``indices`` = `None`.
"""
#
upsampling with pooling indices
x
=
self
.
upsample
(
x
,
indices
,
output_size
=
feature
.
shape
)
#
the forward pass of the layers of the block
x
=
self
.
conv
(
x
)
# check whether to apply the skip connection
# skip connection: concatenate the output of a layer in the encoder to
# the corresponding layer in the decoder (along the channel axis)
if
skip
:
x
=
torch
.
cat
([
x
,
feature
],
axis
=
1
)
# the downsampling layer
y
,
indices
=
self
.
downsample
(
x
)
# output of the convolutional layer
x
=
self
.
conv
(
x
)
return
(
y
,
x
,
indices
)
return
x
def
downsample
(
self
,
x
):
"""
Define the downsampling method.
The `~pysegcnn.core.layers.EncoderBlock.downsample` `method should
implement the spatial pooling operation.
class
Conv2dUpsample
(
nn
.
Module
):
"""
Block of convolution, batchnorm, relu and nearest neighbor upsampling.
Use one of the following functions to downsample:
- `torch.nn.functional.max_pool2d`
- `torch.nn.functional.interpolate`
See `pysegcnn.core.layers.ConvBnReluMaxPool` for an example
implementation.
Parameters
----------
x : `torch.Tensor`
Input tensor, e.g. output of a convolutional block.
Raises
------
NotImplementedError
Raised if `pysegcnn.core.layers.EncoderBlock` is not inherited.
Returns
-------
x : `torch.Tensor`
The spatially downsampled tensor.
indices : `torch.Tensor` or `None`
Optional indices of the downsampling method, e.g. indices of the
maxima when using `torch.nn.functional.max_pool2d`. Useful for
upsampling later. If no indices are required to upsample, simply
return ``indices`` = `None`.
"""
raise
NotImplementedError
(
'
Implement the downsampling function.
'
)
class
DecoderBlock
(
Block
):
"""
Block of a convolutional decoder.
Parameters
----------
in_channels : `int`
Number of input channels.
out_channels : `int`
Number of
desired
output channels
Number of output channels
.
**kwargs:
'
dict
'
[`str`]
Additional keyword arguments passed to
`pysegcnn.core.layers.Conv2dSame`.
Additional arguments passed to `pysegcnn.core.layers.Conv2dSame`.
Returns
-------
...
...
@@ -251,43 +292,37 @@ class Conv2dUpsample(nn.Module):
"""
def
__init__
(
self
,
in_channels
,
out_channels
,
**
kwargs
):
super
().
__init__
()
super
().
__init__
(
in_channels
,
out_channels
,
**
kwargs
)
# create the convolutional blocks for this module
self
.
conv
=
conv_bn_relu
(
in_channels
,
out_channels
,
**
kwargs
)
# create the upsampling layer
self
.
upsample
=
F
.
interpolate
# defines the forward pass
def
forward
(
self
,
x
,
feature
,
indices
,
skip
):
"""
Forward p
ropagation through this
block.
"""
Forward p
ass of a decoder
block.
Parameters
----------
x : `torch.Tensor`
Output of previous layer.
feature : `torch.Tensor`
Encoder feature used for the skip connection.
indices : `torch.Tensor`
Indices of the max pooling operation. Used in unpooling operation.
Not used here, but passed to preserve generic interface. Useful in
`pysegcnn.core.layers.Decoder`.
Input tensor.
feature : `torch.Tensor`, shape=(batch, channel, height, width)
Intermediate output of a layer in the encoder.
If ``skip`` = True, ``feature`` is concatenated (along the channel
axis) to the output of the respective upsampling layer in the
decoder (skip connection).
indices : `torch.Tensor` or `None`
Indices of the encoder downsampling method.
skip : `bool`
Whether to apply skip connection.
Whether to apply
the
skip connection.
Returns
-------
x : `torch.Tensor`
Output of th
is
block.
Output of th
e decoder
block.
"""
# upsampl
ing with pooling indices
x
=
self
.
upsample
(
x
,
size
=
feature
.
shape
[
2
:],
mode
=
'
nearest
'
)
# upsampl
e
x
=
self
.
upsample
(
x
,
feature
,
indices
)
# check whether to apply the skip connection
# skip connection: concatenate the output of a layer in the encoder to
# the corresponding layer in the decoder (along the channel axis
)
# the corresponding layer in the decoder (along the channel axis)
if
skip
:
x
=
torch
.
cat
([
x
,
feature
],
axis
=
1
)
...
...
@@ -296,20 +331,65 @@ class Conv2dUpsample(nn.Module):
return
x
def
upsample
(
self
,
x
,
feature
,
indices
):
"""
Define the upsampling method.
The `~pysegcnn.core.layers.DecoderBlock.upsample` `method should
implement the spatial upsampling operation.
Use one of the following functions to upsample:
- `torch.nn.functional.max_unpool2d`
- `torch.nn.functional.interpolate`
See `pysegcnn.core.layers.ConvBnReluMaxUnpool` or
`pysegcnn.core.layers.ConvBnReluUpsample` for an example
implementation.
Parameters
----------
x : `torch.Tensor`
Input tensor, e.g. output of a convolutional block.
feature : `torch.Tensor`, shape=(batch, channel, height, width)
Intermediate output of a layer in the encoder. Used to implement
skip connections.
indices : `torch.Tensor` or `None`
Indices of the encoder downsampling method.
Raises
------
NotImplementedError
Raised if `pysegcnn.core.layers.DecoderBlock` is not inherited.
Returns
-------
x : `torch.Tensor`
The spatially upsampled tensor.
"""
raise
NotImplementedError
(
'
Implement the upsampling function
'
)
class
Encoder
(
nn
.
Module
):
"""
Generic convolutional encoder.
When instanciating an encoder-decoder architechure, ``filters`` should be
the same for `pysegcnn.core.layers.Encoder` and
`pysegcnn.core.layers.Decoder`.
See `pysegcnn.core.models.UNet` for an example implementation.
Parameters
----------
filters : `list` [`int`]
List of input channels to each convolutional block.
block : `torch.nn.Module`
The convolutional block. ``block`` should inherit from
`torch.nn.Module`, e.g. `pysegcnn.core.layers.Conv2dPool`.
List of input channels to each convolutional block. The length of
``filters`` determines the depth of the encoder. The first element of
``filters`` has to be the number of channels of the input images.
block : `pysegcnn.core.layers.EncoderBlock`
The convolutional block defining a layer in the encoder.
A subclass of `pysegcnn.core.layers.EncoderBlock`, e.g.
`pysegcnn.core.layers.ConvBnReluMaxPool`.
**kwargs:
'
dict
'
[`str`]
Additional keyword arguments passed to
`pysegcnn.core.layers.Conv2dSame`.
Additional arguments passed to `pysegcnn.core.layers.Conv2dSame`.
Returns
-------
...
...
@@ -322,23 +402,34 @@ class Encoder(nn.Module):
# the number of filters for each block: the first element of filters
# has to be the number of input channels
self
.
features
=
filters
self
.
features
=
np
.
asarray
(
filters
)
# the block of operations defining a layer in the encoder
if
not
issubclass
(
block
,
EncoderBlock
):
raise
TypeError
(
'"
block
"
expected to be a subclass of {}.
'
.
format
(
repr
(
EncoderBlock
)))
self
.
block
=
block
# construct the encoder layers
self
.
layers
=
[]
for
i
,
(
l
,
lp1
)
in
enumerate
(
zip
(
self
.
features
,
self
.
features
[
1
:])
)
:
for
lyr
,
l
yr
p1
in
zip
(
self
.
features
,
self
.
features
[
1
:]):
# append blocks to the encoder layers
self
.
layers
.
append
(
self
.
block
(
l
,
lp1
,
**
kwargs
))
self
.
layers
.
append
(
self
.
block
(
l
yr
,
l
yr
p1
,
**
kwargs
))
# convert list of layers to ModuleList
self
.
layers
=
nn
.
ModuleList
(
*
[
self
.
layers
])
# forward pass through the encoder
def
forward
(
self
,
x
):
"""
Forward propagation through the encoder.
"""
Forward pass of the encoder.
Stores intermediate outputs in a dictionary. The keys of the dictionary
are the number of the network layers and the values are dictionaries
with the following (key, value) pairs:
``
"
feature
"
``
The intermediate encoder outputs (`torch.Tensor`).
``
"
indices
"
``
The indices of the max pooling layer, if required
(`torch.Tensor`).
Parameters
----------
...
...
@@ -354,7 +445,6 @@ class Encoder(nn.Module):
# initialize a dictionary that caches the intermediate outputs, i.e.
# features and pooling indices of each block in the encoder
self
.
cache
=
{}
for
i
,
layer
in
enumerate
(
self
.
layers
):
# apply current encoder layer forward pass
x
,
y
,
ind
=
layer
.
forward
(
x
)
...
...
@@ -368,18 +458,26 @@ class Encoder(nn.Module):
class
Decoder
(
nn
.
Module
):
"""
Generic convolutional decoder.
When instanciating an encoder-decoder architechure, ``filters`` should be
the same for `pysegcnn.core.layers.Encoder` and
`pysegcnn.core.layers.Decoder`.
See `pysegcnn.core.models.UNet` for an example implementation.
Parameters
----------
filters : `list` [`int`]
List of input channels to each convolutional block.
block : `torch.nn.Module`
The convolutional block. ``block`` should inherit from
`torch.nn.Module`, e.g. `pysegcnn.core.layers.Conv2dUnpool`.
List of input channels to each convolutional block. The length of
``filters`` determines the depth of the decoder. The first element of
``filters`` has to be the number of channels of the input images.
block : `pysegcnn.core.layers.DecoderBlock`
The convolutional block defining a layer in the decoder.
A subclass of `pysegcnn.core.layers.DecoderBlock`, e.g.
`pysegcnn.core.layers.ConvBnReluMaxUnpool`.
skip : `bool`
Whether to apply skip connections from the encoder to the decoder.
**kwargs:
'
dict
'
[`str`]
Additional keyword arguments passed to
`pysegcnn.core.layers.Conv2dSame`.
Additional arguments passed to `pysegcnn.core.layers.Conv2dSame`.
Returns
-------
...
...
@@ -391,12 +489,15 @@ class Decoder(nn.Module):
super
().
__init__
()
# the block of operations defining a layer in the decoder
if
not
issubclass
(
block
,
DecoderBlock
):
raise
TypeError
(
'"
block
"
expected to be a subclass of {}.
'
.
format
(
repr
(
DecoderBlock
)))
self
.
block
=
block
# the number of filters for each block is symmetric to the encoder:
# the last two element of filters have to be equal in order to apply
# last skip connection
self
.
features
=
filters
[::
-
1
]
self
.
features
=
np
.
asarray
(
filters
)
[::
-
1
]
self
.
features
[
-
1
]
=
self
.
features
[
-
2
]
# whether to apply skip connections
...
...
@@ -404,9 +505,7 @@ class Decoder(nn.Module):
# in case of skip connections, the number of input channels to
# each block of the decoder is doubled
n_in
=
1
if
self
.
skip
:
n_in
*=
2
n_in
=
2
if
self
.
skip
else
1
# construct decoder layers
self
.
layers
=
[]
...
...
@@ -416,20 +515,21 @@ class Decoder(nn.Module):
# convert list of layers to ModuleList
self
.
layers
=
nn
.
ModuleList
(
*
[
self
.
layers
])
# forward pass through decoder
def
forward
(
self
,
x
,
enc_cache
):
"""
Forward p
ropagation through
the decoder.
"""
Forward p
ass of
the decoder.
Parameters
----------
x : `torch.Tensor`
Output of the encoder.
enc_cache : `dict`
Cache dictionary with keys:
``
'
feature
'
``
Encoder features used for the skip connection.
``
'
indices
'
``
The indices of the max pooling operations.
enc_cache : `dict` [`dict`]
Cache dictionary. The keys of the dictionary are the number of the
network layers and the values are dictionaries with the following
(key, value) pairs:
``
"
feature
"
``
The intermediate encoder outputs (`torch.Tensor`).
``
"
indices
"
``
The indices of the max pooling layer (`torch.Tensor`).
Returns
-------
...
...
@@ -449,3 +549,175 @@ class Decoder(nn.Module):
x
=
layer
.
forward
(
x
,
feature
,
indices
,
self
.
skip
)
return
x
class
ConvBnReluMaxPool
(
EncoderBlock
):
"""
Block of convolution, batchnorm, relu and 2x2 max pool.
Parameters
----------
in_channels : `int`
Number of input channels.
out_channels : `int`
Number of output channels.
**kwargs:
'
dict
'
[`str`]
Additional keyword arguments passed to
`pysegcnn.core.layers.Conv2dSame`.
Returns
-------
None.
"""
def
__init__
(
self
,
in_channels
,
out_channels
,
**
kwargs
):
super
().
__init__
(
in_channels
,
out_channels
,
**
kwargs
)
def
layers
(
self
):
"""
Sequence of convolution, batchnorm and relu layers.
Returns
-------
layers : `torch.nn.Sequential` [`torch.nn.Module`]
An instance of `torch.nn.Sequential` containing the sequence
of convolution, batchnorm and relu layer (`torch.nn.Module`)
instances.
"""
return
conv_bn_relu
(
self
.
in_channels
,
self
.
out_channels
,
**
self
.
kwargs
)
def
downsample
(
self
,
x
):
"""
2x2 max pooling layer, `torch.nn.functional.max_pool2d`.
Parameters
----------
x : `torch.Tensor`
Input tensor.
Returns
-------
x : `torch.Tensor`
The 2x2 max pooled tensor.
indices : `torch.Tensor` or `None`
The indices of the maxima. Useful for upsampling with
`torch.nn.functional.max_unpool2d`.
"""
x
,
indices
=
F
.
max_pool2d
(
x
,
kernel_size
=
2
,
return_indices
=
True
)
return
x
,
indices
class
ConvBnReluMaxUnpool
(
DecoderBlock
):
"""
Block of convolution, batchnorm, relu and 2x2 max unpool.
Parameters
----------
in_channels : `int`
Number of input channels.
out_channels : `int`
Number of output channels
**kwargs:
'
dict
'
[`str`]
Additional keyword arguments passed to
`pysegcnn.core.layers.Conv2dSame`.
Returns
-------
None.
"""
def
__init__
(
self
,
in_channels
,
out_channels
,
**
kwargs
):
super
().
__init__
(
in_channels
,
out_channels
,
**
kwargs
)
def
layers
(
self
):
"""
Sequence of convolution, batchnorm and relu layers.
Returns
-------
layers : `torch.nn.Sequential` [`torch.nn.Module`]
An instance of `torch.nn.Sequential` containing the sequence
of convolution, batchnorm and relu layer (`torch.nn.Module`)
instances.
"""
return
conv_bn_relu
(
self
.
in_channels
,
self
.
out_channels
,
**
self
.
kwargs
)
def
upsample
(
self
,
x
,
feature
,
indices
):
"""
2x2 max unpooling layer.
Parameters
----------
x : `torch.Tensor`
Input tensor.
feature : `torch.Tensor`, shape=(batch, channel, height, width)
Intermediate output of a layer in the encoder. Used to determine
the output shape of the upsampling operation.
indices : `torch.Tensor`
The indices of the maxima of the max pooling operation
(as returned by `torch.nn.functional.max_pool2d`).
Returns
-------
x : `torch.Tensor`
The 2x2 max unpooled tensor.
"""
return
F
.
max_unpool2d
(
x
,
indices
,
kernel_size
=
2
,
output_size
=
feature
.
shape
[
2
:])
class
ConvBnReluUpsample
(
DecoderBlock
):
"""
Block of convolution, batchnorm, relu and nearest neighbor upsampling.
Parameters
----------
in_channels : `int`
Number of input channels.
out_channels : `int`
Number of output channels
**kwargs:
'
dict
'
[`str`]
Additional arguments passed to `pysegcnn.core.layers.Conv2dSame`.
Returns
-------
None.
"""
def
__init__
(
self
,
in_channels
,
out_channels
,
**
kwargs
):
super
().
__init__
(
in_channels
,
out_channels
,
**
kwargs
)
def
layers
(
self
):
"""
Sequence of convolution, batchnorm and relu layers.
Returns
-------
layers : `torch.nn.Sequential` [`torch.nn.Module`]
An instance of `torch.nn.Sequential` containing the sequence
of convolution, batchnorm and relu layer (`torch.nn.Module`)
instances.
"""
return
conv_bn_relu
(
self
.
in_channels
,
self
.
out_channels
,
**
self
.
kwargs
)
def
upsample
(
self
,
x
,
feature
,
indices
=
None
):
"""
Nearest neighbor upsampling.
Parameters
----------
x : `torch.Tensor`
Input tensor.
feature : `torch.Tensor`, shape=(batch, channel, height, width)
Intermediate output of a layer in the encoder. Used to determine
the output shape of the upsampling operation.
indices : `None`, optional
The indices of the maxima of the max pooling operation
(as returned by `torch.nn.functional.max_pool2d`). Not required by
this upsampling method.
Returns
-------
x : `torch.Tensor`
The 2x2 max unpooled tensor.
"""
return
F
.
interpolate
(
x
,
size
=
feature
.
shape
[
2
:],
mode
=
'
nearest
'
)
This diff is collapsed.
Click to expand it.
Preview
0%
Loading
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Save comment
Cancel
Please
register
or
sign in
to comment