diff --git a/pysegcnn/core/models.py b/pysegcnn/core/models.py index df3c3a697cdb124a1d64992e80d33ad4a8ae2a37..c04fa4cae58352f1b9c1bd00fa36ab207421389c 100644 --- a/pysegcnn/core/models.py +++ b/pysegcnn/core/models.py @@ -26,8 +26,8 @@ import torch.nn as nn import torch.optim as optim # locals -from pysegcnn.core.layers import (Encoder, Decoder, Conv2dPool, Conv2dUnpool, - Conv2dSame) +from pysegcnn.core.layers import (Encoder, Decoder, ConvBnReluMaxPool, + ConvBnReluMaxUnpool, Conv2dSame) # module level logger LOGGER = logging.getLogger(__name__) @@ -269,11 +269,11 @@ class UNet(Network): self.epoch = 0 # construct the encoder - self.encoder = Encoder(filters=self.filters, block=Conv2dPool, + self.encoder = Encoder(filters=self.filters, block=ConvBnReluMaxPool, **kwargs) # construct the decoder - self.decoder = Decoder(filters=self.filters, block=Conv2dUnpool, + self.decoder = Decoder(filters=self.filters, block=ConvBnReluMaxUnpool, skip=skip, **kwargs) # construct the classifier