From 04830e41ef0feb1c539d101fc14a52c80e9bedea Mon Sep 17 00:00:00 2001 From: "Daniel.Frisinghelli" <daniel.frisinghelli@eurac.edu> Date: Fri, 22 Jan 2021 17:42:19 +0100 Subject: [PATCH] Working on a more stable generation of model state files. --- pysegcnn/core/trainer.py | 347 +++++++++++++-------------------------- 1 file changed, 112 insertions(+), 235 deletions(-) diff --git a/pysegcnn/core/trainer.py b/pysegcnn/core/trainer.py index 457671b..123be81 100644 --- a/pysegcnn/core/trainer.py +++ b/pysegcnn/core/trainer.py @@ -37,9 +37,9 @@ from torch.utils.data import DataLoader from torch.optim import Optimizer # locals -from pysegcnn.core.dataset import SupportedDatasets, ImageDataset +from pysegcnn.core.dataset import SupportedDatasets from pysegcnn.core.transforms import Augment -from pysegcnn.core.utils import (img2np, item_in_enum, accuracy_function, +from pysegcnn.core.utils import (item_in_enum, accuracy_function, reconstruct_scene, check_filename_length, array_replace) from pysegcnn.core.split import SupportedSplits @@ -110,9 +110,6 @@ class DatasetConfig(BaseConfig): A regural expression to match the ground truth naming convention. All directories and subdirectories in ``root_dir`` are searched for files matching ``gt_pattern``. - seed : `int` - The random seed. Used to split the dataset into training, - validation and test set. Useful for reproducibility. sort : `bool` Whether to chronologically sort the samples. Useful for time series data. The default is `False`. @@ -144,7 +141,6 @@ class DatasetConfig(BaseConfig): bands: list tile_size: object gt_pattern: str - seed: int sort: bool = False transforms: list = dataclasses.field(default_factory=[]) pad: bool = False @@ -196,7 +192,6 @@ class DatasetConfig(BaseConfig): root_dir=str(self.root_dir), use_bands=self.bands, tile_size=self.tile_size, - seed=self.seed, sort=self.sort, transforms=self.transforms, pad=self.pad, @@ -217,36 +212,32 @@ class SplitConfig(BaseConfig): ---------- split_mode : `str` The mode to split the dataset. - ttratio : `float` - The ratio of training and validation data to test data, e.g. - ``ttratio= 0.6`` means 60% for training and validation, 40% for - testing. - tvratio : `float` + k_folds: `int`, optional + The number of folds. + seed : `int`, optional + The random seed for reproducibility. The default is `0`. + shuffle : `bool`, optional + Whether to shuffle the data before splitting into batches. The + default is `True`. + tvratio : `float`, optional The ratio of training data to validation data, e.g. ``tvratio=0.8`` - means 80% training, 20% validation. - date : `str` - A date. Used if ``split_mode='date'``. The default is `yyyymmdd`. - dateformat : `str` - The format of ``date``. ``dateformat`` is used by - :py:func:`datetime.datetime.strptime' to parse ``date`` to a - :py:class:`datetime.datetime` object. The default is `'%Y%m%d'`. - drop : `float` - Whether to drop samples (during training only) with a fraction of - pixels equal to the constant padding value >= ``drop``. ``drop=0`` - means, do not drop any samples. The default is `0`. - split_class : :py:class:`pysegcnn.core.split.Split` - A subclass of :py:class:`pysegcnn.core.split.Split`. - dropped : `list` [`dict`] - List of the dropped samples. + means 80% training, 20% validation. The default is `0.8`. Used if + ``k_folds=1``. + ttratio : `float`, optional + The ratio of training and validation data to test data, e.g. + ``ttratio=0.6`` means 60% for training and validation, 40% for + testing. The default is `1`. Used if ``k_folds=1``. + split_class : :py:class:`pysegcnn.core.split.RandomSplit` + A subclass of :py:class:`pysegcnn.core.split.RandomSplit`. """ split_mode: str - ttratio: float - tvratio: float - date: str = 'yyyymmdd' - dateformat: str = '%Y%m%d' - drop: float = 0 + k_folds: int = 1 + seed: int = 0 + shuffle: bool = True + tvratio: float = 0.8 + ttratio: float = 1 def __post_init__(self): """Check the type of each argument. @@ -263,52 +254,6 @@ class SplitConfig(BaseConfig): # check if the split mode is valid self.split_class = item_in_enum(self.split_mode, SupportedSplits) - # list of dropped samples - self.dropped = [] - - @staticmethod - def drop_samples(ds, drop_threshold=1): - """Drop samples with a fraction of pixels equal to the padding value. - - Parameters - ---------- - ds : :py:class:`pysegcnn.core.split.CustomSubset` - An instance of :py:class:`pysegcnn.core.split.CustomSubset`. - drop_threshold : `float`, optional - The threshold above which samples are dropped. ``drop_threshold=1`` - means a sample is dropped, if all pixels are equal to the padding - value. ``drop_threshold=0.8`` means, drop a sample if 80% of the - pixels are equal to the padding value, etc. The default is `1`. - - Returns - ------- - dropped : `list` [`dict`] - List of the dropped samples. - - """ - # iterate over the scenes returned by self.compose_scenes() - dropped = [] - for pos, i in enumerate(ds.indices): - - # the current scene - s = ds.dataset.scenes[i] - - # the current tile in the ground truth - tile_gt = img2np(s['gt'], ds.dataset.tile_size, s['tile'], - ds.dataset.pad, ds.dataset.cval) - - # percent of pixels equal to the constant padding value - npixels = (tile_gt[tile_gt == ds.dataset.cval].size / tile_gt.size) - - # drop samples where npixels >= self.drop - if npixels >= drop_threshold: - LOGGER.info('Skipping scene {}, tile {}: {:.2f}% padded pixels' - ' ...'.format(s['id'], s['tile'], npixels * 100)) - dropped.append(s) - _ = ds.indices.pop(pos) - - return dropped - def train_val_test_split(self, ds): """Split ``ds`` into training, validation and test set. @@ -333,29 +278,14 @@ class SplitConfig(BaseConfig): The test set. """ - if not isinstance(ds, ImageDataset): - raise TypeError('Expected "ds" to be {}.' - .format('.'.join([ImageDataset.__module__, - ImageDataset.__name__]))) - - if self.split_mode == 'random' or self.split_mode == 'scene': - subset = self.split_class(ds, - self.ttratio, - self.tvratio, - ds.seed) - - else: - subset = self.split_class(ds, self.date, self.dateformat) + # instanciate the split class + split = self.split_class(ds, self.k_folds, self.seed, self.shuffle, + self.tvratio, self.ttratio) # the training, validation and test dataset - train_ds, valid_ds, test_ds = subset.split() - - # whether to drop training samples with a fraction of pixels equal to - # the constant padding value cval >= drop - if ds.pad and self.drop > 0: - self.dropped = self.drop_samples(train_ds, self.drop) + subsets = split.split() - return train_ds, valid_ds, test_ds + return subsets @staticmethod def dataloaders(*args, **kwargs): @@ -787,6 +717,7 @@ class StateConfig(BaseConfig): Generate the model state filename according to the following naming conventions: + - For source domain without domain adaptation: Model_Optim_SourceDataset_ModelParams.pt @@ -801,30 +732,8 @@ class StateConfig(BaseConfig): Model_Optim_SourceDataset_TargetDataset_ModelParams_prt_ NameOfPretrainedModel.pt - Attributes - ---------- - src_dc : :py:class:`pysegcnn.core.trainer.DatasetConfig` - The source domain dataset configuration. - src_sc : :py:class:`pysegcnn.core.trainer.SplitConfig` - The source domain dataset split configuration. - trg_dc : :py:class:`pysegcnn.core.trainer.DatasetConfig` - The target domain dataset configuration. - trg_sc : :py:class:`pysegcnn.core.trainer.SplitConfig` - The target domain dataset split configuration. - mc : :py:class:`pysegcnn.core.trainer.ModelConfig` - The model configuration. - tc : :py:class:`pysegcnn.core.trainer.TransferLearningConfig` - The transfer learning configuration. - """ - src_dc: DatasetConfig - src_sc: SplitConfig - trg_dc: DatasetConfig - trg_sc: SplitConfig - mc: ModelConfig - tc: TransferLearningConfig - def __post_init__(self): """Check the type of each argument. @@ -840,85 +749,112 @@ class StateConfig(BaseConfig): # base dataset state filename: Dataset_SplitMode_SplitParams self.ds_state_file = '{}_{}_{}' - # base model state filename: Model_Optim - self.ml_state_file = '{}_{}' + # base model state filename: Model_Optim_BatchSize + self.ml_state_file = '{}_{}_b{}' - # base model state filename extentsion: TileSize_BatchSize_Bands - self.ml_state_ext = 't{}_b{}_{}.pt' - - # check that the spectral bands are the same for both source and target - # domains - if self.src_dc.bands != self.trg_dc.bands: - raise ValueError('Spectral bands of the source and target domain ' - 'have to be equal: \n source: {} \n target: {}' - .format(', '.join(self.src_dc.bands), - ', '.join(self.trg_dc.bands)) - ) + # base model state filename extentsion: TileSize_Bands + self.ds_state_ext = 't{}_{}.pt' - def init_state(self): + def init_state(self, src_dc, src_sc, mc, trg_dc=None, trg_sc=None, tc=None, + fold=0): """Generate the model state filename. + Parameters + ---------- + src_dc : :py:class:`pysegcnn.core.trainer.DatasetConfig` + The source domain dataset configuration. + src_sc : :py:class:`pysegcnn.core.trainer.SplitConfig` + The source domain dataset split configuration. + mc : :py:class:`pysegcnn.core.trainer.ModelConfig` + The model configuration. + trg_dc : :py:class:`pysegcnn.core.trainer.DatasetConfig` + The target domain dataset configuration. + trg_sc : :py:class:`pysegcnn.core.trainer.SplitConfig` + The target domain dataset split configuration. + tc : :py:class:`pysegcnn.core.trainer.TransferLearningConfig` + The transfer learning configuration. + Returns ------- state : :py:class:`pathlib.Path` The path to the model state file. """ - # source domain dataset state filename - src_ds_state = self._format_ds_state( - self.ds_state_file, self.src_dc, self.src_sc) + # source domain dataset state filename and extension + src_ds_state, src_ds_ext = self.format_ds_state( + src_dc, src_sc, fold) - # model state file name and extension: common to both source and target - # domain - ml_state, ml_ext = self._format_ml_state(self.src_dc, self.mc) + # model state file name: common to both source and target domain + ml_state = self.format_model_state(mc) # state file for models trained only on the source domain - state = '_'.join([ml_state, src_ds_state, ml_ext]) + state = '_'.join([ml_state, src_ds_state, src_ds_ext]) # check whether the model is trained on the source domain only - if self.tc.transfer: + if tc is not None: + + # check whether the target domain configurations are correctly + # specified + if trg_dc is None or trg_sc is None: + raise ValueError('Target domain configurations required.') # target domain dataset state filename - trg_ds_state = self._format_ds_state( - self.ds_state_file, self.trg_dc, self.trg_sc) + trg_ds_state, _ = self.format_ds_state( + trg_dc, trg_sc, fold) # check whether a pretrained model is used to fine-tune to the # target domain - if self.tc.supervised: + if tc.supervised: # state file for models fine-tuned to target domain # DatasetConfig_PretrainedModel.pt - state = '_'.join([self.tc.pretrained_model, + + # TODO: Is this correct? Trainer is initialized with source + # dataloaders + state = '_'.join([tc.pretrained_model, 'sda_{}'.format(trg_ds_state)]) else: # state file for models trained via unsupervised domain # adaptation state = '_'.join([state.replace( - ml_ext, 'uda_{}'.format(self.tc.uda_pos)), - trg_ds_state, ml_ext]) + src_ds_ext, 'uda_{}'.format(tc.uda_pos)), + trg_ds_state, src_ds_ext]) # check whether unsupervised domain adaptation is initialized # from a pretrained model state - if self.tc.uda_from_pretrained: + if tc.uda_from_pretrained: state = '_'.join(state.replace('.pt', ''), 'prt_{}'.format( - self.tc.pretrained_model)) + tc.pretrained_model)) # path to model state - state = self.tc.state_path.joinpath(state) + state = mc.state_path.joinpath(state) return state - def _format_ds_state(self, state_file, dc, sc): + def format_model_state(self, mc): + """Format base model state filename. + + Parameters + ---------- + mc : :py:class:`pysegcnn.core.trainer.ModelConfig` + The model configuration. + + """ + return self.ml_state_file.format(mc.model_name, mc.optim_name, + mc.batch_size) + + def format_ds_state(self, dc, sc, fold=None): """Format base dataset state filename. Parameters ---------- - state_file : `str` - The base dataset state filename. dc : :py:class:`pysegcnn.core.trainer.DatasetConfig` The dataset configuration. sc : :py:class:`pysegcnn.core.trainer.SplitConfig` The dataset split configuration. + fold : `int` or `None`, optional + The number of the current fold. The default is `None`, which means + the fold is not reported in the model state filename. Returns ------- @@ -926,52 +862,39 @@ class StateConfig(BaseConfig): The formatted dataset state filename. """ + # store the random seed for reproducibility + split_params = 's{}'.format(sc.seed) - # check which split mode was used - if sc.split_mode == 'date': - # store the date of the split - split_params = sc.date + # check whether the model is trained via cross validation + if sc.k_folds > 1 and fold is not None: + split_params += 'f{}'.format(fold) else: - # store the random split parameters - split_params = 's{}t{}v{}'.format( - dc.seed, str(sc.ttratio).replace('.', ''), - str(sc.tvratio).replace('.', '')) - - # model state filename - file = state_file.format(dc.dataset_class.__name__ + - '_m{}'.format(len(dc.merge_labels)), - sc.split_mode.capitalize(), - split_params, - ) - - return file + # construct dataset split parameters + split_params += 't{}v{}'.format(str(sc.ttratio).replace('.', ''), + str(sc.tvratio).replace('.', '')) - def _format_ml_state(self, dc, mc): - """Format base model state filename. - - Parameters - ---------- - dc : :py:class:`pysegcnn.core.trainer.DatasetConfig` - The source domain dataset configuration. - mc : :py:class:`pysegcnn.core.trainer.ModelConfig` - The model configuration. - - Returns - ------- - extention : `str` - The formatted model state filename. - - """ # get the band numbers if dc.bands: + # the spectral bands used to train the model bands = dc.dataset_class.get_sensor().band_dict() bformat = ''.join([(v[0] + str(k)) for k, v in bands.items() if v in dc.bands]) else: + # all available spectral bands are used to train the model bformat = 'all' - return (self.ml_state_file.format(mc.model_name, mc.optim_name), - self.ml_state_ext.format(dc.tile_size, mc.batch_size, bformat)) + # dataset state filename + file = self.ds_state_file.format(dc.dataset_class.__name__ + + '_m{}'.format(len(dc.merge_labels)), + sc.split_mode.capitalize(), + split_params, + ) + + # dataset state filename extension: common to both source and target + # domain + ext = self.ds_state_ext.format(dc.tile_size, bformat) + + return file, ext @dataclasses.dataclass @@ -1743,38 +1666,6 @@ class DomainAdaptationTrainer(ClassificationNetworkTrainer): 'uda_pos': self.uda_pos, 'uda_lambda': self.uda_lambda} - def _build_ds_repr(self, train_dl, valid_dl, test_dl): - """Build the dataset representation. - - Returns - ------- - fs : `str` - Representation string. - - """ - # dataset tile size - tile_size = (2 * (train_dl.dataset.dataset.tile_size,) if - train_dl.dataset.dataset.tile_size is not None else - train_dl.dataset.dataset.get_size()) - - # dataset configuration - fs = ' (dataset):\n ' - fs += ''.join(repr(train_dl.dataset.dataset)).replace('\n', - '\n' + 8 * ' ') - fs += '\n (batch):\n ' - fs += '- batch size: {}\n '.format(train_dl.batch_size) - fs += '- mini-batch shape (b, c, h, w): {}'.format( - ((min(train_dl.batch_size, len(train_dl.dataset)), - len(train_dl.dataset.dataset.use_bands),) + tile_size)) - - # dataset split - fs += '\n (split):' - for dl in [train_dl, valid_dl, test_dl]: - if dl.dataset is not None: - fs += '\n' + 8 * ' ' + repr(dl.dataset) - - return fs - def __repr__(self): """Representation. @@ -1787,20 +1678,6 @@ class DomainAdaptationTrainer(ClassificationNetworkTrainer): # representation string to print fs = self.__class__.__name__ + '(\n' - # source domain - fs += ' (source domain)\n ' - fs += self._build_ds_repr( - self.src_train_dl, self.src_valid_dl, self.src_test_dl).replace( - '\n', '\n' + 4 * ' ') - - # target domain - if not self.supervised: - fs += '\n (target domain)\n ' - fs += self._build_ds_repr( - self.trg_train_dl, - self.trg_valid_dl, - self.trg_test_dl).replace('\n', '\n' + 4 * ' ') - # model configuration fs += self._build_model_repr_() -- GitLab