Skip to content
Snippets Groups Projects
Commit bac6bc17 authored by Frisinghelli Daniel's avatar Frisinghelli Daniel
Browse files

Improved saving method to rebuild model in case of checkpoint or transfer learning

parent 7c1732db
No related branches found
No related tags found
No related merge requests found
......@@ -35,20 +35,37 @@ class Network(nn.Module):
for param in self.parameters():
param.requires_grad = True
def save(self, optimizer, state_file,
def save(self, state_file, optimizer, bands,
outpath=os.path.join(os.getcwd(), '_models')):
# check if the output path exists and if not, create it
if not os.path.isdir(outpath):
os.makedirs(outpath, exist_ok=True)
# create a dictionary that stores the model state
model_state = {
'epoch': self.epoch,
'model_state_dict': self.state_dict(),
'optim_state_dict': optimizer.state_dict()
# initialize dictionary to store network parameters
model_state = {}
# store input bands
model_state['bands'] = bands
# store construction parameters to instanciate the network
model_state['params'] = {
'skip': self.skip,
'filters': self.nfilters,
'nclasses': self.nclasses,
'in_channels': self.in_channels
}
# store optional keyword arguments
model_state['kwargs'] = self.kwargs
# store model epoch
model_state['epoch'] = self.epoch
# store model and optimizer state
model_state['model_state_dict'] = self.state_dict()
model_state['optim_state_dict'] = optimizer.state_dict()
# model state dictionary stores the values of all trainable parameters
state = os.path.join(outpath, state_file)
torch.save(model_state, state)
......@@ -87,9 +104,16 @@ class UNet(Network):
# number of classes
self.nclasses = nclasses
# get the configuration for the convolutional layers of the encoder
# configuration of the convolutional layers in the network
self.kwargs = kwargs
self.nfilters = filters
# convolutional layers of the encoder
self.filters = np.hstack([np.array(in_channels), np.array(filters)])
# whether to apply skip connections
self.skip = skip
# number of epochs trained
self.epoch = 0
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment