From 6e330d114a7ca8f2575dc900952a161e739af211 Mon Sep 17 00:00:00 2001 From: "Daniel.Frisinghelli" <daniel.frisinghelli@eurac.edu> Date: Tue, 11 Aug 2020 16:41:15 +0200 Subject: [PATCH] Improved initialization of dataset and model --- pysegcnn/core/trainer.py | 166 +++++++++++++++++++++++---------------- 1 file changed, 100 insertions(+), 66 deletions(-) diff --git a/pysegcnn/core/trainer.py b/pysegcnn/core/trainer.py index 96bb55e..09239c9 100644 --- a/pysegcnn/core/trainer.py +++ b/pysegcnn/core/trainer.py @@ -22,7 +22,6 @@ from pysegcnn.core.split import (RandomTileSplit, RandomSceneSplit, DateSplit, VALID_SPLIT_MODES) - class NetworkTrainer(object): def __init__(self, config): @@ -38,12 +37,11 @@ class NetworkTrainer(object): # initialize the dataset to train the model on self._init_dataset() - # initialize the model - self._init_model() - # initialize the model state files self._init_state() + # initialize the model + self._init_model() def from_pretrained(self): @@ -69,7 +67,7 @@ class NetworkTrainer(object): .format(self.bands)) # instanciate pretrained model architecture - model = self.net(**model_state['params'], **model_state['kwargs']) + model = self.model(**model_state['params'], **model_state['kwargs']) # load pretrained model weights model.load(self.pretrained_model, inpath=self.state_path) @@ -78,39 +76,41 @@ class NetworkTrainer(object): # dataset model.epoch = 0 + # adjust the number of classes in the model + model.nclasses = len(self.dataset.labels) + # adjust the classification layer to the number of classes of the # current dataset model.classifier = Conv2dSame(in_channels=filters[0], - out_channels=len(self.dataset.labels), + out_channels=model.nclasses, kernel_size=1) - # adjust the number of classes in the model - model.nclasses = len(self.dataset.labels) return model def from_checkpoint(self): # whether to resume training from an existing model - checkpoint_state = None - max_accuracy = 0 - if os.path.exists(self.state) and self.checkpoint: - # load the model state - state = self.model.load(self.state_file, self.optimizer, - self.state_path) - print('Resuming training from {} ...'.format(state)) - print('Model epoch: {:d}'.format(self.model.epoch)) + if not os.path.exists(self.state): + raise FileNotFoundError('Model checkpoint {} does not exist.' + .format(self.state)) + + # load the model state + state = self.model.load(self.state_file, self.optimizer, + self.state_path) + print('Resuming training from {} ...'.format(state)) + print('Model epoch: {:d}'.format(self.model.epoch)) - # load the model loss and accuracy - checkpoint_state = torch.load(self.loss_state) + # load the model loss and accuracy + checkpoint_state = torch.load(self.loss_state) - # get all non-zero elements, i.e. get number of epochs trained - # before the early stop - checkpoint_state = {k: v[np.nonzero(v)].reshape(v.shape[0], -1) for - k, v in checkpoint_state.items()} + # get all non-zero elements, i.e. get number of epochs trained + # before the early stop + checkpoint_state = {k: v[np.nonzero(v)].reshape(v.shape[0], -1) for + k, v in checkpoint_state.items()} - # maximum accuracy on the validation set - max_accuracy = checkpoint_state['va'][:, -1].mean().item() + # maximum accuracy on the validation set + max_accuracy = checkpoint_state['va'][:, -1].mean().item() return checkpoint_state, max_accuracy @@ -128,9 +128,6 @@ class NetworkTrainer(object): print('mode = {}, delta = {}, patience = {} epochs ...' .format(self.mode, self.delta, self.patience)) - # initial accuracy on the validation set - max_accuracy = 0 - # create dictionary of the observed losses and accuracies on the # training and validation dataset tshape = (len(self.train_dl), self.epochs) @@ -141,9 +138,6 @@ class NetworkTrainer(object): 'va': np.zeros(shape=vshape) } - # whether to resume training from an existing model - checkpoint_state, max_accuracy = self.from_checkpoint() - # send the model to the gpu if available self.model = self.model.to(self.device) @@ -188,7 +182,8 @@ class NetworkTrainer(object): # print progress print('Epoch: {:d}/{:d}, Mini-batch: {:d}/{:d}, Loss: {:.2f}, ' - 'Accuracy: {:.2f}'.format(epoch + 1, self.epochs, + 'Accuracy: {:.2f}'.format(epoch + 1, + self.epochs, batch + 1, len(self.train_dl), observed_loss, @@ -212,8 +207,8 @@ class NetworkTrainer(object): epoch_acc = vacc.squeeze().mean() # whether the model improved with respect to the previous epoch - if es.increased(epoch_acc, max_accuracy, self.delta): - max_accuracy = epoch_acc + if es.increased(epoch_acc, self.max_accuracy, self.delta): + self.max_accuracy = epoch_acc # save model state if the model improved with # respect to the previous epoch _ = self.model.save(self.state_file, @@ -224,7 +219,7 @@ class NetworkTrainer(object): # save losses and accuracy self._save_loss(training_state, self.checkpoint, - checkpoint_state) + self.checkpoint_state) # whether the early stopping criterion is met if es.stop(epoch_acc): @@ -241,7 +236,7 @@ class NetworkTrainer(object): # save losses and accuracy after each epoch self._save_loss(training_state, self.checkpoint, - checkpoint_state) + self.checkpoint_state) return training_state @@ -300,7 +295,7 @@ class NetworkTrainer(object): # format: networkname_datasetname_t(tilesize)_b(batchsize)_bands.pt bformat = ''.join([b[0] for b in self.bands]) if self.bands else 'all' self.state_file = ('{}_{}_t{}_b{}_{}.pt' - .format(self.model.__class__.__name__, + .format(self.model.__name__, self.dataset.__class__.__name__, self.tile_size, self.batch_size, @@ -321,34 +316,38 @@ class NetworkTrainer(object): def _init_dataset(self): - # check whether the dataset is currently supported - self.dataset = None - for dataset in SupportedDatasets: - if self.dataset_name == dataset.name: - self.dataset = dataset.value['class']( - self.dataset_path, - use_bands=self.bands, - tile_size=self.tile_size, - sort=self.sort, - transforms=self.transforms, - pad=self.pad, - cval=self.cval, - gt_pattern=self.gt_pattern) + # the dataset name + self.dataset_name = os.path.basename(self.root_dir) - if self.dataset is None: + # check whether the dataset is currently supported + if self.dataset_name not in SupportedDatasets.__members__: raise ValueError('{} is not a valid dataset. ' .format(self.dataset_name) + 'Available datasets are: \n' + '\n'.join(name for name, _ in SupportedDatasets.__members__.items())) + else: + self.dataset_class = SupportedDatasets.__members__[ + self.dataset_name].value + # instanciate the dataset + self.dataset = self.dataset_class( + self.root_dir, + use_bands=self.bands, + tile_size=self.tile_size, + sort=self.sort, + transforms=self.transforms, + pad=self.pad, + cval=self.cval, + gt_pattern=self.gt_pattern + ) # the mode to split if self.split_mode not in VALID_SPLIT_MODES: raise ValueError('{} is not supported. Valid modes are {}, see ' - 'pysegcnn.main.config.py for a description of ' - 'each mode.'.format(self.split_mode, - VALID_SPLIT_MODES)) + 'pysegcnn.main.config.py for a description of ' + 'each mode.'.format(self.split_mode, + VALID_SPLIT_MODES)) if self.split_mode == 'random': self.subset = RandomTileSplit(self.dataset, self.ttratio, @@ -364,12 +363,12 @@ class NetworkTrainer(object): self.date, self.dateformat) - # the training, validation and dataset + # the training, validation and test dataset self.train_ds, self.valid_ds, self.test_ds = self.subset.split() # whether to drop training samples with a fraction of pixels equal to # the constant padding value self.cval >= self.drop - if self.pad: + if self.pad and self.drop: self._drop(self.train_ds) # the shape of a single batch @@ -400,18 +399,53 @@ class NetworkTrainer(object): def _init_model(self): - # instanciate the segmentation network - if self.pretrained: + # initial accuracy on the validation set + self.max_accuracy = 0 + + # set the model checkpoint to None, overwritten when resuming + # training from an existing model checkpoint + self.checkpoint_state = None + + # case (1): build a model for the specified dataset + if not self.pretrained and not self.checkpoint: + + # instanciate the model + self.model = self.model(in_channels=len(self.dataset.use_bands), + nclasses=len(self.dataset.labels), + filters=self.filters, + skip=self.skip_connection, + **self.kwargs) + + # the optimizer used to update the model weights + self.optimizer = self.optimizer(self.model.parameters(), self.lr) + + # case (2): using a pretrained model withouth existing checkpoint on + # a new dataset, i.e. transfer learning + if self.pretrained and not self.checkpoint: + # load pretrained model self.model = self.from_pretrained() - else: - self.model = self.net(in_channels=len(self.dataset.use_bands), - nclasses=len(self.dataset.labels), - filters=self.filters, - skip=self.skip_connection, - **self.kwargs) - - # the optimizer used to update the model weights - self.optimizer = self.optimizer(self.model.parameters(), self.lr) + + # the optimizer used to update the model weights + self.optimizer = self.optimizer(self.model.parameters(), self.lr) + + # case (3): using a pretrained model with existing checkpoint on the + # same dataset the pretrained model was trained on + elif self.checkpoint: + + # instanciate the model + self.model = self.model(in_channels=len(self.dataset.use_bands), + nclasses=len(self.dataset.labels), + filters=self.filters, + skip=self.skip_connection, + **self.kwargs) + + # the optimizer used to update the model weights + self.optimizer = self.optimizer(self.model.parameters(), self.lr) + + # whether to resume training from an existing model checkpoint + if self.checkpoint: + (self.checkpoint_state, + self.max_accuracy) = self.from_checkpoint() # function to drop samples with a fraction of pixels equal to the constant # padding value self.cval >= self.drop -- GitLab