diff --git a/pysegcnn/core/models.py b/pysegcnn/core/models.py index cbf195b14d845797a270f60dd5d318bf3e484f41..6d840fc67393d86b0bef7086c27561feb96aee8a 100644 --- a/pysegcnn/core/models.py +++ b/pysegcnn/core/models.py @@ -7,11 +7,13 @@ Created on Fri Jun 26 16:31:36 2020 """ # builtins import os +import enum # externals import numpy as np import torch import torch.nn as nn +import torch.optim as optim # locals from pysegcnn.core.layers import (Encoder, Decoder, Conv2dPool, Conv2dUnpool, @@ -139,3 +141,14 @@ class UNet(Network): # classification return self.classifier(x) + + +class SupportedModels(enum.Enum): + Unet = UNet + + +class SupportedOptimizers(enum.Enum): + Adam = optim.Adam + +class SupportedLossFunctions(enum.Enum): + CrossEntropy = nn.CrossEntropyLoss