Skip to content
Snippets Groups Projects
Commit f7112679 authored by Frisinghelli Daniel's avatar Frisinghelli Daniel
Browse files

Major refactor: Increased modularity

parent a2f71786
No related branches found
No related tags found
No related merge requests found
......@@ -8,6 +8,7 @@ Created on Wed Aug 12 10:24:34 2020
import dataclasses
import pathlib
import logging
import datetime
# externals
import numpy as np
......@@ -199,7 +200,6 @@ class ModelConfig(BaseConfig):
skip_connection: bool = True
kwargs: dict = dataclasses.field(
default_factory=lambda: {'kernel_size': 3, 'stride': 1, 'dilation': 1})
state_path: pathlib.Path = pathlib.Path(HERE).joinpath('_models/')
batch_size: int = 64
checkpoint: bool = False
transfer: bool = False
......@@ -226,11 +226,16 @@ class ModelConfig(BaseConfig):
# check whether the loss function is currently supported
self.loss_class = item_in_enum(self.loss_name, SupportedLossFunctions)
# path to model states
self.state_path = pathlib.Path(HERE).joinpath('_models/')
# path to pretrained model
self.pretrained_path = self.state_path.joinpath(self.pretrained_model)
def init_optimizer(self, model):
LOGGER.info('Optimizer: {}.'.format(repr(self.optim_class)))
# initialize the optimizer for the specified model
optimizer = self.optim_class(model.parameters(), self.lr)
......@@ -238,17 +243,28 @@ class ModelConfig(BaseConfig):
def init_loss_function(self):
LOGGER.info('Loss function: {}.'.format(repr(self.loss_class)))
# instanciate the loss function
loss_function = self.loss_class()
return loss_function
def init_model(self, ds):
def init_model(self, ds, state_file):
# write an initialization string to the log file
# now = datetime.datetime.strftime(datetime.datetime.now(),
# '%Y-%m-%dT%H:%M:%S')
# LOGGER.info(80 * '-')
# LOGGER.info('{}: Initializing model run. '.format(now) + 35 * '-')
# LOGGER.info(80 * '-')
# case (1): build a new model
if not self.transfer:
# set the random seed for reproducibility
torch.manual_seed(self.torch_seed)
LOGGER.info('Initializing model: {}'.format(state_file.name))
# instanciate the model
model = self.model_class(
......@@ -261,104 +277,86 @@ class ModelConfig(BaseConfig):
# case (2): load a pretrained model for transfer learning
else:
# load pretrained model
model, _ = self.load_pretrained(self.pretrained_path, new_ds=ds)
LOGGER.info('Loading pretrained model for transfer learning from: '
'{}'.format(self.pretrained_path))
model = self.transfer_model(self.pretrained_path, ds)
return model
def from_checkpoint(self, model, optimizer, state_file, loss_state):
# initialize the optimizer
optimizer = self.init_optimizer(model)
# whether to resume training from an existing model checkpoint
checkpoint_state = {}
max_accuracy = 0
if self.checkpoint:
model, optimizer, checkpoint_state = self.load_checkpoint(
model, optimizer, state_file)
# check whether the checkpoint exists
if state_file.exists() and loss_state.exists():
# load model checkpoint
model, optimizer = self.load_pretrained(state_file, optimizer,
new_ds=None)
(checkpoint_state, max_accuracy) = self.load_checkpoint(
loss_state)
else:
LOGGER.info('Checkpoint for model {} does not exist. '
'Initializing new model.'.format(state_file.name))
return model, optimizer, checkpoint_state, max_accuracy
return model, optimizer, checkpoint_state
@staticmethod
def load_pretrained(state_file, optimizer=None, new_ds=None):
# load the pretrained model
if not state_file.exists():
raise FileNotFoundError('Pretrained model {} does not exist.'
.format(state_file))
LOGGER.info('Loading pretrained model: {}'.format(state_file.name))
# load the model state
model_state = torch.load(state_file)
def load_checkpoint(model, optimizer, state_file):
# the model class
model_class = model_state['cls']
# instanciate pretrained model architecture
model = model_class(**model_state['params'], **model_state['kwargs'])
# load pretrained model weights
_ = model.load(state_file.name, optimizer=optimizer,
inpath=str(state_file.parent))
LOGGER.info('Model epoch: {:d}'.format(model.epoch))
# whether to resume training from an existing model checkpoint
checkpoint_state = {}
# check whether to apply pretrained model on a new dataset
if new_ds is not None:
LOGGER.info('Configuring model for new dataset: {}.'
.format(new_ds.__class__.__name__))
# if no checkpoint exists, file a warning and continue with a model
# initialized from scratch
if not state_file.exists():
LOGGER.warning('Checkpoint for model {} does not exist. '
'Initializing new model.'
.format(state_file.name))
else:
# load model checkpoint
model, optimizer, model_state = Network.load(state_file, optimizer)
# the bands the model was trained with
bands = model_state['bands']
# load model loss and accuracy
# check whether the current dataset uses the correct spectral bands
if new_ds.use_bands != bands:
raise ValueError('The pretrained network was trained with the '
'bands {}, not with: {}'
.format(bands, new_ds.use_bands))
# get all non-zero elements, i.e. get number of epochs trained
# before the early stop
checkpoint_state = {k: v[np.nonzero(v)].reshape(v.shape[0], -1)
for k, v in model_state['state'].items()}
# get the number of convolutional filters
filters = model_state['params']['filters']
return model, optimizer, checkpoint_state
# reset model epoch to 0, since the model is trained on a different
# dataset
model.epoch = 0
@staticmethod
def transfer_model(state_file, ds):
# adjust the number of classes in the model
model.nclasses = len(new_ds.labels)
LOGGER.info('Replacing classification layer to classes: {}.'
.format(', '.join('({}, {})'.format(k, v['label'])
for k, v in new_ds.labels.items())))
# check input type
if not isinstance(ds, ImageDataset):
raise TypeError('Expected "ds" to be {}.'
.format('.'.join([ImageDataset.__module__,
ImageDataset.__name__])))
# adjust the classification layer to the number of classes of the
# current dataset
model.classifier = Conv2dSame(in_channels=filters[0],
out_channels=model.nclasses,
kernel_size=1)
# load the pretrained model
model, _, model_state = Network.load(state_file)
LOGGER.info('Configuring model for new dataset: {}.'.format(
ds.__class__.__name__))
return model, optimizer
# check whether the current dataset uses the correct spectral bands
if new_ds.use_bands != model_state['bands']:
raise ValueError('The pretrained network was trained with '
'bands {}, not with bands {}.'
.format(model_state['bands'], new_ds.use_bands))
@staticmethod
def load_checkpoint(loss_state):
# get the number of convolutional filters
filters = model_state['params']['filters']
# load the model loss and accuracy
checkpoint_state = torch.load(loss_state)
# reset model epoch to 0, since the model is trained on a different
# dataset
model.epoch = 0
# get all non-zero elements, i.e. get number of epochs trained
# before the early stop
checkpoint_state = {k: v[np.nonzero(v)].reshape(v.shape[0], -1)
for k, v in checkpoint_state.items()}
# adjust the number of classes in the model
model.nclasses = len(ds.labels)
LOGGER.info('Replacing classification layer to classes: {}.'
.format(', '.join('({}, {})'.format(k, v['label'])
for k, v in ds.labels.items())))
# maximum accuracy on the validation set
max_accuracy = checkpoint_state['va'][:, -1].mean().item()
# adjust the classification layer to the number of classes of the
# current dataset
model.classifier = Conv2dSame(in_channels=filters[0],
out_channels=model.nclasses,
kernel_size=1)
return checkpoint_state, max_accuracy
return model
@dataclasses.dataclass
......@@ -420,14 +418,12 @@ class StateConfig(BaseConfig):
# path to model state
state = self.mc.state_path.joinpath(state_file)
# path to model loss/accuracy
loss_state = pathlib.Path(str(state).replace('.pt', '_loss.pt'))
return state, loss_state
return state
@dataclasses.dataclass
class EvalConfig(BaseConfig):
state_file: pathlib.Path
test: object
predict_scene: bool = False
plot_samples: bool = False
......@@ -446,6 +442,32 @@ class EvalConfig(BaseConfig):
raise TypeError('Expected "test" to be None, True or False, got '
'{}.'.format(self.test))
# the output paths for the different graphics
self.base_path = pathlib.Path(HERE)
self.sample_path = self.base_path.joinpath('_samples')
self.scenes_path = self.base_path.joinpath('_scenes')
self.models_path = self.base_path.joinpath('_graphics')
# write initialization string to log file
# LOGGER.info(80 * '-')
# LOGGER.info('{}')
# LOGGER.info(80 * '-')
@dataclasses.dataclass
class LogConfig(BaseConfig):
state_file: pathlib.Path
def __post_init__(self):
super().__post_init__()
# the path to store model logs
self.log_path = pathlib.Path(HERE).joinpath('_logs')
# the log file of the current model
self.log_file = self.log_path.joinpath(
self.state_file.name.replace('.pt', '.log'))
@dataclasses.dataclass
class NetworkTrainer(BaseConfig):
......@@ -454,15 +476,14 @@ class NetworkTrainer(BaseConfig):
loss_function: nn.Module
train_dl: DataLoader
valid_dl: DataLoader
test_dl: DataLoader
state_file: pathlib.Path
loss_state: pathlib.Path
epochs: int = 1
nthreads: int = torch.get_num_threads()
early_stop: bool = False
mode: str = 'max'
delta: float = 0
patience: int = 10
max_accuracy: float = 0
checkpoint_state: dict = dataclasses.field(default_factory=dict)
save: bool = True
......@@ -473,16 +494,21 @@ class NetworkTrainer(BaseConfig):
self.device = torch.device("cuda:0" if torch.cuda.is_available()
else "cpu")
# maximum accuracy on the validation dataset
self.max_accuracy = 0
if self.checkpoint_state:
self.max_accuracy = self.checkpoint_state['va'].mean(
axis=0).max().item()
# whether to use early stopping
self.es = None
if self.early_stop:
self.es = EarlyStopping(self.mode, self.max_accuracy, self.delta,
self.patience)
def train(self):
LOGGER.info(30 * '-' + ' Training ' + 30 * '-')
LOGGER.info(35 * '-' + ' Training ' + 35 * '-')
# set the number of threads
LOGGER.info('Device: {}'.format(self.device))
......@@ -493,11 +519,11 @@ class NetworkTrainer(BaseConfig):
# training and validation dataset
tshape = (len(self.train_dl), self.epochs)
vshape = (len(self.valid_dl), self.epochs)
training_state = {'tl': np.zeros(shape=tshape),
'ta': np.zeros(shape=tshape),
'vl': np.zeros(shape=vshape),
'va': np.zeros(shape=vshape)
}
self.training_state = {'tl': np.zeros(shape=tshape),
'ta': np.zeros(shape=tshape),
'vl': np.zeros(shape=vshape),
'va': np.zeros(shape=vshape)
}
# send the model to the gpu if available
self.model = self.model.to(self.device)
......@@ -525,7 +551,7 @@ class NetworkTrainer(BaseConfig):
# compute loss
loss = self.loss_function(outputs, labels.long())
observed_loss = loss.detach().numpy().item()
training_state['tl'][batch, epoch] = observed_loss
self.training_state['tl'][batch, epoch] = observed_loss
# compute the gradients of the loss function w.r.t.
# the network weights
......@@ -539,7 +565,7 @@ class NetworkTrainer(BaseConfig):
# calculate accuracy on current batch
observed_accuracy = accuracy_function(ypred, labels)
training_state['ta'][batch, epoch] = observed_accuracy
self.training_state['ta'][batch, epoch] = observed_accuracy
# print progress
LOGGER.info('Epoch: {:d}/{:d}, Mini-batch: {:d}/{:d}, '
......@@ -562,8 +588,8 @@ class NetworkTrainer(BaseConfig):
vacc, vloss = self.predict()
# append observed accuracy and loss to arrays
training_state['va'][:, epoch] = vacc.squeeze()
training_state['vl'][:, epoch] = vloss.squeeze()
self.training_state['va'][:, epoch] = vacc.squeeze()
self.training_state['vl'][:, epoch] = vloss.squeeze()
# metric to assess model performance on the validation set
epoch_acc = vacc.squeeze().mean()
......@@ -574,7 +600,7 @@ class NetworkTrainer(BaseConfig):
# save model state if the model improved with
# respect to the previous epoch
self.save_state(training_state)
self.save_state()
# whether the early stopping criterion is met
if self.es.stop(epoch_acc):
......@@ -583,15 +609,13 @@ class NetworkTrainer(BaseConfig):
else:
# if no early stopping is required, the model state is
# saved after each epoch
self.save_state(training_state)
self.save_state()
return training_state
return self.training_state
def predict(self):
LOGGER.info(30 * '-' + ' Predicting ' + 30 * '-')
# send the model to the gpu if available
self.model = self.model.to(self.device)
......@@ -631,37 +655,38 @@ class NetworkTrainer(BaseConfig):
.format(batch + 1, len(self.valid_dl), acc))
# calculate overall accuracy on the validation/test set
LOGGER.info('Epoch {:d}, Overall accuracy: {:.2f}%.'
LOGGER.info('Epoch: {:d}, Mean accuracy: {:.2f}%.'
.format(self.model.epoch, accuracies.mean() * 100))
return accuracies, losses
def save_state(self, training_state):
def save_state(self):
# whether to save the model state
if self.save:
# save model state
state = self.model.save(self.state_file.name,
self.optimizer,
self.train_dl.dataset.dataset.use_bands,
self.state_file.parent)
# save losses and accuracy
self._save_loss(training_state)
def _save_loss(self, training_state):
# append the model performance before the checkpoint to the model
# state, if a checkpoint is passed
if self.checkpoint_state:
# save losses and accuracy
state = training_state
if self.checkpoint_state:
# append values from checkpoint to current training state
state = {k1: np.hstack([v1, v2]) for (k1, v1), (k2, v2) in
zip(self.checkpoint_state.items(),
self.training_state.items()) if k1 == k2}
else:
state = self.training_state
# append values from checkpoint to current training state
state = {k1: np.hstack([v1, v2]) for (k1, v1), (k2, v2) in
zip(self.checkpoint_state.items(), training_state.items())
if k1 == k2}
# save model state
_ = self.model.save(
self.state_file,
self.optimizer,
bands=self.train_dl.dataset.dataset.use_bands,
train_ds=self.train_dl.dataset,
valid_ds=self.valid_dl.dataset,
test_ds=self.test_dl.dataset,
state=state,
)
# save the model loss and accuracies to file
torch.save(state, self.loss_state)
def __repr__(self):
......@@ -687,6 +712,7 @@ class NetworkTrainer(BaseConfig):
fs += '\n (split):'
fs += '\n ' + repr(self.train_dl.dataset)
fs += '\n ' + repr(self.valid_dl.dataset)
fs += '\n ' + repr(self.test_dl.dataset)
# model
fs += '\n (model):\n '
......@@ -764,6 +790,6 @@ class EarlyStopping(object):
def __repr__(self):
fs = self.__class__.__name__
fs += '(mode={}, best={}, delta={}, patience={})'.format(
fs += '(mode={}, best={:.2f}, delta={}, patience={})'.format(
self.mode, self.best, self.min_delta, self.patience)
return fs
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment