Skip to content
Snippets Groups Projects
Commit a2f71786 authored by Frisinghelli Daniel's avatar Frisinghelli Daniel
Browse files

Improved functionality of generic Network class

parent eac2851e
No related branches found
No related tags found
No related merge requests found
......@@ -9,6 +9,7 @@ Created on Fri Jun 26 16:31:36 2020
import os
import enum
import logging
import pathlib
# externals
import numpy as np
......@@ -29,6 +30,9 @@ class Network(nn.Module):
def __init__(self):
super().__init__()
# initialize state file
self.state_file = None
def freeze(self):
for param in self.parameters():
param.requires_grad = False
......@@ -37,22 +41,22 @@ class Network(nn.Module):
for param in self.parameters():
param.requires_grad = True
def save(self, state_file, optimizer, bands=None,
outpath=os.path.join(os.getcwd(), '_models/')):
def save(self, state_file, optimizer, bands=None, **kwargs):
# check if the output path exists and if not, create it
if not os.path.isdir(outpath):
os.makedirs(outpath, exist_ok=True)
state_file = pathlib.Path(state_file)
if not state_file.parent.is_dir():
state_file.parent.mkdir(parents=True, exist_ok=True)
# initialize dictionary to store network parameters
model_state = {}
# store model name
model_state['cls'] = self.__class__
model_state = {**kwargs}
# store the bands the model was trained with
# store the spectral bands the model is trained with
model_state['bands'] = bands
# store model class
model_state['cls'] = self.__class__
# store construction parameters to instanciate the network
model_state['params'] = {
'skip': self.skip,
......@@ -62,7 +66,7 @@ class Network(nn.Module):
}
# store optional keyword arguments
model_state['kwargs'] = self.kwargs
model_state['params'] = {**model_state['params'], **self.kwargs}
# store model epoch
model_state['epoch'] = self.epoch
......@@ -72,30 +76,48 @@ class Network(nn.Module):
model_state['optim_state_dict'] = optimizer.state_dict()
# model state dictionary stores the values of all trainable parameters
state = os.path.join(outpath, state_file)
torch.save(model_state, state)
LOGGER.info('Network parameters saved in {}'.format(state))
torch.save(model_state, state_file)
LOGGER.info('Network parameters saved in {}'.format(state_file))
return state
return state_file
def load(self, state_file, optimizer=None,
inpath=os.path.join(os.getcwd(), '_models/')):
@staticmethod
def load(state_file, optimizer=None):
# load the model state file
state = os.path.join(inpath, state_file)
model_state = torch.load(state)
# load the pretrained model
state_file = pathlib.Path(state_file)
if not state_file.exists():
raise FileNotFoundError('{} does not exist.'.format(state_file))
LOGGER.info('Loading pretrained weights from: {}'.format(state_file))
# resume network parameters
LOGGER.info('Loading model parameters ...'.format(state))
self.load_state_dict(model_state['model_state_dict'])
self.epoch = model_state['epoch']
# load the model state
model_state = torch.load(state_file)
# the model class
model_class = model_state['cls']
# instanciate pretrained model architecture
model = model_class(**model_state['params'])
# store state file as instance attribute
model.state_file = state_file
# load pretrained model weights
LOGGER.info('Loading model parameters ...')
model.load_state_dict(model_state['model_state_dict'])
model.epoch = model_state['epoch']
# resume optimizer parameters
if optimizer is not None:
LOGGER.info('Loading optimizer parameters ...'.format(state))
LOGGER.info('Loading optimizer parameters ...')
optimizer.load_state_dict(model_state['optim_state_dict'])
LOGGER.info('Model epoch: {:d}'.format(model.epoch))
return model, optimizer, model_state
return state
@property
def state(self):
return self.state_file
class UNet(Network):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment