From a2f717864feff6608ef51db81b2258134146e009 Mon Sep 17 00:00:00 2001 From: "Daniel.Frisinghelli" <daniel.frisinghelli@eurac.edu> Date: Mon, 17 Aug 2020 17:23:48 +0200 Subject: [PATCH] Improved functionality of generic Network class --- pysegcnn/core/models.py | 72 +++++++++++++++++++++++++++-------------- 1 file changed, 47 insertions(+), 25 deletions(-) diff --git a/pysegcnn/core/models.py b/pysegcnn/core/models.py index 2d5cc84..1f9fa2b 100644 --- a/pysegcnn/core/models.py +++ b/pysegcnn/core/models.py @@ -9,6 +9,7 @@ Created on Fri Jun 26 16:31:36 2020 import os import enum import logging +import pathlib # externals import numpy as np @@ -29,6 +30,9 @@ class Network(nn.Module): def __init__(self): super().__init__() + # initialize state file + self.state_file = None + def freeze(self): for param in self.parameters(): param.requires_grad = False @@ -37,22 +41,22 @@ class Network(nn.Module): for param in self.parameters(): param.requires_grad = True - def save(self, state_file, optimizer, bands=None, - outpath=os.path.join(os.getcwd(), '_models/')): + def save(self, state_file, optimizer, bands=None, **kwargs): # check if the output path exists and if not, create it - if not os.path.isdir(outpath): - os.makedirs(outpath, exist_ok=True) + state_file = pathlib.Path(state_file) + if not state_file.parent.is_dir(): + state_file.parent.mkdir(parents=True, exist_ok=True) # initialize dictionary to store network parameters - model_state = {} - - # store model name - model_state['cls'] = self.__class__ + model_state = {**kwargs} - # store the bands the model was trained with + # store the spectral bands the model is trained with model_state['bands'] = bands + # store model class + model_state['cls'] = self.__class__ + # store construction parameters to instanciate the network model_state['params'] = { 'skip': self.skip, @@ -62,7 +66,7 @@ class Network(nn.Module): } # store optional keyword arguments - model_state['kwargs'] = self.kwargs + model_state['params'] = {**model_state['params'], **self.kwargs} # store model epoch model_state['epoch'] = self.epoch @@ -72,30 +76,48 @@ class Network(nn.Module): model_state['optim_state_dict'] = optimizer.state_dict() # model state dictionary stores the values of all trainable parameters - state = os.path.join(outpath, state_file) - torch.save(model_state, state) - LOGGER.info('Network parameters saved in {}'.format(state)) + torch.save(model_state, state_file) + LOGGER.info('Network parameters saved in {}'.format(state_file)) - return state + return state_file - def load(self, state_file, optimizer=None, - inpath=os.path.join(os.getcwd(), '_models/')): + @staticmethod + def load(state_file, optimizer=None): - # load the model state file - state = os.path.join(inpath, state_file) - model_state = torch.load(state) + # load the pretrained model + state_file = pathlib.Path(state_file) + if not state_file.exists(): + raise FileNotFoundError('{} does not exist.'.format(state_file)) + LOGGER.info('Loading pretrained weights from: {}'.format(state_file)) - # resume network parameters - LOGGER.info('Loading model parameters ...'.format(state)) - self.load_state_dict(model_state['model_state_dict']) - self.epoch = model_state['epoch'] + # load the model state + model_state = torch.load(state_file) + + # the model class + model_class = model_state['cls'] + + # instanciate pretrained model architecture + model = model_class(**model_state['params']) + + # store state file as instance attribute + model.state_file = state_file + + # load pretrained model weights + LOGGER.info('Loading model parameters ...') + model.load_state_dict(model_state['model_state_dict']) + model.epoch = model_state['epoch'] # resume optimizer parameters if optimizer is not None: - LOGGER.info('Loading optimizer parameters ...'.format(state)) + LOGGER.info('Loading optimizer parameters ...') optimizer.load_state_dict(model_state['optim_state_dict']) + LOGGER.info('Model epoch: {:d}'.format(model.epoch)) + + return model, optimizer, model_state - return state + @property + def state(self): + return self.state_file class UNet(Network): -- GitLab