From 463e73d3d32f4aea3de9014d5a7968bdc795cbbb Mon Sep 17 00:00:00 2001 From: "Daniel.Frisinghelli" <daniel.frisinghelli@eurac.edu> Date: Mon, 24 Aug 2020 10:49:58 +0200 Subject: [PATCH] Network.load() now instanciates the optimizer saved by Network.save(). --- pysegcnn/core/models.py | 33 ++++++++++++++++++++------------- pysegcnn/core/trainer.py | 2 +- 2 files changed, 21 insertions(+), 14 deletions(-) diff --git a/pysegcnn/core/models.py b/pysegcnn/core/models.py index c04fa4c..39ba1a1 100644 --- a/pysegcnn/core/models.py +++ b/pysegcnn/core/models.py @@ -116,8 +116,9 @@ class Network(nn.Module): # store the spectral bands the model is trained with model_state['bands'] = bands - # store model class + # store model and optimizer class model_state['cls'] = self.__class__ + model_state['optim_cls'] = optimizer.__class__ # store construction parameters to instanciate the network model_state['params'] = { @@ -127,6 +128,9 @@ class Network(nn.Module): 'in_channels': self.in_channels } + # store optimizer construction parameters + model_state['optim_params'] = optimizer.defaults + # store optional keyword arguments model_state['params'] = {**model_state['params'], **self.kwargs} @@ -144,21 +148,17 @@ class Network(nn.Module): return model_state @staticmethod - def load(state_file, optimizer=None): + def load(state_file): """Load a model state. - Returns the model in ``state_file`` with the pretrained model weights. - If ``optimizer`` is specified, the optimizer parameters are also loaded - from ``state_file``. This is useful when resuming training an existing - model. + Returns the model in ``state_file`` with the pretrained model and + optimizer weights. Useful when resuming training an existing model. Parameters ---------- state_file : `str` or `pathlib.Path` The model state file. Model state files are stored in pysegcnn/main/_models. - optimizer : `torch.optim.Optimizer` or `None`, optional - The optimizer used to train the model. Raises ------ @@ -169,7 +169,7 @@ class Network(nn.Module): ------- model : `pysegcnn.core.models.Network` The pretrained model. - optimizer : `torch.optim.Optimizer` or `None` + optimizer : `torch.optim.Optimizer` The optimizer used to train the model. model_state : '`dict` A dictionary containing the model and optimizer state, as @@ -185,8 +185,9 @@ class Network(nn.Module): # load the model state model_state = torch.load(state_file) - # the model class + # the model and optimizer class model_class = model_state['cls'] + optim_class = model_state['optim_cls'] # instanciate pretrained model architecture model = model_class(**model_state['params']) @@ -200,9 +201,11 @@ class Network(nn.Module): model.epoch = model_state['epoch'] # resume optimizer parameters - if optimizer is not None: - LOGGER.info('Loading optimizer parameters ...') - optimizer.load_state_dict(model_state['optim_state_dict']) + LOGGER.info('Loading optimizer parameters ...') + optimizer = optim_class(model.parameters(), + **model_state['optim_params']) + optimizer.load_state_dict(model_state['optim_state_dict']) + LOGGER.info('Model epoch: {:d}'.format(model.epoch)) return model, optimizer, model_state @@ -223,6 +226,10 @@ class Network(nn.Module): class UNet(Network): """A PyTorch implementation of `U-Net`_. + Slightly modified version of U-Net: + - each convolution is followed by a batch normalization layer + - the upsampling is implemented by a 2x2 max unpooling operation + .. _U-Net: https://arxiv.org/abs/1505.04597 diff --git a/pysegcnn/core/trainer.py b/pysegcnn/core/trainer.py index 92014fb..bc9714f 100644 --- a/pysegcnn/core/trainer.py +++ b/pysegcnn/core/trainer.py @@ -680,7 +680,7 @@ class ModelConfig(BaseConfig): .format(state_file.name)) else: # load model checkpoint - model, optimizer, model_state = Network.load(state_file, optimizer) + model, optimizer, model_state = Network.load(state_file) # load model loss and accuracy -- GitLab