diff --git a/pytorch/trainer.py b/pytorch/trainer.py index 7376ccdee5ef2e0b9cc396a47de4647d369f46cb..3bf269588f7fd5aed77e44387d2be1dfedcf02ce 100755 --- a/pytorch/trainer.py +++ b/pytorch/trainer.py @@ -18,6 +18,7 @@ from torch.utils.data import random_split, DataLoader # local modules from pytorch.dataset import SparcsDataset, Cloud95Dataset from pytorch.layers import Conv2dSame +from pytorch.constants import SupportedDatasets class NetworkTrainer(object): @@ -28,105 +29,26 @@ class NetworkTrainer(object): for k, v in config.items(): setattr(self, k, v) - # check which dataset the model is trained on - if self.dataset_name == 'Sparcs': - # instanciate the SparcsDataset - self.dataset = SparcsDataset(self.dataset_path, - use_bands=self.bands, - tile_size=self.tile_size) - elif self.dataset_name == 'Cloud95': - # instanciate the Cloud95Dataset - self.dataset = Cloud95Dataset(self.dataset_path, - use_bands=self.bands, - tile_size=self.tile_size, - exclude=self.patches) - else: - raise ValueError('{} is not a valid dataset. Available datasets ' - 'are "Sparcs" and "Cloud95".' - .format(self.dataset_name)) - - # print the bands used for the segmentation - print('------------------------ Input bands -------------------------') - print(*['Band {}: {}'.format(i, b) for i, b in - enumerate(self.dataset.use_bands)], sep='\n') - print('--------------------------------------------------------------') - - # print the classes of interest - print('-------------------------- Classes ---------------------------') - print(*['Class {}: {}'.format(k, v['label']) for k, v in - self.dataset.labels.items()], sep='\n') - print('--------------------------------------------------------------') - - # instanciate the segmentation network - print('------------------- Network architecture ---------------------') - if self.pretrained: - self.model = self.from_pretrained() - else: - self.model = self.net(in_channels=len(self.dataset.use_bands), - nclasses=len(self.dataset.labels), - filters=self.filters, - skip=self.skip_connection, - **self.kwargs) - print(self.model) - print('--------------------------------------------------------------') - - # the training and validation dataset - print('------------------------ Dataset split -----------------------') - self.train_ds, self.valid_ds, self.test_ds = self.train_val_test_split( - self.dataset, self.tvratio, self.ttratio, self.seed) - print('--------------------------------------------------------------') - - # number of batches in the validation set - self.nvbatches = int(len(self.valid_ds) / self.batch_size) - - # number of batches in the training set - self.nbatches = int(len(self.train_ds) / self.batch_size) - - # the training and validation dataloaders - self.train_dl = DataLoader(self.train_ds, - self.batch_size, - shuffle=True, - drop_last=True) - self.valid_dl = DataLoader(self.valid_ds, - self.batch_size, - shuffle=True, - drop_last=True) - - # the optimizer used to update the model weights - self.optimizer = self.optimizer(self.model.parameters(), self.lr) - # whether to use the gpu self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") - # file to save model state to - # format: networkname_datasetname_t(tilesize)_b(batchsize)_bands.pt - bformat = ''.join([b[0] for b in self.bands]) if self.bands else 'all' - self.state_file = ('{}_{}_t{}_b{}_{}.pt' - .format(self.model.__class__.__name__, - self.dataset.__class__.__name__, - self.tile_size, - self.batch_size, - bformat)) - - # check whether a pretrained model was used and change state filename - # accordingly - if self.pretrained: - # add the configuration of the pretrained model to the state name - self.state_file = (self.state_file.replace('.pt', '_') + - 'pretrained_' + self.pretrained_model) - - # path to model state - self.state = os.path.join(self.state_path, self.state_file) + # initialize the dataset to train the model on + self._init_dataset() - # path to model loss/accuracy - self.loss_state = self.state.replace('.pt', '_loss.pt') + # initialize the model + self._init_model() def from_pretrained(self): # load the pretrained model - model_state = torch.load( - os.path.join(self.state_path, self.pretrained_model)) + model_state = os.path.join(self.state_path, self.pretrained_model) + if not os.path.exists(model_state): + raise FileNotFoundError('Pretrained model {} does not exist.' + .format(model_state)) + + # load the model state + model_state = torch.load(model_state) # get the input bands of the pretrained model bands = model_state['bands'] @@ -152,8 +74,10 @@ class NetworkTrainer(object): out_channels=len(self.dataset.labels), kernel_size=1) - return model + # adjust the number of classes in the model + model.nclasses = len(self.dataset.labels) + return model def ds_len(self, ds, ratio): return int(np.round(len(ds) * ratio)) @@ -193,22 +117,6 @@ class NetworkTrainer(object): def accuracy_function(self, outputs, labels): return (outputs == labels).float().mean() - def _save_loss(self, training_state, checkpoint=False, - checkpoint_state=None): - - # save losses and accuracy - if checkpoint and checkpoint_state is not None: - - # append values from checkpoint to current training - # state - torch.save({ - k1: np.hstack([v1, v2]) for (k1, v1), (k2, v2) in - zip(checkpoint_state.items(), training_state.items()) - if k1 == k2}, - self.loss_state) - else: - torch.save(training_state, self.loss_state) - def train(self): # set the number of threads @@ -418,6 +326,117 @@ class NetworkTrainer(object): return cm, accuracies, losses + def _init_dataset(self): + + # check whether the dataset is currently supported + self.dataset = None + for dataset in SupportedDatasets: + if self.dataset_name == dataset.name: + self.dataset = dataset.value['class'](self.dataset_path, + self.bands, + self.tile_size, + self.sort, + self.transforms) + if self.dataset is None: + print('{} is not a valid dataset.').format(self.dataset_name) + print('Available datasets are:') + for name, _ in SupportedDatasets.__members__.items(): + print(name) + raise ValueError('Dataset not supported.') + + # print the bands used for the segmentation + print('------------------------ Input bands -------------------------') + print(*['Band {}: {}'.format(i, b) for i, b in + enumerate(self.dataset.use_bands)], sep='\n') + print('--------------------------------------------------------------') + + # print the classes of interest + print('-------------------------- Classes ---------------------------') + print(*['Class {}: {}'.format(k, v['label']) for k, v in + self.dataset.labels.items()], sep='\n') + print('--------------------------------------------------------------') + + # the training and validation dataset + print('------------------------ Dataset split -----------------------') + self.train_ds, self.valid_ds, self.test_ds = self.train_val_test_split( + self.dataset, self.tvratio, self.ttratio, self.seed) + + # number of batches in the validation set + self.nvbatches = int(len(self.valid_ds) / self.batch_size) + + # number of batches in the training set + self.nbatches = int(len(self.train_ds) / self.batch_size) + print('Number of training batches: {}'.format(self.nbatches)) + print('Number of validation batches: {}'.format(self.nvbatches)) + print('--------------------------------------------------------------') + + # the training and validation dataloaders + self.train_dl = DataLoader(self.train_ds, + self.batch_size, + shuffle=True, + drop_last=True) + self.valid_dl = DataLoader(self.valid_ds, + self.batch_size, + shuffle=True, + drop_last=True) + + def _init_model(self): + + # instanciate the segmentation network + print('------------------- Network architecture ---------------------') + if self.pretrained: + self.model = self.from_pretrained() + else: + self.model = self.net(in_channels=len(self.dataset.use_bands), + nclasses=len(self.dataset.labels), + filters=self.filters, + skip=self.skip_connection, + **self.kwargs) + print(self.model) + print('--------------------------------------------------------------') + + # the optimizer used to update the model weights + self.optimizer = self.optimizer(self.model.parameters(), self.lr) + + # file to save model state to + # format: networkname_datasetname_t(tilesize)_b(batchsize)_bands.pt + bformat = ''.join([b[0] for b in self.bands]) if self.bands else 'all' + self.state_file = ('{}_{}_t{}_b{}_{}.pt' + .format(self.model.__class__.__name__, + self.dataset.__class__.__name__, + self.tile_size, + self.batch_size, + bformat)) + + # check whether a pretrained model was used and change state filename + # accordingly + if self.pretrained: + # add the configuration of the pretrained model to the state name + self.state_file = (self.state_file.replace('.pt', '_') + + 'pretrained_' + self.pretrained_model) + + # path to model state + self.state = os.path.join(self.state_path, self.state_file) + + # path to model loss/accuracy + self.loss_state = self.state.replace('.pt', '_loss.pt') + + def _save_loss(self, training_state, checkpoint=False, + checkpoint_state=None): + + # save losses and accuracy + if checkpoint and checkpoint_state is not None: + + # append values from checkpoint to current training + # state + torch.save({ + k1: np.hstack([v1, v2]) for (k1, v1), (k2, v2) in + zip(checkpoint_state.items(), training_state.items()) + if k1 == k2}, + self.loss_state) + else: + torch.save(training_state, self.loss_state) + class EarlyStopping(object):