Commit 804234d7 authored by Frisinghelli Daniel's avatar Frisinghelli Daniel
Browse files

Moved architecture to climax.

parent 8173421c
......@@ -27,8 +27,7 @@ import torch.optim as optim
# locals
from pysegcnn.core.layers import (Encoder, Decoder, ConvBnReluMaxPool,
ConvBnReluMaxUnpool, ConvBnReluUpsample,
Conv2dSame)
ConvBnReluMaxUnpool, Conv2dSame)
from pysegcnn.core.utils import check_filename_length
# module level logger
......@@ -487,74 +486,6 @@ class SegNet(ConvolutionalAutoEncoder):
**kwargs)
class USegNet(ConvolutionalAutoEncoder):
"""An implementation of `SegNet`_ with interpolation in PyTorch.
.. _SegNet:
https://arxiv.org/abs/1511.00561
Attributes
----------
state_file : `str` or `None` or :py:class:`pathlib.Path`
The model state file, where the model parameters are saved.
in_channels : `int`
Number of channels of the input images.
nclasses : `int`
Number of classes.
filters : `list` [`int`]
List of the number of convolutional filters in each block.
skip : `bool`
Whether to apply skip connections from the encoder to the decoder.
kwargs : `dict` [`str`]
Additional keyword arguments passed to
:py:class:`pysegcnn.core.layers.Conv2dSame`.
epoch : `int`
Number of epochs the model was trained.
encoder : :py:class:`pysegcnn.core.layers.Encoder`
The convolutional encoder.
decoder : :py:class:`pysegcnn.core.layers.Decoder`
The convolutional decoder.
classifier : :py:class:`pysegcnn.core.layers.Conv2dSame`
The classification layer, a 1x1 convolution.
"""
def __init__(self, state_file, in_channels, nclasses,
filters=[32, 64, 128, 256], skip=True,
kwargs={'kernel_size': 3, 'stride': 1, 'dilation': 1}):
"""Initialize.
Parameters
----------
state_file : `str` or `None` or :py:class:`pathlib.Path`
The model state file, where the model parameters are saved.
in_channels : `int`
Number of channels of the input images.
nclasses : `int`
Number of classes.
filters : `list` [`int`], optional
List of input channels to each convolutional block. The default is
`[32, 64, 128, 256]`.
skip : `bool`, optional
Whether to apply skip connections from the encoder to the decoder.
The default is `True`.
kwargs: `dict` [`str`: `int`]
Additional keyword arguments passed to
:py:class:`pysegcnn.core.layers.Conv2dSame`. The default is
`{'kernel_size': 3, 'stride': 1, 'dilation': 1}`.
"""
super().__init__(state_file=state_file,
in_channels=in_channels,
nclasses=nclasses,
encoder_block=ConvBnReluMaxPool,
decoder_block=ConvBnReluUpsample,
filters=filters,
skip=skip,
**kwargs)
class SupportedModels(enum.Enum):
"""Names and corresponding classes of the implemented models."""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment