Skip to content
Snippets Groups Projects
Commit a2f71786 authored by Frisinghelli Daniel's avatar Frisinghelli Daniel
Browse files

Improved functionality of generic Network class

parent eac2851e
No related branches found
No related tags found
No related merge requests found
...@@ -9,6 +9,7 @@ Created on Fri Jun 26 16:31:36 2020 ...@@ -9,6 +9,7 @@ Created on Fri Jun 26 16:31:36 2020
import os import os
import enum import enum
import logging import logging
import pathlib
# externals # externals
import numpy as np import numpy as np
...@@ -29,6 +30,9 @@ class Network(nn.Module): ...@@ -29,6 +30,9 @@ class Network(nn.Module):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
# initialize state file
self.state_file = None
def freeze(self): def freeze(self):
for param in self.parameters(): for param in self.parameters():
param.requires_grad = False param.requires_grad = False
...@@ -37,22 +41,22 @@ class Network(nn.Module): ...@@ -37,22 +41,22 @@ class Network(nn.Module):
for param in self.parameters(): for param in self.parameters():
param.requires_grad = True param.requires_grad = True
def save(self, state_file, optimizer, bands=None, def save(self, state_file, optimizer, bands=None, **kwargs):
outpath=os.path.join(os.getcwd(), '_models/')):
# check if the output path exists and if not, create it # check if the output path exists and if not, create it
if not os.path.isdir(outpath): state_file = pathlib.Path(state_file)
os.makedirs(outpath, exist_ok=True) if not state_file.parent.is_dir():
state_file.parent.mkdir(parents=True, exist_ok=True)
# initialize dictionary to store network parameters # initialize dictionary to store network parameters
model_state = {} model_state = {**kwargs}
# store model name
model_state['cls'] = self.__class__
# store the bands the model was trained with # store the spectral bands the model is trained with
model_state['bands'] = bands model_state['bands'] = bands
# store model class
model_state['cls'] = self.__class__
# store construction parameters to instanciate the network # store construction parameters to instanciate the network
model_state['params'] = { model_state['params'] = {
'skip': self.skip, 'skip': self.skip,
...@@ -62,7 +66,7 @@ class Network(nn.Module): ...@@ -62,7 +66,7 @@ class Network(nn.Module):
} }
# store optional keyword arguments # store optional keyword arguments
model_state['kwargs'] = self.kwargs model_state['params'] = {**model_state['params'], **self.kwargs}
# store model epoch # store model epoch
model_state['epoch'] = self.epoch model_state['epoch'] = self.epoch
...@@ -72,30 +76,48 @@ class Network(nn.Module): ...@@ -72,30 +76,48 @@ class Network(nn.Module):
model_state['optim_state_dict'] = optimizer.state_dict() model_state['optim_state_dict'] = optimizer.state_dict()
# model state dictionary stores the values of all trainable parameters # model state dictionary stores the values of all trainable parameters
state = os.path.join(outpath, state_file) torch.save(model_state, state_file)
torch.save(model_state, state) LOGGER.info('Network parameters saved in {}'.format(state_file))
LOGGER.info('Network parameters saved in {}'.format(state))
return state return state_file
def load(self, state_file, optimizer=None, @staticmethod
inpath=os.path.join(os.getcwd(), '_models/')): def load(state_file, optimizer=None):
# load the model state file # load the pretrained model
state = os.path.join(inpath, state_file) state_file = pathlib.Path(state_file)
model_state = torch.load(state) if not state_file.exists():
raise FileNotFoundError('{} does not exist.'.format(state_file))
LOGGER.info('Loading pretrained weights from: {}'.format(state_file))
# resume network parameters # load the model state
LOGGER.info('Loading model parameters ...'.format(state)) model_state = torch.load(state_file)
self.load_state_dict(model_state['model_state_dict'])
self.epoch = model_state['epoch'] # the model class
model_class = model_state['cls']
# instanciate pretrained model architecture
model = model_class(**model_state['params'])
# store state file as instance attribute
model.state_file = state_file
# load pretrained model weights
LOGGER.info('Loading model parameters ...')
model.load_state_dict(model_state['model_state_dict'])
model.epoch = model_state['epoch']
# resume optimizer parameters # resume optimizer parameters
if optimizer is not None: if optimizer is not None:
LOGGER.info('Loading optimizer parameters ...'.format(state)) LOGGER.info('Loading optimizer parameters ...')
optimizer.load_state_dict(model_state['optim_state_dict']) optimizer.load_state_dict(model_state['optim_state_dict'])
LOGGER.info('Model epoch: {:d}'.format(model.epoch))
return model, optimizer, model_state
return state @property
def state(self):
return self.state_file
class UNet(Network): class UNet(Network):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment