diff --git a/pysegcnn/core/models.py b/pysegcnn/core/models.py index f35fd2992a1d73e90160d28e823edffe3719f115..d37daca54a6ee9eb52f3c5b252d7fc67d0970eea 100644 --- a/pysegcnn/core/models.py +++ b/pysegcnn/core/models.py @@ -276,11 +276,12 @@ class Network(nn.Module): """ # get the model class of the pretrained model - model_class = item_in_enum(str(state_file).split('_')[0], + state_file = pathlib.Path(state_file) + model_class = item_in_enum(str(state_file.stem).split('_')[0], SupportedModels) # get the optimizer class of the pretrained model - optim_class = item_in_enum(str(state_file).split('_')[1], + optim_class = item_in_enum(str(state_file.stem).split('_')[1], SupportedOptimizers) # load the pretrained model configuration