diff --git a/pysegcnn/core/trainer.py b/pysegcnn/core/trainer.py index 73d3a4399f3bfa6ca3ce4c011e48cf71a3ad143c..757fb3f37e2492be8db4b394997d1ba717232ee1 100644 --- a/pysegcnn/core/trainer.py +++ b/pysegcnn/core/trainer.py @@ -7,6 +7,7 @@ Created on Wed Aug 12 10:24:34 2020 # builtins import dataclasses import pathlib +import logging # externals import numpy as np @@ -16,7 +17,6 @@ import torch.nn.functional as F from torch.utils.data import DataLoader, Dataset from torch.optim import Optimizer - # locals from pysegcnn.core.dataset import SupportedDatasets, ImageDataset from pysegcnn.core.transforms import Augment @@ -27,6 +27,9 @@ from pysegcnn.core.models import (SupportedModels, SupportedOptimizers, from pysegcnn.core.layers import Conv2dSame from pysegcnn.main.config import HERE +# module level logger +LOGGER = logging.getLogger(__name__) + @dataclasses.dataclass class BaseConfig: @@ -60,7 +63,6 @@ class DatasetConfig(BaseConfig): sort: bool = False transforms: list = dataclasses.field(default_factory=list) pad: bool = False - cval: int = 99 def __post_init__(self): # check input types @@ -80,11 +82,6 @@ class DatasetConfig(BaseConfig): ' of {}.'.format('.'.join([Augment.__module__, Augment.__name__]))) - # check whether the constant padding value is within the valid range - if not 0 < self.cval < 255: - raise ValueError('Expecting 0 <= cval <= 255, got cval={}.' - .format(self.cval)) - def init_dataset(self): # instanciate the dataset @@ -96,7 +93,6 @@ class DatasetConfig(BaseConfig): sort=self.sort, transforms=self.transforms, pad=self.pad, - cval=self.cval, gt_pattern=self.gt_pattern ) @@ -121,7 +117,8 @@ class SplitConfig(BaseConfig): # function to drop samples with a fraction of pixels equal to the constant # padding value self.cval >= self.drop - def _drop_samples(self, ds, drop_threshold=1): + @staticmethod + def _drop_samples(ds, drop_threshold=1): # iterate over the scenes returned by self.compose_scenes() dropped = [] @@ -139,8 +136,8 @@ class SplitConfig(BaseConfig): # drop samples where npixels >= self.drop if npixels >= drop_threshold: - print('Skipping scene {}, tile {}: {:.2f}% padded pixels ...' - .format(s['id'], s['tile'], npixels * 100)) + LOGGER.info('Skipping scene {}, tile {}: {:.2f}% padded pixels' + ' ...'.format(s['id'], s['tile'], npixels * 100)) dropped.append(s) _ = ds.indices.pop(pos) @@ -197,14 +194,24 @@ class ModelConfig(BaseConfig): model_name: str filters: list torch_seed: int + optim_name: str + loss_name: str skip_connection: bool = True kwargs: dict = dataclasses.field( default_factory=lambda: {'kernel_size': 3, 'stride': 1, 'dilation': 1}) state_path: pathlib.Path = pathlib.Path(HERE).joinpath('_models/') batch_size: int = 64 checkpoint: bool = False - pretrained: bool = False + transfer: bool = False pretrained_model: str = '' + lr: float = 0.001 + early_stop: bool = False + mode: str = 'max' + delta: float = 0 + patience: int = 10 + epochs: int = 50 + nthreads: int = torch.get_num_threads() + save: bool = True def __post_init__(self): # check input types @@ -213,65 +220,32 @@ class ModelConfig(BaseConfig): # check whether the model is currently supported self.model_class = item_in_enum(self.model_name, SupportedModels) - def init_state(self, ds, sc, tc): - - # file to save model state to: - # network_dataset_optim_split_splitparams_tilesize_batchsize_bands.pt + # check whether the optimizer is currently supported + self.optim_class = item_in_enum(self.optim_name, SupportedOptimizers) - # model state filename - state_file = '{}_{}_{}_{}Split_{}_t{}_b{}_{}.pt' + # check whether the loss function is currently supported + self.loss_class = item_in_enum(self.loss_name, SupportedLossFunctions) - # get the band numbers - bformat = ''.join(band[0] + - str(ds.sensor.__members__[band].value) for - band in ds.use_bands) + # path to pretrained model + self.pretrained_path = self.state_path.joinpath(self.pretrained_model) - # check which split mode was used - if sc.split_mode == 'date': - # store the date that was used to split the dataset - state_file = state_file.format(self.model_class.__name__, - ds.__class__.__name__, - tc.optim_name, - sc.split_mode.capitalize(), - sc.date, - ds.tile_size, - self.batch_size, - bformat) - else: - # store the random split parameters - split_params = 's{}_t{}v{}'.format( - ds.seed, str(sc.ttratio).replace('.', ''), - str(sc.tvratio).replace('.', '')) + def init_optimizer(self, model): - # model state filename - state_file = state_file.format(self.model_class.__name__, - ds.__class__.__name__, - tc.optim_name, - sc.split_mode.capitalize(), - split_params, - ds.tile_size, - self.batch_size, - bformat) + # initialize the optimizer for the specified model + optimizer = self.optim_class(model.parameters(), self.lr) - # check whether a pretrained model was used and change state filename - # accordingly - if self.pretrained: - # add the configuration of the pretrained model to the state name - state_file = (state_file.replace('.pt', '_') + - 'pretrained_' + self.pretrained_model) + return optimizer - # path to model state - state = self.state_path.joinpath(state_file) + def init_loss_function(self): - # path to model loss/accuracy - loss_state = pathlib.Path(str(state).replace('.pt', '_loss.pt')) + loss_function = self.loss_class() - return state, loss_state + return loss_function def init_model(self, ds): # case (1): build a new model - if not self.pretrained: + if not self.transfer: # set the random seed for reproducibility torch.manual_seed(self.torch_seed) @@ -284,130 +258,172 @@ class ModelConfig(BaseConfig): skip=self.skip_connection, **self.kwargs) - # case (2): load a pretrained model + # case (2): load a pretrained model for transfer learning else: - # load pretrained model - model = self.load_pretrained() + model, _ = self.load_pretrained(self.pretrained_path, new_ds=ds) return model - def load_checkpoint(self, state_file, loss_state, model, optimizer): + def from_checkpoint(self, model, optimizer, state_file, loss_state): - # initial accuracy on the validation set + # whether to resume training from an existing model checkpoint + checkpoint_state = {} max_accuracy = 0 + if self.checkpoint: - # set the model checkpoint to None, overwritten when resuming - # training from an existing model checkpoint - checkpoint_state = {} + # check whether the checkpoint exists + if state_file.exists() and loss_state.exists(): + # load model checkpoint + model, optimizer = self.load_pretrained(state_file, optimizer, + new_ds=None) + (checkpoint_state, max_accuracy) = self.load_checkpoint( + loss_state) + else: + LOGGER.info('Checkpoint for model {} does not exist. ' + 'Initializing new model.'.format(state_file.name)) - # whether to resume training from an existing model - if self.checkpoint: + return model, optimizer, checkpoint_state, max_accuracy - # check if a model checkpoint exists - if not state_file.exists(): - raise FileNotFoundError('Model checkpoint {} does not exist.' - .format(state_file)) + @staticmethod + def load_pretrained(state_file, optimizer=None, new_ds=None): - # load the model state - state = model.load(state_file.name, optimizer, self.state_path) - print('Found checkpoint: {}'.format(state)) - print('Resuming training from checkpoint ...'.format(state)) - print('Model epoch: {:d}'.format(model.epoch)) + # load the pretrained model + if not state_file.exists(): + raise FileNotFoundError('Pretrained model {} does not exist.' + .format(state_file)) - # load the model loss and accuracy - checkpoint_state = torch.load(loss_state) + LOGGER.info('Loading pretrained model: {}'.format(state_file.name)) - # get all non-zero elements, i.e. get number of epochs trained - # before the early stop - checkpoint_state = {k: v[np.nonzero(v)].reshape(v.shape[0], -1) - for k, v in checkpoint_state.items()} + # load the model state + model_state = torch.load(state_file) - # maximum accuracy on the validation set - max_accuracy = checkpoint_state['va'][:, -1].mean().item() + # the model class + model_class = model_state['cls'] - return checkpoint_state, max_accuracy + # instanciate pretrained model architecture + model = model_class(**model_state['params'], **model_state['kwargs']) - def load_pretrained(self, ds): + # load pretrained model weights + _ = model.load(state_file.name, optimizer=optimizer, + inpath=str(state_file.parent)) + LOGGER.info('Model epoch: {:d}'.format(model.epoch)) - # load the pretrained model - model_state = self.state_path.joinpath(self.pretrained_model) - if not model_state.exists(): - raise FileNotFoundError('Pretrained model {} does not exist.' - .format(model_state)) + # check whether to apply pretrained model on a new dataset + if new_ds is not None: + LOGGER.info('Configuring model for new dataset: {}.' + .format(new_ds.__class__.__name__)) - # load the model state - model_state = torch.load(model_state) + # the bands the model was trained with + bands = model_state['bands'] - # get the input bands of the pretrained model - bands = model_state['bands'] + # check whether the current dataset uses the correct spectral bands + if new_ds.use_bands != bands: + raise ValueError('The pretrained network was trained with the ' + 'bands {}, not with: {}' + .format(bands, new_ds.use_bands)) - # get the number of convolutional filters - filters = model_state['params']['filters'] + # get the number of convolutional filters + filters = model_state['params']['filters'] - # check whether the current dataset uses the correct spectral bands - if ds.use_bands != bands: - raise ValueError('The bands of the pretrained network do not ' - 'match the specified bands: {}' - .format(bands)) + # reset model epoch to 0, since the model is trained on a different + # dataset + model.epoch = 0 - # instanciate pretrained model architecture - model = self.model_class(**model_state['params'], - **model_state['kwargs']) + # adjust the number of classes in the model + model.nclasses = len(new_ds.labels) + LOGGER.info('Replacing classification layer to classes: {}.' + .format(', '.join('({}, {})'.format(k, v['label']) + for k, v in new_ds.labels.items()))) - # load pretrained model weights - model.load(self.pretrained_model, inpath=str(self.state_path)) + # adjust the classification layer to the number of classes of the + # current dataset + model.classifier = Conv2dSame(in_channels=filters[0], + out_channels=model.nclasses, + kernel_size=1) - # reset model epoch to 0, since the model is trained on a different - # dataset - model.epoch = 0 + return model, optimizer - # adjust the number of classes in the model - model.nclasses = len(ds.labels) + @staticmethod + def load_checkpoint(loss_state): - # adjust the classification layer to the number of classes of the - # current dataset - model.classifier = Conv2dSame(in_channels=filters[0], - out_channels=model.nclasses, - kernel_size=1) + # load the model loss and accuracy + checkpoint_state = torch.load(loss_state) - return model + # get all non-zero elements, i.e. get number of epochs trained + # before the early stop + checkpoint_state = {k: v[np.nonzero(v)].reshape(v.shape[0], -1) + for k, v in checkpoint_state.items()} + + # maximum accuracy on the validation set + max_accuracy = checkpoint_state['va'][:, -1].mean().item() + + return checkpoint_state, max_accuracy @dataclasses.dataclass -class TrainConfig(BaseConfig): - optim_name: str - loss_name: str - lr: float = 0.001 - early_stop: bool = False - mode: str = 'max' - delta: float = 0 - patience: int = 10 - epochs: int = 50 - nthreads: int = torch.get_num_threads() - save: bool = True +class StateConfig(BaseConfig): + ds: ImageDataset + sc: SplitConfig + mc: ModelConfig def __post_init__(self): super().__post_init__() - # check whether the optimizer is currently supported - self.optim_class = item_in_enum(self.optim_name, SupportedOptimizers) + def init_state(self): - # check whether the loss function is currently supported - self.loss_class = item_in_enum(self.loss_name, SupportedLossFunctions) + # file to save model state to: + # network_dataset_optim_split_splitparams_tilesize_batchsize_bands.pt - def init_optimizer(self, model): + # model state filename + state_file = '{}_{}_{}_{}Split_{}_t{}_b{}_{}.pt' - # initialize the optimizer for the specified model - optimizer = self.optim_class(model.parameters(), self.lr) + # get the band numbers + bformat = ''.join(band[0] + + str(self.ds.sensor.__members__[band].value) for + band in self.ds.use_bands) - return optimizer + # check which split mode was used + if self.sc.split_mode == 'date': + # store the date that was used to split the dataset + state_file = state_file.format(self.mc.model_name, + self.ds.__class__.__name__, + self.mc.optim_name, + self.sc.split_mode.capitalize(), + self.sc.date, + self.ds.tile_size, + self.mc.batch_size, + bformat) + else: + # store the random split parameters + split_params = 's{}_t{}v{}'.format( + self.ds.seed, str(self.sc.ttratio).replace('.', ''), + str(self.sc.tvratio).replace('.', '')) - def init_loss_function(self): + # model state filename + state_file = state_file.format(self.mc.model_name, + self.ds.__class__.__name__, + self.mc.optim_name, + self.sc.split_mode.capitalize(), + split_params, + self.ds.tile_size, + self.mc.batch_size, + bformat) - loss_function = self.loss_class() + # check whether a pretrained model was used and change state filename + # accordingly + if self.mc.transfer: + # add the configuration of the pretrained model to the state name + state_file = (state_file.replace('.pt', '_') + + 'pretrained_' + self.mc.pretrained_model) - return loss_function + # path to model state + state = self.mc.state_path.joinpath(state_file) + + # path to model loss/accuracy + loss_state = pathlib.Path(str(state).replace('.pt', '_loss.pt')) + + return state, loss_state @dataclasses.dataclass @@ -428,6 +444,7 @@ class EvalConfig(BaseConfig): raise TypeError('Expected "test" to be None, True or False, got ' '{}.'.format(self.test)) + @dataclasses.dataclass class NetworkTrainer(BaseConfig): model: Network @@ -457,15 +474,16 @@ class NetworkTrainer(BaseConfig): # whether to use early stopping self.es = None if self.early_stop: - self.es = EarlyStopping(self.mode, self.delta, self.patience) + self.es = EarlyStopping(self.mode, self.max_accuracy, self.delta, + self.patience) def train(self): - print('------------------------- Training ---------------------------') + LOGGER.info(30 * '-' + ' Training ' + 30 * '-') # set the number of threads - print('Device: {}'.format(self.device)) - print('Number of cpu threads: {}'.format(self.nthreads)) + LOGGER.info('Device: {}'.format(self.device)) + LOGGER.info('Number of cpu threads: {}'.format(self.nthreads)) torch.set_num_threads(self.nthreads) # create dictionary of the observed losses and accuracies on the @@ -485,7 +503,7 @@ class NetworkTrainer(BaseConfig): for epoch in range(self.epochs): # set the model to training mode - print('Setting model to training mode ...') + LOGGER.info('Setting model to training mode ...') self.model.train() # iterate over the dataloader object @@ -521,13 +539,14 @@ class NetworkTrainer(BaseConfig): training_state['ta'][batch, epoch] = observed_accuracy # print progress - print('Epoch: {:d}/{:d}, Mini-batch: {:d}/{:d}, Loss: {:.2f}, ' - 'Accuracy: {:.2f}'.format(epoch + 1, - self.epochs, - batch + 1, - len(self.train_dl), - observed_loss, - observed_accuracy)) + LOGGER.info('Epoch: {:d}/{:d}, Mini-batch: {:d}/{:d}, ' + 'Loss: {:.2f}, Accuracy: {:.2f}'.format( + epoch + 1, + self.epochs, + batch + 1, + len(self.train_dl), + observed_loss, + observed_accuracy)) # update the number of epochs trained self.model.epoch += 1 @@ -568,13 +587,13 @@ class NetworkTrainer(BaseConfig): def predict(self): - print('------------------------ Predicting --------------------------') + LOGGER.info(30 * '-' + ' Predicting ' + 30 * '-') # send the model to the gpu if available self.model = self.model.to(self.device) # set the model to evaluation mode - print('Setting model to evaluation mode ...') + LOGGER.info('Setting model to evaluation mode ...') self.model.eval() # create arrays of the observed losses and accuracies @@ -582,7 +601,7 @@ class NetworkTrainer(BaseConfig): losses = np.zeros(shape=(len(self.valid_dl), 1)) # iterate over the validation/test set - print('Calculating accuracy on the validation set ...') + LOGGER.info('Calculating accuracy on the validation set ...') for batch, (inputs, labels) in enumerate(self.valid_dl): # send the data to the gpu if available @@ -605,12 +624,12 @@ class NetworkTrainer(BaseConfig): accuracies[batch, 0] = acc # print progress - print('Mini-batch: {:d}/{:d}, Accuracy: {:.2f}' - .format(batch + 1, len(self.valid_dl), acc)) + LOGGER.info('Mini-batch: {:d}/{:d}, Accuracy: {:.2f}' + .format(batch + 1, len(self.valid_dl), acc)) # calculate overall accuracy on the validation/test set - print('Epoch {:d}, Overall accuracy: {:.2f}%.' - .format(self.model.epoch, accuracies.mean() * 100)) + LOGGER.info('Epoch {:d}, Overall accuracy: {:.2f}%.' + .format(self.model.epoch, accuracies.mean() * 100)) return accuracies, losses @@ -649,7 +668,7 @@ class NetworkTrainer(BaseConfig): # dataset fs += ' (dataset):\n ' fs += ''.join( - repr(self.train_dl.dataset.dataset)).replace('\n','\n ') + repr(self.train_dl.dataset.dataset)).replace('\n', '\n ') # batch size fs += '\n (batch):\n ' @@ -684,7 +703,7 @@ class NetworkTrainer(BaseConfig): class EarlyStopping(object): - def __init__(self, mode='max', min_delta=0, patience=10): + def __init__(self, mode='max', best=0, min_delta=0, patience=10): # check if mode is correctly specified if mode not in ['min', 'max']: @@ -707,7 +726,7 @@ class EarlyStopping(object): self.patience = patience # initialize best metric - self.best = None + self.best = best # initialize early stopping flag self.early_stop = False @@ -717,25 +736,20 @@ class EarlyStopping(object): def stop(self, metric): - if self.best is not None: - - # if the metric improved, reset the epochs counter, else, advance - if self.is_better(metric, self.best, self.min_delta): - self.counter = 0 - self.best = metric - else: - self.counter += 1 - print('Early stopping counter: {}/{}'.format(self.counter, - self.patience)) - - # if the metric did not improve over the last patience epochs, - # the early stopping criterion is met - if self.counter >= self.patience: - print('Early stopping criterion met, exiting training ...') - self.early_stop = True - - else: + # if the metric improved, reset the epochs counter, else, advance + if self.is_better(metric, self.best, self.min_delta): + self.counter = 0 self.best = metric + else: + self.counter += 1 + LOGGER.info('Early stopping counter: {}/{}'.format( + self.counter, self.patience)) + + # if the metric did not improve over the last patience epochs, + # the early stopping criterion is met + if self.counter >= self.patience: + LOGGER.info('Early stopping criterion met, stopping training.') + self.early_stop = True return self.early_stop @@ -746,6 +760,7 @@ class EarlyStopping(object): return metric > best + min_delta def __repr__(self): - fs = (self.__class__.__name__ + '(mode={}, delta={}, patience={})' - .format(self.mode, self.min_delta, self.patience)) + fs = self.__class__.__name__ + fs += '(mode={}, best={}, delta={}, patience={})'.format( + self.mode, self.best, self.min_delta, self.patience) return fs diff --git a/pysegcnn/main/train.py b/pysegcnn/main/train.py index 6c036ea6e10ae589595af6c98ff4246eae19afd0..bcffb54dd38220e5e308df477dc7428fe65db614 100644 --- a/pysegcnn/main/train.py +++ b/pysegcnn/main/train.py @@ -5,11 +5,14 @@ Created on Tue Jun 30 09:33:38 2020 @author: Daniel """ +# builtins +import logging + # locals from pysegcnn.core.trainer import (DatasetConfig, SplitConfig, ModelConfig, - TrainConfig, NetworkTrainer) -from pysegcnn.main.config import (dataset_config, split_config, - model_config, train_config) + StateConfig, NetworkTrainer) +from pysegcnn.core.logging import log_conf +from pysegcnn.main.config import (dataset_config, split_config, model_config) if __name__ == '__main__': @@ -20,35 +23,36 @@ if __name__ == '__main__': dc = DatasetConfig(**dataset_config) sc = SplitConfig(**split_config) mc = ModelConfig(**model_config) - tc = TrainConfig(**train_config) # (ii) instanciate the dataset ds = dc.init_dataset() ds - # (iii) instanciate the training, validation and test datasets and + # (iii) instanciate the model state + state = StateConfig(ds, sc, mc) + state_file, loss_state = state.init_state() + + # initialize logging + log_file = str(state_file).replace('.pt', '_train.log') + logging.config.dictConfig(log_conf(log_file)) + + # (iv) instanciate the training, validation and test datasets and # dataloaders train_ds, valid_ds, test_ds = sc.train_val_test_split(ds) - train_dl, valid_dl, test_dl = sc.dataloaders(train_ds, - valid_ds, - test_ds, - batch_size=mc.batch_size, - shuffle=True, - drop_last=False) - - # (iv) instanciate the model state files - state_file, loss_state = mc.init_state(ds, sc, tc) + train_dl, valid_dl, test_dl = sc.dataloaders( + train_ds, valid_ds, test_ds, batch_size=mc.batch_size, shuffle=True, + drop_last=False) - # (v) instanciate the model + # (iv) instanciate the model model = mc.init_model(ds) # (vi) instanciate the optimizer and the loss function - optimizer = tc.init_optimizer(model) - loss_function = tc.init_loss_function() + optimizer = mc.init_optimizer(model) + loss_function = mc.init_loss_function() # (vii) resume training from an existing model checkpoint - checkpoint_state, max_accuracy = mc.load_checkpoint(state_file, loss_state, - model, optimizer) + (model, optimizer, checkpoint_state, max_accuracy) = mc.from_checkpoint( + model, optimizer, state_file, loss_state) # (viii) initialize network trainer class for eays model training trainer = NetworkTrainer(model=model, @@ -58,15 +62,15 @@ if __name__ == '__main__': valid_dl=valid_dl, state_file=state_file, loss_state=loss_state, - epochs=tc.epochs, - nthreads=tc.nthreads, - early_stop=tc.early_stop, - mode=tc.mode, - delta=tc.delta, - patience=tc.patience, + epochs=mc.epochs, + nthreads=mc.nthreads, + early_stop=mc.early_stop, + mode=mc.mode, + delta=mc.delta, + patience=mc.patience, max_accuracy=max_accuracy, checkpoint_state=checkpoint_state, - save=tc.save + save=mc.save ) # (ix) train model