From bac6bc17571fae0b2cfd1b3a9b41219f687f8e7b Mon Sep 17 00:00:00 2001 From: "Daniel.Frisinghelli" <daniel.frisinghelli@eurac.edu> Date: Wed, 15 Jul 2020 17:14:01 +0200 Subject: [PATCH] Improved saving method to rebuild model in case of checkpoint or transfer learning --- pytorch/models.py | 38 +++++++++++++++++++++++++++++++------- 1 file changed, 31 insertions(+), 7 deletions(-) diff --git a/pytorch/models.py b/pytorch/models.py index a0eb58a..e1add80 100644 --- a/pytorch/models.py +++ b/pytorch/models.py @@ -35,20 +35,37 @@ class Network(nn.Module): for param in self.parameters(): param.requires_grad = True - def save(self, optimizer, state_file, + def save(self, state_file, optimizer, bands, outpath=os.path.join(os.getcwd(), '_models')): # check if the output path exists and if not, create it if not os.path.isdir(outpath): os.makedirs(outpath, exist_ok=True) - # create a dictionary that stores the model state - model_state = { - 'epoch': self.epoch, - 'model_state_dict': self.state_dict(), - 'optim_state_dict': optimizer.state_dict() + # initialize dictionary to store network parameters + model_state = {} + + # store input bands + model_state['bands'] = bands + + # store construction parameters to instanciate the network + model_state['params'] = { + 'skip': self.skip, + 'filters': self.nfilters, + 'nclasses': self.nclasses, + 'in_channels': self.in_channels } + # store optional keyword arguments + model_state['kwargs'] = self.kwargs + + # store model epoch + model_state['epoch'] = self.epoch + + # store model and optimizer state + model_state['model_state_dict'] = self.state_dict() + model_state['optim_state_dict'] = optimizer.state_dict() + # model state dictionary stores the values of all trainable parameters state = os.path.join(outpath, state_file) torch.save(model_state, state) @@ -87,9 +104,16 @@ class UNet(Network): # number of classes self.nclasses = nclasses - # get the configuration for the convolutional layers of the encoder + # configuration of the convolutional layers in the network + self.kwargs = kwargs + self.nfilters = filters + + # convolutional layers of the encoder self.filters = np.hstack([np.array(in_channels), np.array(filters)]) + # whether to apply skip connections + self.skip = skip + # number of epochs trained self.epoch = 0 -- GitLab