Skip to content
Snippets Groups Projects
Commit 6e330d11 authored by Frisinghelli Daniel's avatar Frisinghelli Daniel
Browse files

Improved initialization of dataset and model

parent 7cce87b3
No related branches found
No related tags found
No related merge requests found
......@@ -22,7 +22,6 @@ from pysegcnn.core.split import (RandomTileSplit, RandomSceneSplit, DateSplit,
VALID_SPLIT_MODES)
class NetworkTrainer(object):
def __init__(self, config):
......@@ -38,12 +37,11 @@ class NetworkTrainer(object):
# initialize the dataset to train the model on
self._init_dataset()
# initialize the model
self._init_model()
# initialize the model state files
self._init_state()
# initialize the model
self._init_model()
def from_pretrained(self):
......@@ -69,7 +67,7 @@ class NetworkTrainer(object):
.format(self.bands))
# instanciate pretrained model architecture
model = self.net(**model_state['params'], **model_state['kwargs'])
model = self.model(**model_state['params'], **model_state['kwargs'])
# load pretrained model weights
model.load(self.pretrained_model, inpath=self.state_path)
......@@ -78,39 +76,41 @@ class NetworkTrainer(object):
# dataset
model.epoch = 0
# adjust the number of classes in the model
model.nclasses = len(self.dataset.labels)
# adjust the classification layer to the number of classes of the
# current dataset
model.classifier = Conv2dSame(in_channels=filters[0],
out_channels=len(self.dataset.labels),
out_channels=model.nclasses,
kernel_size=1)
# adjust the number of classes in the model
model.nclasses = len(self.dataset.labels)
return model
def from_checkpoint(self):
# whether to resume training from an existing model
checkpoint_state = None
max_accuracy = 0
if os.path.exists(self.state) and self.checkpoint:
# load the model state
state = self.model.load(self.state_file, self.optimizer,
self.state_path)
print('Resuming training from {} ...'.format(state))
print('Model epoch: {:d}'.format(self.model.epoch))
if not os.path.exists(self.state):
raise FileNotFoundError('Model checkpoint {} does not exist.'
.format(self.state))
# load the model state
state = self.model.load(self.state_file, self.optimizer,
self.state_path)
print('Resuming training from {} ...'.format(state))
print('Model epoch: {:d}'.format(self.model.epoch))
# load the model loss and accuracy
checkpoint_state = torch.load(self.loss_state)
# load the model loss and accuracy
checkpoint_state = torch.load(self.loss_state)
# get all non-zero elements, i.e. get number of epochs trained
# before the early stop
checkpoint_state = {k: v[np.nonzero(v)].reshape(v.shape[0], -1) for
k, v in checkpoint_state.items()}
# get all non-zero elements, i.e. get number of epochs trained
# before the early stop
checkpoint_state = {k: v[np.nonzero(v)].reshape(v.shape[0], -1) for
k, v in checkpoint_state.items()}
# maximum accuracy on the validation set
max_accuracy = checkpoint_state['va'][:, -1].mean().item()
# maximum accuracy on the validation set
max_accuracy = checkpoint_state['va'][:, -1].mean().item()
return checkpoint_state, max_accuracy
......@@ -128,9 +128,6 @@ class NetworkTrainer(object):
print('mode = {}, delta = {}, patience = {} epochs ...'
.format(self.mode, self.delta, self.patience))
# initial accuracy on the validation set
max_accuracy = 0
# create dictionary of the observed losses and accuracies on the
# training and validation dataset
tshape = (len(self.train_dl), self.epochs)
......@@ -141,9 +138,6 @@ class NetworkTrainer(object):
'va': np.zeros(shape=vshape)
}
# whether to resume training from an existing model
checkpoint_state, max_accuracy = self.from_checkpoint()
# send the model to the gpu if available
self.model = self.model.to(self.device)
......@@ -188,7 +182,8 @@ class NetworkTrainer(object):
# print progress
print('Epoch: {:d}/{:d}, Mini-batch: {:d}/{:d}, Loss: {:.2f}, '
'Accuracy: {:.2f}'.format(epoch + 1, self.epochs,
'Accuracy: {:.2f}'.format(epoch + 1,
self.epochs,
batch + 1,
len(self.train_dl),
observed_loss,
......@@ -212,8 +207,8 @@ class NetworkTrainer(object):
epoch_acc = vacc.squeeze().mean()
# whether the model improved with respect to the previous epoch
if es.increased(epoch_acc, max_accuracy, self.delta):
max_accuracy = epoch_acc
if es.increased(epoch_acc, self.max_accuracy, self.delta):
self.max_accuracy = epoch_acc
# save model state if the model improved with
# respect to the previous epoch
_ = self.model.save(self.state_file,
......@@ -224,7 +219,7 @@ class NetworkTrainer(object):
# save losses and accuracy
self._save_loss(training_state,
self.checkpoint,
checkpoint_state)
self.checkpoint_state)
# whether the early stopping criterion is met
if es.stop(epoch_acc):
......@@ -241,7 +236,7 @@ class NetworkTrainer(object):
# save losses and accuracy after each epoch
self._save_loss(training_state,
self.checkpoint,
checkpoint_state)
self.checkpoint_state)
return training_state
......@@ -300,7 +295,7 @@ class NetworkTrainer(object):
# format: networkname_datasetname_t(tilesize)_b(batchsize)_bands.pt
bformat = ''.join([b[0] for b in self.bands]) if self.bands else 'all'
self.state_file = ('{}_{}_t{}_b{}_{}.pt'
.format(self.model.__class__.__name__,
.format(self.model.__name__,
self.dataset.__class__.__name__,
self.tile_size,
self.batch_size,
......@@ -321,34 +316,38 @@ class NetworkTrainer(object):
def _init_dataset(self):
# check whether the dataset is currently supported
self.dataset = None
for dataset in SupportedDatasets:
if self.dataset_name == dataset.name:
self.dataset = dataset.value['class'](
self.dataset_path,
use_bands=self.bands,
tile_size=self.tile_size,
sort=self.sort,
transforms=self.transforms,
pad=self.pad,
cval=self.cval,
gt_pattern=self.gt_pattern)
# the dataset name
self.dataset_name = os.path.basename(self.root_dir)
if self.dataset is None:
# check whether the dataset is currently supported
if self.dataset_name not in SupportedDatasets.__members__:
raise ValueError('{} is not a valid dataset. '
.format(self.dataset_name) +
'Available datasets are: \n' +
'\n'.join(name for name, _ in
SupportedDatasets.__members__.items()))
else:
self.dataset_class = SupportedDatasets.__members__[
self.dataset_name].value
# instanciate the dataset
self.dataset = self.dataset_class(
self.root_dir,
use_bands=self.bands,
tile_size=self.tile_size,
sort=self.sort,
transforms=self.transforms,
pad=self.pad,
cval=self.cval,
gt_pattern=self.gt_pattern
)
# the mode to split
if self.split_mode not in VALID_SPLIT_MODES:
raise ValueError('{} is not supported. Valid modes are {}, see '
'pysegcnn.main.config.py for a description of '
'each mode.'.format(self.split_mode,
VALID_SPLIT_MODES))
'pysegcnn.main.config.py for a description of '
'each mode.'.format(self.split_mode,
VALID_SPLIT_MODES))
if self.split_mode == 'random':
self.subset = RandomTileSplit(self.dataset,
self.ttratio,
......@@ -364,12 +363,12 @@ class NetworkTrainer(object):
self.date,
self.dateformat)
# the training, validation and dataset
# the training, validation and test dataset
self.train_ds, self.valid_ds, self.test_ds = self.subset.split()
# whether to drop training samples with a fraction of pixels equal to
# the constant padding value self.cval >= self.drop
if self.pad:
if self.pad and self.drop:
self._drop(self.train_ds)
# the shape of a single batch
......@@ -400,18 +399,53 @@ class NetworkTrainer(object):
def _init_model(self):
# instanciate the segmentation network
if self.pretrained:
# initial accuracy on the validation set
self.max_accuracy = 0
# set the model checkpoint to None, overwritten when resuming
# training from an existing model checkpoint
self.checkpoint_state = None
# case (1): build a model for the specified dataset
if not self.pretrained and not self.checkpoint:
# instanciate the model
self.model = self.model(in_channels=len(self.dataset.use_bands),
nclasses=len(self.dataset.labels),
filters=self.filters,
skip=self.skip_connection,
**self.kwargs)
# the optimizer used to update the model weights
self.optimizer = self.optimizer(self.model.parameters(), self.lr)
# case (2): using a pretrained model withouth existing checkpoint on
# a new dataset, i.e. transfer learning
if self.pretrained and not self.checkpoint:
# load pretrained model
self.model = self.from_pretrained()
else:
self.model = self.net(in_channels=len(self.dataset.use_bands),
nclasses=len(self.dataset.labels),
filters=self.filters,
skip=self.skip_connection,
**self.kwargs)
# the optimizer used to update the model weights
self.optimizer = self.optimizer(self.model.parameters(), self.lr)
# the optimizer used to update the model weights
self.optimizer = self.optimizer(self.model.parameters(), self.lr)
# case (3): using a pretrained model with existing checkpoint on the
# same dataset the pretrained model was trained on
elif self.checkpoint:
# instanciate the model
self.model = self.model(in_channels=len(self.dataset.use_bands),
nclasses=len(self.dataset.labels),
filters=self.filters,
skip=self.skip_connection,
**self.kwargs)
# the optimizer used to update the model weights
self.optimizer = self.optimizer(self.model.parameters(), self.lr)
# whether to resume training from an existing model checkpoint
if self.checkpoint:
(self.checkpoint_state,
self.max_accuracy) = self.from_checkpoint()
# function to drop samples with a fraction of pixels equal to the constant
# padding value self.cval >= self.drop
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment