Skip to content
Snippets Groups Projects
Commit 92cd6ba6 authored by Frisinghelli Daniel's avatar Frisinghelli Daniel
Browse files

Number of input channels and output classes are saved to state file;...

Number of input channels and output classes are saved to state file; configurable loss function support dropped.
parent 49be0399
No related branches found
No related tags found
No related merge requests found
...@@ -158,6 +158,10 @@ class Network(nn.Module): ...@@ -158,6 +158,10 @@ class Network(nn.Module):
# store model epoch # store model epoch
model_state['epoch'] = self.epoch model_state['epoch'] = self.epoch
# store model construction parameters
model_state['in_channels'] = self.in_channels
model_state['nclasses'] = self.nclasses
# store model and optimizer state # store model and optimizer state
model_state['model_state_dict'] = self.state_dict() model_state['model_state_dict'] = self.state_dict()
model_state['optim_state_dict'] = optimizer.state_dict() model_state['optim_state_dict'] = optimizer.state_dict()
...@@ -284,7 +288,7 @@ class Network(nn.Module): ...@@ -284,7 +288,7 @@ class Network(nn.Module):
# instanciate the pretrained model architecture # instanciate the pretrained model architecture
model = model_class(state_file=state_file, model = model_class(state_file=state_file,
in_channels=len(model_state['bands']), in_channels=model_state['in_channels'],
nclasses=model_state['nclasses']) nclasses=model_state['nclasses'])
# instanciate the optimizer # instanciate the optimizer
...@@ -487,9 +491,3 @@ class SupportedOptimizers(enum.Enum): ...@@ -487,9 +491,3 @@ class SupportedOptimizers(enum.Enum):
Adam = optim.Adam Adam = optim.Adam
AdamW = optim.AdamW AdamW = optim.AdamW
class SupportedLossFunctions(enum.Enum):
"""Names and corresponding classes of the tested loss functions."""
CrossEntropy = nn.CrossEntropyLoss
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment