diff --git a/pysegcnn/core/trainer.py b/pysegcnn/core/trainer.py index b2535ff67626c601a3d133ced7c7700825d9f011..96bb55e7ebabae189aaf06ee9215b84fdf079162 100644 --- a/pysegcnn/core/trainer.py +++ b/pysegcnn/core/trainer.py @@ -18,7 +18,9 @@ from torch.utils.data import DataLoader from pysegcnn.core.dataset import SupportedDatasets from pysegcnn.core.layers import Conv2dSame from pysegcnn.core.utils import img2np, accuracy_function -from pysegcnn.core.split import Split +from pysegcnn.core.split import (RandomTileSplit, RandomSceneSplit, DateSplit, + VALID_SPLIT_MODES) + class NetworkTrainer(object): @@ -340,13 +342,27 @@ class NetworkTrainer(object): '\n'.join(name for name, _ in SupportedDatasets.__members__.items())) - # instanciate the Split class handling the dataset split - self.subset = Split(self.dataset, self.split_mode, - tvratio=self.tvratio, - ttratio=self.ttratio, - seed=self.seed, - date=self.date, - dateformat=self.dateformat) + + # the mode to split + if self.split_mode not in VALID_SPLIT_MODES: + raise ValueError('{} is not supported. Valid modes are {}, see ' + 'pysegcnn.main.config.py for a description of ' + 'each mode.'.format(self.split_mode, + VALID_SPLIT_MODES)) + if self.split_mode == 'random': + self.subset = RandomTileSplit(self.dataset, + self.ttratio, + self.tvratio, + self.seed) + if self.split_mode == 'scene': + self.subset = RandomSceneSplit(self.dataset, + self.ttratio, + self.tvratio, + self.seed) + if self.split_mode == 'date': + self.subset = DateSplit(self.dataset, + self.date, + self.dateformat) # the training, validation and dataset self.train_ds, self.valid_ds, self.test_ds = self.subset.split() @@ -442,16 +458,10 @@ class NetworkTrainer(object): # representation string to print fs = self.__class__.__name__ + '(\n' - fs += ' (bands):\n ' - - # bands used for the segmentation - fs += '\n '.join('- Band {}: {}'.format(i, b) for i, b in - enumerate(self.dataset.use_bands)) - # classes of interest - fs += '\n (classes):\n ' - fs += '\n '.join('- Class {}: {}'.format(k, v['label']) for - k, v in self.dataset.labels.items()) + # dataset + fs += ' (dataset):\n ' + fs += ''.join(self.dataset.__repr__()).replace('\n', '\n ') # batch size fs += '\n (batch):\n ' @@ -459,19 +469,18 @@ class NetworkTrainer(object): fs += '- batch shape (b, h, w): {}'.format(self.batch_shape) # dataset split - fs += '\n (dataset):\n ' - fs += '\n '.join( - '- {}: {:d} batches ({:.2f}%)' - .format(k, v[0], v[1] * 100) for k, v in - {'Training': (len(self.train_ds), self.ttratio * self.tvratio), - 'Validation': (len(self.valid_ds), - self.ttratio * (1 - self.tvratio)), - 'Test': (len(self.test_ds), 1 - self.ttratio)}.items()) + fs += '\n (split):\n ' + fs += ''.join(self.subset.__repr__()).replace('\n', '\n ') # model fs += '\n (model):\n ' fs += ''.join(self.model.__repr__()).replace('\n', '\n ') + + # optimizer + fs += '\n (optimizer):\n ' + fs += ''.join(self.optimizer.__repr__()).replace('\n', '\n ') fs += '\n)' + return fs