Skip to content
Snippets Groups Projects
Commit d5a08ef8 authored by Frisinghelli Daniel's avatar Frisinghelli Daniel
Browse files

Added support to use a pretrained model on a differentdataset for transfer learning tasks

parent bac6bc17
No related branches found
No related tags found
No related merge requests found
......@@ -20,7 +20,7 @@ sys.path.append('..')
# local modules
from pytorch.dataset import SparcsDataset, Cloud95Dataset
from pytorch.constants import SparcsLabels, Cloud95Labels
from pytorch.layers import Conv2dSame
class NetworkTrainer(object):
......@@ -31,8 +31,6 @@ class NetworkTrainer(object):
for k, v in config.items():
setattr(self, k, v)
def initialize(self):
# check which dataset the model is trained on
if self.dataset_name == 'Sparcs':
# instanciate the SparcsDataset
......@@ -81,6 +79,12 @@ class NetworkTrainer(object):
self.dataset, self.tvratio, self.ttratio, self.seed)
print('--------------------------------------------------------------')
# number of batches in the validation set
self.nvbatches = int(len(self.valid_ds) / self.batch_size)
# number of batches in the training set
self.nbatches = int(len(self.train_ds) / self.batch_size)
# the training and validation dataloaders
self.train_dl = DataLoader(self.train_ds,
self.batch_size,
......@@ -108,48 +112,49 @@ class NetworkTrainer(object):
self.batch_size,
bformat))
# check whether a pretrained model was used and change state filename
# accordingly
if self.pretrained:
# add the configuration of the pretrained model to the state name
self.state_file = (self.state_file.replace('.pt', '_') +
'pretrained_' + self.pretrained_model)
# path to model state
self.state = os.path.join(self.state_path, self.state_file)
# path to model loss/accuracy
self.loss_state = self.state.replace('.pt', '_loss.pt')
def from_pretrained(self):
# name of the dataset the pretrained model was trained on
dataset_name = self.pretrained_model.split('_')[1]
# load the pretrained model
model_state = torch.load(
os.path.join(self.state_path, self.pretrained_model))
# input bands of the pretrained model
bands = self.pretrained_model.split('_')[-1].split('.')[0]
# get the input bands of the pretrained model
bands = model_state['bands']
if dataset_name == SparcsDataset.__name__:
# get the number of convolutional filters
filters = model_state['params']['filters']
# number of input channels
in_channels = len(bands) if bands != 'all' else 10
# check whether the current dataset uses the correct spectral bands
if self.bands != bands:
raise ValueError('The bands of the pretrained network do not '
'match the specified bands: {}'
.format(self.bands))
# instanciate pretrained model architecture
model = self.net(in_channels=in_channels,
nclasses=len(SparcsLabels),
filters=self.filters,
skip=self.skip_connection,
**self.kwargs)
if dataset_name == Cloud95Dataset.__name__:
# number of input channels
in_channels = len(bands) if bands != 'all' else 4
# instanciate pretrained model architecture
model = self.net(in_channels=in_channels,
nclasses=len(Cloud95Labels),
filters=self.filters,
skip=self.skip_connection,
**self.kwargs)
# instanciate pretrained model architecture
model = self.net(**model_state['params'], **model_state['kwargs'])
# load pretrained model weights
model.load(self.pretrained_model, inpath=self.state_path)
# adjust the classification layer to the number of classes of the
# current dataset
model.classifier = Conv2dSame(in_channels=filters[0],
out_channels=len(self.dataset.labels),
kernel_size=1)
return model
......@@ -191,6 +196,22 @@ class NetworkTrainer(object):
def accuracy_function(self, outputs, labels):
return (outputs == labels).float().mean()
def _save_loss(self, training_state, checkpoint=False,
checkpoint_state=None):
# save losses and accuracy
if checkpoint and checkpoint_state is not None:
# append values from checkpoint to current training
# state
torch.save({
k1: np.hstack([v1, v2]) for (k1, v1), (k2, v2) in
zip(checkpoint_state.items(), training_state.items())
if k1 == k2},
self.loss_state)
else:
torch.save(training_state, self.loss_state)
def train(self):
# set the number of threads
......@@ -206,31 +227,36 @@ class NetworkTrainer(object):
# initial accuracy on the validation set
max_accuracy = 0
# create dictionary of the observed losses and accuracies on the
# training and validation dataset
training_state = {'tl': np.zeros(shape=(self.nbatches, self.epochs)),
'ta': np.zeros(shape=(self.nbatches, self.epochs)),
'vl': np.zeros(shape=(self.nvbatches, self.epochs)),
'va': np.zeros(shape=(self.nvbatches, self.epochs))
}
# whether to resume training from an existing model
checkpoint_state = None
if os.path.exists(self.state) and self.checkpoint:
# load the model state
state = self.model.load(self.state_file, self.optimizer,
self.state_path)
print('Resuming training from {} ...'.format(state))
print('Model epoch: {:d}'.format(self.model.epoch))
# send the model to the gpu if available
self.model = self.model.to(self.device)
# load the model loss and accuracy
checkpoint_state = torch.load(self.loss_state)
# number of batches in the validation set
nvbatches = int(len(self.valid_ds) / self.batch_size)
# number of batches in the training set
nbatches = int(len(self.train_ds) / self.batch_size)
# get all non-zero elements, i.e. get number of epochs trained
# before the early stop
checkpoint_state = {k: v[np.nonzero(v)].reshape(v.shape[0], -1) for
k, v in checkpoint_state.items()}
# create arrays of the observed losses and accuracies on the
# training set
losses = np.zeros(shape=(nbatches, self.epochs))
accuracies = np.zeros(shape=(nbatches, self.epochs))
# maximum accuracy on the validation set
max_accuracy = checkpoint_state['va'][:, -1].mean().item()
# create arrays of the observed losses and accuracies on the
# validation set
vlosses = np.zeros(shape=(nvbatches, self.epochs))
vaccuracies = np.zeros(shape=(nvbatches, self.epochs))
# send the model to the gpu if available
self.model = self.model.to(self.device)
# initialize the training: iterate over the entire training data set
for epoch in range(self.epochs):
......@@ -254,7 +280,8 @@ class NetworkTrainer(object):
# compute loss
loss = self.loss_function(outputs, labels.long())
losses[batch, epoch] = loss.detach().numpy().item()
observed_loss = loss.detach().numpy().item()
training_state['tl'][batch, epoch] = observed_loss
# compute the gradients of the loss function w.r.t.
# the network weights
......@@ -267,14 +294,17 @@ class NetworkTrainer(object):
ypred = F.softmax(outputs, dim=1).argmax(dim=1)
# calculate accuracy on current batch
acc = self.accuracy_function(ypred, labels)
accuracies[batch, epoch] = acc
observed_accuracy = self.accuracy_function(ypred, labels)
training_state['ta'][batch, epoch] = observed_accuracy
# print progress
print('Epoch: {:d}/{:d}, Batch: {:d}/{:d}, Loss: {:.2f}, '
'Accuracy: {:.2f}'.format(epoch, self.epochs, batch,
nbatches, losses[batch, epoch],
acc))
'Accuracy: {:.2f}'.format(epoch,
self.epochs,
batch,
self.nbatches,
observed_loss,
observed_accuracy))
# update the number of epochs trained
self.model.epoch += 1
......@@ -287,8 +317,8 @@ class NetworkTrainer(object):
_, vacc, vloss = self.predict()
# append observed accuracy and loss to arrays
vaccuracies[:, epoch] = vacc.squeeze()
vlosses[:, epoch] = vloss.squeeze()
training_state['va'][:, epoch] = vacc.squeeze()
training_state['vl'][:, epoch] = vloss.squeeze()
# metric to assess model performance on the validation set
epoch_acc = vacc.squeeze().mean()
......@@ -298,37 +328,34 @@ class NetworkTrainer(object):
max_accuracy = epoch_acc
# save model state if the model improved with
# respect to the previous epoch
_ = self.model.save(self.optimizer, self.state_file,
_ = self.model.save(self.state_file,
self.optimizer,
self.bands,
self.state_path)
# save losses and accuracy
self._save_loss(training_state,
self.checkpoint,
checkpoint_state)
# whether the early stopping criterion is met
if es.stop(epoch_acc):
# save losses and accuracy before exiting training
torch.save({'epoch': epoch,
'training_loss': losses,
'training_accuracy': accuracies,
'validation_loss': vlosses,
'validation_accuracy': vaccuracies},
self.loss_state)
break
else:
# if no early stopping is required, the model state is saved
# after each epoch
_ = self.model.save(self.optimizer, self.state_file,
_ = self.model.save(self.state_file,
self.optimizer,
self.bands,
self.state_path)
# save losses and accuracy after each epoch to file
torch.save({'epoch': epoch,
'training_loss': losses,
'training_accuracy': accuracies,
'validation_loss': vlosses,
'validation_accuracy': vaccuracies},
self.loss_state)
# save losses and accuracy after each epoch
self._save_loss(training_state,
self.checkpoint,
checkpoint_state)
return losses, accuracies, vlosses, vaccuracies
return training_state
def predict(self, pretrained=False, confusion=False):
......@@ -347,12 +374,9 @@ class NetworkTrainer(object):
# initialize confusion matrix
cm = torch.zeros(self.model.nclasses, self.model.nclasses)
# number of batches in the validation set
nbatches = int(len(self.valid_ds) / self.batch_size)
# create arrays of the observed losses and accuracies
accuracies = np.zeros(shape=(nbatches, 1))
losses = np.zeros(shape=(nbatches, 1))
accuracies = np.zeros(shape=(self.nvbatches, 1))
losses = np.zeros(shape=(self.nvbatches, 1))
# iterate over the validation/test set
print('Calculating accuracy on validation set ...')
......@@ -379,7 +403,8 @@ class NetworkTrainer(object):
# print progress
print('Batch: {:d}/{:d}, Accuracy: {:.2f}'.format(batch,
nbatches, acc))
self.nvbatches,
acc))
# update confusion matrix
if confusion:
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment