diff --git a/pysegcnn/core/models.py b/pysegcnn/core/models.py index aa55c9b1ff186db9192055ff75643405d781fda3..e0145003872a901b16fbc2a5217dfb7fff892860 100644 --- a/pysegcnn/core/models.py +++ b/pysegcnn/core/models.py @@ -47,12 +47,19 @@ class Network(nn.Module): """ - def __init__(self): - """Initialize.""" + def __init__(self, state_file=None): + """Initialize. + + Parameters + ---------- + state_file : `str` or `None` or :py:class:`pathlib.Path` + The model state file, where the model parameters are saved. + + """ super().__init__() # initialize state file - self.state_file = None + self.state_file = state_file # number of epochs trained self.epoch = 0 @@ -103,7 +110,7 @@ class Network(nn.Module): for param in getattr(self, str(name)).parameters(): param.requires_grad = True - def save(self, state_file, optimizer, bands=None, **kwargs): + def save(self, state_file, optimizer, **kwargs): """Save the model state. Saves the model and optimizer states together with the model @@ -117,8 +124,6 @@ class Network(nn.Module): Path to save the model state. optimizer : :py:class:`torch.optim.Optimizer` The optimizer used to train the model. - bands : `list` [`str`] or `None`, optional - List of bands the model is trained with. The default is None. **kwargs Arbitrary keyword arguments. Each keyword argument will be saved as (key, value) pair in ``state_file``. @@ -130,7 +135,7 @@ class Network(nn.Module): """ # check if the output path exists and if not, create it - state_file = check_filename_length(state_file) + state_file = pathlib.Path(check_filename_length(state_file)) if not state_file.parent.is_dir(): state_file.parent.mkdir(parents=True, exist_ok=True) @@ -138,25 +143,25 @@ class Network(nn.Module): model_state = {**kwargs} # store the spectral bands the model is trained with - model_state['bands'] = bands + # model_state['bands'] = bands # store model and optimizer class - model_state['cls'] = self.__class__ - model_state['optim_cls'] = optimizer.__class__ + # model_state['cls'] = self.__class__ + # model_state['optim_cls'] = optimizer.__class__ # store construction parameters to instanciate the network - model_state['params'] = { - 'skip': self.skip, - 'filters': self.filters[1:], - 'nclasses': self.nclasses, - 'in_channels': self.in_channels - } + # model_state['params'] = { + # 'skip': self.skip, + # 'filters': self.filters[1:], + # 'nclasses': self.nclasses, + # 'in_channels': self.in_channels + # } # store optimizer construction parameters - model_state['optim_params'] = optimizer.defaults + # model_state['optim_params'] = optimizer.defaults # store optional keyword arguments - model_state['params'] = {**model_state['params'], **self.kwargs} + # model_state['params'] = {**model_state['params'], **self.kwargs} # store model epoch model_state['epoch'] = self.epoch @@ -172,7 +177,7 @@ class Network(nn.Module): return model_state @staticmethod - def load(state_file): + def load(model, optimizer, state_file): """Load a model state. Returns the model in ``state_file`` with the pretrained model and @@ -180,9 +185,14 @@ class Network(nn.Module): Parameters ---------- + model : :py:class:`pysegcnn.core.models.Network` + An instance of the model for which the pretrained weights are + stored in ``state_file``. + optimizer : :py:class:`torch.optim.Optimizer` + An instance of the optimizer used to train ``model``. state_file : `str` or :py:class:`pathlib.Path` - The model state file. Model state files are stored in - pysegcnn/main/_models. + The model state file containing the pretrained parameters for + ``model`` and ``optimizer``. Raises ------ @@ -191,17 +201,13 @@ class Network(nn.Module): Returns ------- - model : :py:class:`pysegcnn.core.models.Network` - The pretrained model. - optimizer : :py:class:`torch.optim.Optimizer` - The optimizer used to train the model. model_state : `dict` A dictionary containing the model and optimizer state, as constructed by :py:meth:`~pysegcnn.core.Network.save`. """ # load the pretrained model - state_file = check_filename_length(pathlib.Path(state_file)) + state_file = pathlib.Path(check_filename_length(state_file)) if not state_file.exists(): raise FileNotFoundError('{} does not exist.'.format(state_file)) LOGGER.info('Loading pretrained weights from: {}'.format(state_file)) @@ -210,11 +216,11 @@ class Network(nn.Module): model_state = torch.load(state_file) # the model and optimizer class - model_class = model_state['cls'] - optim_class = model_state['optim_cls'] + # model_class = model_state['cls'] + # optim_class = model_state['optim_cls'] # instanciate pretrained model architecture - model = model_class(**model_state['params']) + # model = model_class(**model_state['params']) # store state file as instance attribute model.state_file = state_file @@ -226,13 +232,11 @@ class Network(nn.Module): # resume optimizer parameters LOGGER.info('Loading optimizer parameters ...') - optimizer = optim_class(model.parameters(), - **model_state['optim_params']) optimizer.load_state_dict(model_state['optim_state_dict']) LOGGER.info('Model epoch: {:d}'.format(model.epoch)) - return model, optimizer, model_state + return model_state @property def state(self):