Skip to content
Snippets Groups Projects
Commit 1e145d4d authored by Frisinghelli Daniel's avatar Frisinghelli Daniel
Browse files

Divided init method into smaller submethods; included support for data augmentations

parent 5e28a1c8
No related branches found
No related tags found
No related merge requests found
......@@ -18,6 +18,7 @@ from torch.utils.data import random_split, DataLoader
# local modules
from pytorch.dataset import SparcsDataset, Cloud95Dataset
from pytorch.layers import Conv2dSame
from pytorch.constants import SupportedDatasets
class NetworkTrainer(object):
......@@ -28,105 +29,26 @@ class NetworkTrainer(object):
for k, v in config.items():
setattr(self, k, v)
# check which dataset the model is trained on
if self.dataset_name == 'Sparcs':
# instanciate the SparcsDataset
self.dataset = SparcsDataset(self.dataset_path,
use_bands=self.bands,
tile_size=self.tile_size)
elif self.dataset_name == 'Cloud95':
# instanciate the Cloud95Dataset
self.dataset = Cloud95Dataset(self.dataset_path,
use_bands=self.bands,
tile_size=self.tile_size,
exclude=self.patches)
else:
raise ValueError('{} is not a valid dataset. Available datasets '
'are "Sparcs" and "Cloud95".'
.format(self.dataset_name))
# print the bands used for the segmentation
print('------------------------ Input bands -------------------------')
print(*['Band {}: {}'.format(i, b) for i, b in
enumerate(self.dataset.use_bands)], sep='\n')
print('--------------------------------------------------------------')
# print the classes of interest
print('-------------------------- Classes ---------------------------')
print(*['Class {}: {}'.format(k, v['label']) for k, v in
self.dataset.labels.items()], sep='\n')
print('--------------------------------------------------------------')
# instanciate the segmentation network
print('------------------- Network architecture ---------------------')
if self.pretrained:
self.model = self.from_pretrained()
else:
self.model = self.net(in_channels=len(self.dataset.use_bands),
nclasses=len(self.dataset.labels),
filters=self.filters,
skip=self.skip_connection,
**self.kwargs)
print(self.model)
print('--------------------------------------------------------------')
# the training and validation dataset
print('------------------------ Dataset split -----------------------')
self.train_ds, self.valid_ds, self.test_ds = self.train_val_test_split(
self.dataset, self.tvratio, self.ttratio, self.seed)
print('--------------------------------------------------------------')
# number of batches in the validation set
self.nvbatches = int(len(self.valid_ds) / self.batch_size)
# number of batches in the training set
self.nbatches = int(len(self.train_ds) / self.batch_size)
# the training and validation dataloaders
self.train_dl = DataLoader(self.train_ds,
self.batch_size,
shuffle=True,
drop_last=True)
self.valid_dl = DataLoader(self.valid_ds,
self.batch_size,
shuffle=True,
drop_last=True)
# the optimizer used to update the model weights
self.optimizer = self.optimizer(self.model.parameters(), self.lr)
# whether to use the gpu
self.device = torch.device("cuda:0" if torch.cuda.is_available() else
"cpu")
# file to save model state to
# format: networkname_datasetname_t(tilesize)_b(batchsize)_bands.pt
bformat = ''.join([b[0] for b in self.bands]) if self.bands else 'all'
self.state_file = ('{}_{}_t{}_b{}_{}.pt'
.format(self.model.__class__.__name__,
self.dataset.__class__.__name__,
self.tile_size,
self.batch_size,
bformat))
# check whether a pretrained model was used and change state filename
# accordingly
if self.pretrained:
# add the configuration of the pretrained model to the state name
self.state_file = (self.state_file.replace('.pt', '_') +
'pretrained_' + self.pretrained_model)
# path to model state
self.state = os.path.join(self.state_path, self.state_file)
# initialize the dataset to train the model on
self._init_dataset()
# path to model loss/accuracy
self.loss_state = self.state.replace('.pt', '_loss.pt')
# initialize the model
self._init_model()
def from_pretrained(self):
# load the pretrained model
model_state = torch.load(
os.path.join(self.state_path, self.pretrained_model))
model_state = os.path.join(self.state_path, self.pretrained_model)
if not os.path.exists(model_state):
raise FileNotFoundError('Pretrained model {} does not exist.'
.format(model_state))
# load the model state
model_state = torch.load(model_state)
# get the input bands of the pretrained model
bands = model_state['bands']
......@@ -152,8 +74,10 @@ class NetworkTrainer(object):
out_channels=len(self.dataset.labels),
kernel_size=1)
return model
# adjust the number of classes in the model
model.nclasses = len(self.dataset.labels)
return model
def ds_len(self, ds, ratio):
return int(np.round(len(ds) * ratio))
......@@ -193,22 +117,6 @@ class NetworkTrainer(object):
def accuracy_function(self, outputs, labels):
return (outputs == labels).float().mean()
def _save_loss(self, training_state, checkpoint=False,
checkpoint_state=None):
# save losses and accuracy
if checkpoint and checkpoint_state is not None:
# append values from checkpoint to current training
# state
torch.save({
k1: np.hstack([v1, v2]) for (k1, v1), (k2, v2) in
zip(checkpoint_state.items(), training_state.items())
if k1 == k2},
self.loss_state)
else:
torch.save(training_state, self.loss_state)
def train(self):
# set the number of threads
......@@ -418,6 +326,117 @@ class NetworkTrainer(object):
return cm, accuracies, losses
def _init_dataset(self):
# check whether the dataset is currently supported
self.dataset = None
for dataset in SupportedDatasets:
if self.dataset_name == dataset.name:
self.dataset = dataset.value['class'](self.dataset_path,
self.bands,
self.tile_size,
self.sort,
self.transforms)
if self.dataset is None:
print('{} is not a valid dataset.').format(self.dataset_name)
print('Available datasets are:')
for name, _ in SupportedDatasets.__members__.items():
print(name)
raise ValueError('Dataset not supported.')
# print the bands used for the segmentation
print('------------------------ Input bands -------------------------')
print(*['Band {}: {}'.format(i, b) for i, b in
enumerate(self.dataset.use_bands)], sep='\n')
print('--------------------------------------------------------------')
# print the classes of interest
print('-------------------------- Classes ---------------------------')
print(*['Class {}: {}'.format(k, v['label']) for k, v in
self.dataset.labels.items()], sep='\n')
print('--------------------------------------------------------------')
# the training and validation dataset
print('------------------------ Dataset split -----------------------')
self.train_ds, self.valid_ds, self.test_ds = self.train_val_test_split(
self.dataset, self.tvratio, self.ttratio, self.seed)
# number of batches in the validation set
self.nvbatches = int(len(self.valid_ds) / self.batch_size)
# number of batches in the training set
self.nbatches = int(len(self.train_ds) / self.batch_size)
print('Number of training batches: {}'.format(self.nbatches))
print('Number of validation batches: {}'.format(self.nvbatches))
print('--------------------------------------------------------------')
# the training and validation dataloaders
self.train_dl = DataLoader(self.train_ds,
self.batch_size,
shuffle=True,
drop_last=True)
self.valid_dl = DataLoader(self.valid_ds,
self.batch_size,
shuffle=True,
drop_last=True)
def _init_model(self):
# instanciate the segmentation network
print('------------------- Network architecture ---------------------')
if self.pretrained:
self.model = self.from_pretrained()
else:
self.model = self.net(in_channels=len(self.dataset.use_bands),
nclasses=len(self.dataset.labels),
filters=self.filters,
skip=self.skip_connection,
**self.kwargs)
print(self.model)
print('--------------------------------------------------------------')
# the optimizer used to update the model weights
self.optimizer = self.optimizer(self.model.parameters(), self.lr)
# file to save model state to
# format: networkname_datasetname_t(tilesize)_b(batchsize)_bands.pt
bformat = ''.join([b[0] for b in self.bands]) if self.bands else 'all'
self.state_file = ('{}_{}_t{}_b{}_{}.pt'
.format(self.model.__class__.__name__,
self.dataset.__class__.__name__,
self.tile_size,
self.batch_size,
bformat))
# check whether a pretrained model was used and change state filename
# accordingly
if self.pretrained:
# add the configuration of the pretrained model to the state name
self.state_file = (self.state_file.replace('.pt', '_') +
'pretrained_' + self.pretrained_model)
# path to model state
self.state = os.path.join(self.state_path, self.state_file)
# path to model loss/accuracy
self.loss_state = self.state.replace('.pt', '_loss.pt')
def _save_loss(self, training_state, checkpoint=False,
checkpoint_state=None):
# save losses and accuracy
if checkpoint and checkpoint_state is not None:
# append values from checkpoint to current training
# state
torch.save({
k1: np.hstack([v1, v2]) for (k1, v1), (k2, v2) in
zip(checkpoint_state.items(), training_state.items())
if k1 == k2},
self.loss_state)
else:
torch.save(training_state, self.loss_state)
class EarlyStopping(object):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment