diff --git a/pytorch/trainer.py b/pytorch/trainer.py index e6ee801df9cd6ff8866305fbb3021e1eecd7e7f2..dbcfb805442d2ad20cf40e10aeb2fc7db3f84262 100755 --- a/pytorch/trainer.py +++ b/pytorch/trainer.py @@ -20,7 +20,7 @@ sys.path.append('..') # local modules from pytorch.dataset import SparcsDataset, Cloud95Dataset -from pytorch.constants import SparcsLabels, Cloud95Labels +from pytorch.layers import Conv2dSame class NetworkTrainer(object): @@ -31,8 +31,6 @@ class NetworkTrainer(object): for k, v in config.items(): setattr(self, k, v) - def initialize(self): - # check which dataset the model is trained on if self.dataset_name == 'Sparcs': # instanciate the SparcsDataset @@ -81,6 +79,12 @@ class NetworkTrainer(object): self.dataset, self.tvratio, self.ttratio, self.seed) print('--------------------------------------------------------------') + # number of batches in the validation set + self.nvbatches = int(len(self.valid_ds) / self.batch_size) + + # number of batches in the training set + self.nbatches = int(len(self.train_ds) / self.batch_size) + # the training and validation dataloaders self.train_dl = DataLoader(self.train_ds, self.batch_size, @@ -108,48 +112,49 @@ class NetworkTrainer(object): self.batch_size, bformat)) + # check whether a pretrained model was used and change state filename + # accordingly + if self.pretrained: + # add the configuration of the pretrained model to the state name + self.state_file = (self.state_file.replace('.pt', '_') + + 'pretrained_' + self.pretrained_model) + # path to model state self.state = os.path.join(self.state_path, self.state_file) # path to model loss/accuracy self.loss_state = self.state.replace('.pt', '_loss.pt') - def from_pretrained(self): - # name of the dataset the pretrained model was trained on - dataset_name = self.pretrained_model.split('_')[1] + # load the pretrained model + model_state = torch.load( + os.path.join(self.state_path, self.pretrained_model)) - # input bands of the pretrained model - bands = self.pretrained_model.split('_')[-1].split('.')[0] + # get the input bands of the pretrained model + bands = model_state['bands'] - if dataset_name == SparcsDataset.__name__: + # get the number of convolutional filters + filters = model_state['params']['filters'] - # number of input channels - in_channels = len(bands) if bands != 'all' else 10 + # check whether the current dataset uses the correct spectral bands + if self.bands != bands: + raise ValueError('The bands of the pretrained network do not ' + 'match the specified bands: {}' + .format(self.bands)) - # instanciate pretrained model architecture - model = self.net(in_channels=in_channels, - nclasses=len(SparcsLabels), - filters=self.filters, - skip=self.skip_connection, - **self.kwargs) - - if dataset_name == Cloud95Dataset.__name__: - - # number of input channels - in_channels = len(bands) if bands != 'all' else 4 - - # instanciate pretrained model architecture - model = self.net(in_channels=in_channels, - nclasses=len(Cloud95Labels), - filters=self.filters, - skip=self.skip_connection, - **self.kwargs) + # instanciate pretrained model architecture + model = self.net(**model_state['params'], **model_state['kwargs']) # load pretrained model weights model.load(self.pretrained_model, inpath=self.state_path) + # adjust the classification layer to the number of classes of the + # current dataset + model.classifier = Conv2dSame(in_channels=filters[0], + out_channels=len(self.dataset.labels), + kernel_size=1) + return model @@ -191,6 +196,22 @@ class NetworkTrainer(object): def accuracy_function(self, outputs, labels): return (outputs == labels).float().mean() + def _save_loss(self, training_state, checkpoint=False, + checkpoint_state=None): + + # save losses and accuracy + if checkpoint and checkpoint_state is not None: + + # append values from checkpoint to current training + # state + torch.save({ + k1: np.hstack([v1, v2]) for (k1, v1), (k2, v2) in + zip(checkpoint_state.items(), training_state.items()) + if k1 == k2}, + self.loss_state) + else: + torch.save(training_state, self.loss_state) + def train(self): # set the number of threads @@ -206,31 +227,36 @@ class NetworkTrainer(object): # initial accuracy on the validation set max_accuracy = 0 + # create dictionary of the observed losses and accuracies on the + # training and validation dataset + training_state = {'tl': np.zeros(shape=(self.nbatches, self.epochs)), + 'ta': np.zeros(shape=(self.nbatches, self.epochs)), + 'vl': np.zeros(shape=(self.nvbatches, self.epochs)), + 'va': np.zeros(shape=(self.nvbatches, self.epochs)) + } + # whether to resume training from an existing model + checkpoint_state = None if os.path.exists(self.state) and self.checkpoint: + # load the model state state = self.model.load(self.state_file, self.optimizer, self.state_path) print('Resuming training from {} ...'.format(state)) print('Model epoch: {:d}'.format(self.model.epoch)) - # send the model to the gpu if available - self.model = self.model.to(self.device) + # load the model loss and accuracy + checkpoint_state = torch.load(self.loss_state) - # number of batches in the validation set - nvbatches = int(len(self.valid_ds) / self.batch_size) - - # number of batches in the training set - nbatches = int(len(self.train_ds) / self.batch_size) + # get all non-zero elements, i.e. get number of epochs trained + # before the early stop + checkpoint_state = {k: v[np.nonzero(v)].reshape(v.shape[0], -1) for + k, v in checkpoint_state.items()} - # create arrays of the observed losses and accuracies on the - # training set - losses = np.zeros(shape=(nbatches, self.epochs)) - accuracies = np.zeros(shape=(nbatches, self.epochs)) + # maximum accuracy on the validation set + max_accuracy = checkpoint_state['va'][:, -1].mean().item() - # create arrays of the observed losses and accuracies on the - # validation set - vlosses = np.zeros(shape=(nvbatches, self.epochs)) - vaccuracies = np.zeros(shape=(nvbatches, self.epochs)) + # send the model to the gpu if available + self.model = self.model.to(self.device) # initialize the training: iterate over the entire training data set for epoch in range(self.epochs): @@ -254,7 +280,8 @@ class NetworkTrainer(object): # compute loss loss = self.loss_function(outputs, labels.long()) - losses[batch, epoch] = loss.detach().numpy().item() + observed_loss = loss.detach().numpy().item() + training_state['tl'][batch, epoch] = observed_loss # compute the gradients of the loss function w.r.t. # the network weights @@ -267,14 +294,17 @@ class NetworkTrainer(object): ypred = F.softmax(outputs, dim=1).argmax(dim=1) # calculate accuracy on current batch - acc = self.accuracy_function(ypred, labels) - accuracies[batch, epoch] = acc + observed_accuracy = self.accuracy_function(ypred, labels) + training_state['ta'][batch, epoch] = observed_accuracy # print progress print('Epoch: {:d}/{:d}, Batch: {:d}/{:d}, Loss: {:.2f}, ' - 'Accuracy: {:.2f}'.format(epoch, self.epochs, batch, - nbatches, losses[batch, epoch], - acc)) + 'Accuracy: {:.2f}'.format(epoch, + self.epochs, + batch, + self.nbatches, + observed_loss, + observed_accuracy)) # update the number of epochs trained self.model.epoch += 1 @@ -287,8 +317,8 @@ class NetworkTrainer(object): _, vacc, vloss = self.predict() # append observed accuracy and loss to arrays - vaccuracies[:, epoch] = vacc.squeeze() - vlosses[:, epoch] = vloss.squeeze() + training_state['va'][:, epoch] = vacc.squeeze() + training_state['vl'][:, epoch] = vloss.squeeze() # metric to assess model performance on the validation set epoch_acc = vacc.squeeze().mean() @@ -298,37 +328,34 @@ class NetworkTrainer(object): max_accuracy = epoch_acc # save model state if the model improved with # respect to the previous epoch - _ = self.model.save(self.optimizer, self.state_file, + _ = self.model.save(self.state_file, + self.optimizer, + self.bands, self.state_path) + # save losses and accuracy + self._save_loss(training_state, + self.checkpoint, + checkpoint_state) + # whether the early stopping criterion is met if es.stop(epoch_acc): - - # save losses and accuracy before exiting training - torch.save({'epoch': epoch, - 'training_loss': losses, - 'training_accuracy': accuracies, - 'validation_loss': vlosses, - 'validation_accuracy': vaccuracies}, - self.loss_state) - break else: # if no early stopping is required, the model state is saved # after each epoch - _ = self.model.save(self.optimizer, self.state_file, + _ = self.model.save(self.state_file, + self.optimizer, + self.bands, self.state_path) - # save losses and accuracy after each epoch to file - torch.save({'epoch': epoch, - 'training_loss': losses, - 'training_accuracy': accuracies, - 'validation_loss': vlosses, - 'validation_accuracy': vaccuracies}, - self.loss_state) + # save losses and accuracy after each epoch + self._save_loss(training_state, + self.checkpoint, + checkpoint_state) - return losses, accuracies, vlosses, vaccuracies + return training_state def predict(self, pretrained=False, confusion=False): @@ -347,12 +374,9 @@ class NetworkTrainer(object): # initialize confusion matrix cm = torch.zeros(self.model.nclasses, self.model.nclasses) - # number of batches in the validation set - nbatches = int(len(self.valid_ds) / self.batch_size) - # create arrays of the observed losses and accuracies - accuracies = np.zeros(shape=(nbatches, 1)) - losses = np.zeros(shape=(nbatches, 1)) + accuracies = np.zeros(shape=(self.nvbatches, 1)) + losses = np.zeros(shape=(self.nvbatches, 1)) # iterate over the validation/test set print('Calculating accuracy on validation set ...') @@ -379,7 +403,8 @@ class NetworkTrainer(object): # print progress print('Batch: {:d}/{:d}, Accuracy: {:.2f}'.format(batch, - nbatches, acc)) + self.nvbatches, + acc)) # update confusion matrix if confusion: