From f7112679373a4fe31a6619b9b6e9e35b4eccf0bc Mon Sep 17 00:00:00 2001 From: "Daniel.Frisinghelli" <daniel.frisinghelli@eurac.edu> Date: Mon, 17 Aug 2020 17:24:39 +0200 Subject: [PATCH] Major refactor: Increased modularity --- pysegcnn/core/trainer.py | 268 +++++++++++++++++++++------------------ 1 file changed, 147 insertions(+), 121 deletions(-) diff --git a/pysegcnn/core/trainer.py b/pysegcnn/core/trainer.py index cbe4d0f..855b76b 100644 --- a/pysegcnn/core/trainer.py +++ b/pysegcnn/core/trainer.py @@ -8,6 +8,7 @@ Created on Wed Aug 12 10:24:34 2020 import dataclasses import pathlib import logging +import datetime # externals import numpy as np @@ -199,7 +200,6 @@ class ModelConfig(BaseConfig): skip_connection: bool = True kwargs: dict = dataclasses.field( default_factory=lambda: {'kernel_size': 3, 'stride': 1, 'dilation': 1}) - state_path: pathlib.Path = pathlib.Path(HERE).joinpath('_models/') batch_size: int = 64 checkpoint: bool = False transfer: bool = False @@ -226,11 +226,16 @@ class ModelConfig(BaseConfig): # check whether the loss function is currently supported self.loss_class = item_in_enum(self.loss_name, SupportedLossFunctions) + # path to model states + self.state_path = pathlib.Path(HERE).joinpath('_models/') + # path to pretrained model self.pretrained_path = self.state_path.joinpath(self.pretrained_model) def init_optimizer(self, model): + LOGGER.info('Optimizer: {}.'.format(repr(self.optim_class))) + # initialize the optimizer for the specified model optimizer = self.optim_class(model.parameters(), self.lr) @@ -238,17 +243,28 @@ class ModelConfig(BaseConfig): def init_loss_function(self): + LOGGER.info('Loss function: {}.'.format(repr(self.loss_class))) + + # instanciate the loss function loss_function = self.loss_class() return loss_function - def init_model(self, ds): + def init_model(self, ds, state_file): + + # write an initialization string to the log file + # now = datetime.datetime.strftime(datetime.datetime.now(), + # '%Y-%m-%dT%H:%M:%S') + # LOGGER.info(80 * '-') + # LOGGER.info('{}: Initializing model run. '.format(now) + 35 * '-') + # LOGGER.info(80 * '-') # case (1): build a new model if not self.transfer: # set the random seed for reproducibility torch.manual_seed(self.torch_seed) + LOGGER.info('Initializing model: {}'.format(state_file.name)) # instanciate the model model = self.model_class( @@ -261,104 +277,86 @@ class ModelConfig(BaseConfig): # case (2): load a pretrained model for transfer learning else: # load pretrained model - model, _ = self.load_pretrained(self.pretrained_path, new_ds=ds) + LOGGER.info('Loading pretrained model for transfer learning from: ' + '{}'.format(self.pretrained_path)) + model = self.transfer_model(self.pretrained_path, ds) - return model - - def from_checkpoint(self, model, optimizer, state_file, loss_state): + # initialize the optimizer + optimizer = self.init_optimizer(model) # whether to resume training from an existing model checkpoint checkpoint_state = {} - max_accuracy = 0 if self.checkpoint: + model, optimizer, checkpoint_state = self.load_checkpoint( + model, optimizer, state_file) - # check whether the checkpoint exists - if state_file.exists() and loss_state.exists(): - # load model checkpoint - model, optimizer = self.load_pretrained(state_file, optimizer, - new_ds=None) - (checkpoint_state, max_accuracy) = self.load_checkpoint( - loss_state) - else: - LOGGER.info('Checkpoint for model {} does not exist. ' - 'Initializing new model.'.format(state_file.name)) - - return model, optimizer, checkpoint_state, max_accuracy + return model, optimizer, checkpoint_state @staticmethod - def load_pretrained(state_file, optimizer=None, new_ds=None): - - # load the pretrained model - if not state_file.exists(): - raise FileNotFoundError('Pretrained model {} does not exist.' - .format(state_file)) - - LOGGER.info('Loading pretrained model: {}'.format(state_file.name)) - - # load the model state - model_state = torch.load(state_file) + def load_checkpoint(model, optimizer, state_file): - # the model class - model_class = model_state['cls'] - - # instanciate pretrained model architecture - model = model_class(**model_state['params'], **model_state['kwargs']) - - # load pretrained model weights - _ = model.load(state_file.name, optimizer=optimizer, - inpath=str(state_file.parent)) - LOGGER.info('Model epoch: {:d}'.format(model.epoch)) + # whether to resume training from an existing model checkpoint + checkpoint_state = {} - # check whether to apply pretrained model on a new dataset - if new_ds is not None: - LOGGER.info('Configuring model for new dataset: {}.' - .format(new_ds.__class__.__name__)) + # if no checkpoint exists, file a warning and continue with a model + # initialized from scratch + if not state_file.exists(): + LOGGER.warning('Checkpoint for model {} does not exist. ' + 'Initializing new model.' + .format(state_file.name)) + else: + # load model checkpoint + model, optimizer, model_state = Network.load(state_file, optimizer) - # the bands the model was trained with - bands = model_state['bands'] + # load model loss and accuracy - # check whether the current dataset uses the correct spectral bands - if new_ds.use_bands != bands: - raise ValueError('The pretrained network was trained with the ' - 'bands {}, not with: {}' - .format(bands, new_ds.use_bands)) + # get all non-zero elements, i.e. get number of epochs trained + # before the early stop + checkpoint_state = {k: v[np.nonzero(v)].reshape(v.shape[0], -1) + for k, v in model_state['state'].items()} - # get the number of convolutional filters - filters = model_state['params']['filters'] + return model, optimizer, checkpoint_state - # reset model epoch to 0, since the model is trained on a different - # dataset - model.epoch = 0 + @staticmethod + def transfer_model(state_file, ds): - # adjust the number of classes in the model - model.nclasses = len(new_ds.labels) - LOGGER.info('Replacing classification layer to classes: {}.' - .format(', '.join('({}, {})'.format(k, v['label']) - for k, v in new_ds.labels.items()))) + # check input type + if not isinstance(ds, ImageDataset): + raise TypeError('Expected "ds" to be {}.' + .format('.'.join([ImageDataset.__module__, + ImageDataset.__name__]))) - # adjust the classification layer to the number of classes of the - # current dataset - model.classifier = Conv2dSame(in_channels=filters[0], - out_channels=model.nclasses, - kernel_size=1) + # load the pretrained model + model, _, model_state = Network.load(state_file) + LOGGER.info('Configuring model for new dataset: {}.'.format( + ds.__class__.__name__)) - return model, optimizer + # check whether the current dataset uses the correct spectral bands + if new_ds.use_bands != model_state['bands']: + raise ValueError('The pretrained network was trained with ' + 'bands {}, not with bands {}.' + .format(model_state['bands'], new_ds.use_bands)) - @staticmethod - def load_checkpoint(loss_state): + # get the number of convolutional filters + filters = model_state['params']['filters'] - # load the model loss and accuracy - checkpoint_state = torch.load(loss_state) + # reset model epoch to 0, since the model is trained on a different + # dataset + model.epoch = 0 - # get all non-zero elements, i.e. get number of epochs trained - # before the early stop - checkpoint_state = {k: v[np.nonzero(v)].reshape(v.shape[0], -1) - for k, v in checkpoint_state.items()} + # adjust the number of classes in the model + model.nclasses = len(ds.labels) + LOGGER.info('Replacing classification layer to classes: {}.' + .format(', '.join('({}, {})'.format(k, v['label']) + for k, v in ds.labels.items()))) - # maximum accuracy on the validation set - max_accuracy = checkpoint_state['va'][:, -1].mean().item() + # adjust the classification layer to the number of classes of the + # current dataset + model.classifier = Conv2dSame(in_channels=filters[0], + out_channels=model.nclasses, + kernel_size=1) - return checkpoint_state, max_accuracy + return model @dataclasses.dataclass @@ -420,14 +418,12 @@ class StateConfig(BaseConfig): # path to model state state = self.mc.state_path.joinpath(state_file) - # path to model loss/accuracy - loss_state = pathlib.Path(str(state).replace('.pt', '_loss.pt')) - - return state, loss_state + return state @dataclasses.dataclass class EvalConfig(BaseConfig): + state_file: pathlib.Path test: object predict_scene: bool = False plot_samples: bool = False @@ -446,6 +442,32 @@ class EvalConfig(BaseConfig): raise TypeError('Expected "test" to be None, True or False, got ' '{}.'.format(self.test)) + # the output paths for the different graphics + self.base_path = pathlib.Path(HERE) + self.sample_path = self.base_path.joinpath('_samples') + self.scenes_path = self.base_path.joinpath('_scenes') + self.models_path = self.base_path.joinpath('_graphics') + + # write initialization string to log file + # LOGGER.info(80 * '-') + # LOGGER.info('{}') + # LOGGER.info(80 * '-') + + +@dataclasses.dataclass +class LogConfig(BaseConfig): + state_file: pathlib.Path + + def __post_init__(self): + super().__post_init__() + + # the path to store model logs + self.log_path = pathlib.Path(HERE).joinpath('_logs') + + # the log file of the current model + self.log_file = self.log_path.joinpath( + self.state_file.name.replace('.pt', '.log')) + @dataclasses.dataclass class NetworkTrainer(BaseConfig): @@ -454,15 +476,14 @@ class NetworkTrainer(BaseConfig): loss_function: nn.Module train_dl: DataLoader valid_dl: DataLoader + test_dl: DataLoader state_file: pathlib.Path - loss_state: pathlib.Path epochs: int = 1 nthreads: int = torch.get_num_threads() early_stop: bool = False mode: str = 'max' delta: float = 0 patience: int = 10 - max_accuracy: float = 0 checkpoint_state: dict = dataclasses.field(default_factory=dict) save: bool = True @@ -473,16 +494,21 @@ class NetworkTrainer(BaseConfig): self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + # maximum accuracy on the validation dataset + self.max_accuracy = 0 + if self.checkpoint_state: + self.max_accuracy = self.checkpoint_state['va'].mean( + axis=0).max().item() + # whether to use early stopping self.es = None if self.early_stop: self.es = EarlyStopping(self.mode, self.max_accuracy, self.delta, self.patience) - def train(self): - LOGGER.info(30 * '-' + ' Training ' + 30 * '-') + LOGGER.info(35 * '-' + ' Training ' + 35 * '-') # set the number of threads LOGGER.info('Device: {}'.format(self.device)) @@ -493,11 +519,11 @@ class NetworkTrainer(BaseConfig): # training and validation dataset tshape = (len(self.train_dl), self.epochs) vshape = (len(self.valid_dl), self.epochs) - training_state = {'tl': np.zeros(shape=tshape), - 'ta': np.zeros(shape=tshape), - 'vl': np.zeros(shape=vshape), - 'va': np.zeros(shape=vshape) - } + self.training_state = {'tl': np.zeros(shape=tshape), + 'ta': np.zeros(shape=tshape), + 'vl': np.zeros(shape=vshape), + 'va': np.zeros(shape=vshape) + } # send the model to the gpu if available self.model = self.model.to(self.device) @@ -525,7 +551,7 @@ class NetworkTrainer(BaseConfig): # compute loss loss = self.loss_function(outputs, labels.long()) observed_loss = loss.detach().numpy().item() - training_state['tl'][batch, epoch] = observed_loss + self.training_state['tl'][batch, epoch] = observed_loss # compute the gradients of the loss function w.r.t. # the network weights @@ -539,7 +565,7 @@ class NetworkTrainer(BaseConfig): # calculate accuracy on current batch observed_accuracy = accuracy_function(ypred, labels) - training_state['ta'][batch, epoch] = observed_accuracy + self.training_state['ta'][batch, epoch] = observed_accuracy # print progress LOGGER.info('Epoch: {:d}/{:d}, Mini-batch: {:d}/{:d}, ' @@ -562,8 +588,8 @@ class NetworkTrainer(BaseConfig): vacc, vloss = self.predict() # append observed accuracy and loss to arrays - training_state['va'][:, epoch] = vacc.squeeze() - training_state['vl'][:, epoch] = vloss.squeeze() + self.training_state['va'][:, epoch] = vacc.squeeze() + self.training_state['vl'][:, epoch] = vloss.squeeze() # metric to assess model performance on the validation set epoch_acc = vacc.squeeze().mean() @@ -574,7 +600,7 @@ class NetworkTrainer(BaseConfig): # save model state if the model improved with # respect to the previous epoch - self.save_state(training_state) + self.save_state() # whether the early stopping criterion is met if self.es.stop(epoch_acc): @@ -583,15 +609,13 @@ class NetworkTrainer(BaseConfig): else: # if no early stopping is required, the model state is # saved after each epoch - self.save_state(training_state) + self.save_state() - return training_state + return self.training_state def predict(self): - LOGGER.info(30 * '-' + ' Predicting ' + 30 * '-') - # send the model to the gpu if available self.model = self.model.to(self.device) @@ -631,37 +655,38 @@ class NetworkTrainer(BaseConfig): .format(batch + 1, len(self.valid_dl), acc)) # calculate overall accuracy on the validation/test set - LOGGER.info('Epoch {:d}, Overall accuracy: {:.2f}%.' + LOGGER.info('Epoch: {:d}, Mean accuracy: {:.2f}%.' .format(self.model.epoch, accuracies.mean() * 100)) return accuracies, losses - def save_state(self, training_state): + def save_state(self): # whether to save the model state if self.save: - # save model state - state = self.model.save(self.state_file.name, - self.optimizer, - self.train_dl.dataset.dataset.use_bands, - self.state_file.parent) - - # save losses and accuracy - self._save_loss(training_state) - def _save_loss(self, training_state): + # append the model performance before the checkpoint to the model + # state, if a checkpoint is passed + if self.checkpoint_state: - # save losses and accuracy - state = training_state - if self.checkpoint_state: + # append values from checkpoint to current training state + state = {k1: np.hstack([v1, v2]) for (k1, v1), (k2, v2) in + zip(self.checkpoint_state.items(), + self.training_state.items()) if k1 == k2} + else: + state = self.training_state - # append values from checkpoint to current training state - state = {k1: np.hstack([v1, v2]) for (k1, v1), (k2, v2) in - zip(self.checkpoint_state.items(), training_state.items()) - if k1 == k2} + # save model state + _ = self.model.save( + self.state_file, + self.optimizer, + bands=self.train_dl.dataset.dataset.use_bands, + train_ds=self.train_dl.dataset, + valid_ds=self.valid_dl.dataset, + test_ds=self.test_dl.dataset, + state=state, + ) - # save the model loss and accuracies to file - torch.save(state, self.loss_state) def __repr__(self): @@ -687,6 +712,7 @@ class NetworkTrainer(BaseConfig): fs += '\n (split):' fs += '\n ' + repr(self.train_dl.dataset) fs += '\n ' + repr(self.valid_dl.dataset) + fs += '\n ' + repr(self.test_dl.dataset) # model fs += '\n (model):\n ' @@ -764,6 +790,6 @@ class EarlyStopping(object): def __repr__(self): fs = self.__class__.__name__ - fs += '(mode={}, best={}, delta={}, patience={})'.format( + fs += '(mode={}, best={:.2f}, delta={}, patience={})'.format( self.mode, self.best, self.min_delta, self.patience) return fs -- GitLab