From aefec6d367345b8cea3b71e5bdd67cd07d2476c3 Mon Sep 17 00:00:00 2001 From: "Daniel.Frisinghelli" <daniel.frisinghelli@eurac.edu> Date: Mon, 25 Jan 2021 15:22:52 +0100 Subject: [PATCH] Adjusted NetworkInference to accumulate statistics accross different model runs. --- pysegcnn/core/trainer.py | 170 ++++++++++++++++++++++----------------- 1 file changed, 94 insertions(+), 76 deletions(-) diff --git a/pysegcnn/core/trainer.py b/pysegcnn/core/trainer.py index 123be81..97ce7c1 100644 --- a/pysegcnn/core/trainer.py +++ b/pysegcnn/core/trainer.py @@ -51,7 +51,7 @@ from pysegcnn.core.logging import log_conf from pysegcnn.core.graphics import (plot_loss, plot_confusion_matrix, plot_sample) from pysegcnn.core.constants import map_labels -from pysegcnn.main.config import HERE, DRIVE_PATH +from pysegcnn.main.train_config import HERE, DRIVE_PATH # module level logger @@ -270,11 +270,11 @@ class SplitConfig(BaseConfig): Returns ------- - train_ds : :py:class:`pysegcnn.core.split.CustomSubset`. + train_ds : :py:class:`torch.utils.data.Subset`. The training set. - valid_ds : :py:class:`pysegcnn.core.split.CustomSubset`. + valid_ds : :py:class:`torch.utils.data.Subset`. The validation set. - test_ds : :py:class:`pysegcnn.core.split.CustomSubset`. + test_ds : :py:class:`torch.utils.data.Subset`. The test set. """ @@ -478,11 +478,10 @@ class ModelConfig(BaseConfig): """ # write an initialization string to the log file - LogConfig.init_log('{}: Initializing model run. ') + LogConfig.init_log('Initializing model: {} '.format(state_file.name)) # set the random seed for reproducibility torch.manual_seed(self.torch_seed) - LOGGER.info('Initializing model: {}'.format(state_file.name)) # initialize checkpoint state, i.e. no model checkpoint checkpoint_state = {} @@ -756,7 +755,7 @@ class StateConfig(BaseConfig): self.ds_state_ext = 't{}_{}.pt' def init_state(self, src_dc, src_sc, mc, trg_dc=None, trg_sc=None, tc=None, - fold=0): + fold=None): """Generate the model state filename. Parameters @@ -767,12 +766,13 @@ class StateConfig(BaseConfig): The source domain dataset split configuration. mc : :py:class:`pysegcnn.core.trainer.ModelConfig` The model configuration. - trg_dc : :py:class:`pysegcnn.core.trainer.DatasetConfig` - The target domain dataset configuration. - trg_sc : :py:class:`pysegcnn.core.trainer.SplitConfig` - The target domain dataset split configuration. - tc : :py:class:`pysegcnn.core.trainer.TransferLearningConfig` - The transfer learning configuration. + trg_dc : :py:class:`pysegcnn.core.trainer.DatasetConfig`, optional + The target domain dataset configuration. The default is `None`. + trg_sc : :py:class:`pysegcnn.core.trainer.SplitConfig`, optional + The target domain dataset split configuration. The default is + `None`. + tc : :py:class:`pysegcnn.core.trainer.TransferLearningConfig`, optional + The transfer learning configuration. The default is `None`. Returns ------- @@ -798,33 +798,34 @@ class StateConfig(BaseConfig): if trg_dc is None or trg_sc is None: raise ValueError('Target domain configurations required.') - # target domain dataset state filename - trg_ds_state, _ = self.format_ds_state( - trg_dc, trg_sc, fold) - - # check whether a pretrained model is used to fine-tune to the - # target domain - if tc.supervised: - # state file for models fine-tuned to target domain - # DatasetConfig_PretrainedModel.pt - - # TODO: Is this correct? Trainer is initialized with source - # dataloaders - state = '_'.join([tc.pretrained_model, - 'sda_{}'.format(trg_ds_state)]) - else: - # state file for models trained via unsupervised domain - # adaptation - state = '_'.join([state.replace( - src_ds_ext, 'uda_{}'.format(tc.uda_pos)), - trg_ds_state, src_ds_ext]) - - # check whether unsupervised domain adaptation is initialized - # from a pretrained model state - if tc.uda_from_pretrained: - state = '_'.join(state.replace('.pt', ''), - 'prt_{}'.format( - tc.pretrained_model)) + # check whether to apply transfer learning + if tc.transfer: + + # target domain dataset state filename + trg_ds_state, _ = self.format_ds_state( + trg_dc, trg_sc, fold) + + # check whether a pretrained model is used to fine-tune to the + # target domain + if tc.supervised: + # state file for models fine-tuned to target domain + # PretrainedModel_DatasetConfig.pt + state = '_'.join([ + tc.pretrained_model, 'sda_{}'.format(src_ds_state), + src_ds_ext]) + else: + # state file for models trained via unsupervised domain + # adaptation + state = '_'.join([ + state.replace(src_ds_ext, 'uda_{}'.format(tc.uda_pos)), + trg_ds_state, src_ds_ext]) + + # check whether unsupervised domain adaptation is + # initialized from a pretrained model state + if tc.uda_from_pretrained: + state = '_'.join( + [state.replace('.pt', ''), + 'prt_{}'.format(tc.pretrained_model)]) # path to model state state = mc.state_path.joinpath(state) @@ -954,7 +955,7 @@ class LogConfig(BaseConfig): """ LOGGER.info(80 * '-') - LOGGER.info(init_str.format(LogConfig.now())) + LOGGER.info('{}: '.format(LogConfig.now()) + init_str) LOGGER.info(80 * '-') @@ -984,15 +985,15 @@ class ClassificationNetworkTrainer(BaseConfig): src_train_dl : :py:class:`torch.utils.data.DataLoader` The source domain training :py:class:`torch.utils.data.DataLoader` instance build from an instance of - :py:class:`pysegcnn.core.split.CustomSubset`. + :py:class:`torch.utils.data.Subset`. src_valid_dl : :py:class:`torch.utils.data.DataLoader` The source domain validation :py:class:`torch.utils.data.DataLoader` instance build from an instance of - :py:class:`pysegcnn.core.split.CustomSubset`. + :py:class:`torch.utils.data.Subset`. src_test_dl : :py:class:`torch.utils.data.DataLoader` The source domain test :py:class:`torch.utils.data.DataLoader` instance build from an instance of - :py:class:`pysegcnn.core.split.CustomSubset`. + :py:class:`torch.utils.data.Subset`. epochs : `int` The maximum number of epochs to train. The default is `1`. nthreads : `int` @@ -1419,17 +1420,17 @@ class DomainAdaptationTrainer(ClassificationNetworkTrainer): trg_train_dl : `None` or :py:class:`torch.utils.data.DataLoader` The target domain training :py:class:`torch.utils.data.DataLoader` instance build from an instance of - :py:class:`pysegcnn.core.split.CustomSubset`. The default is an empty + :py:class:`torch.utils.data.Subset`. The default is an empty :py:class:`torch.utils.data.DataLoader`. trg_valid_dl : `None` or :py:class:`torch.utils.data.DataLoader` The target domain validation :py:class:`torch.utils.data.DataLoader` instance build from an instance of - :py:class:`pysegcnn.core.split.CustomSubset`. The default is an empty + :py:class:`torch.utils.data.Subset`. The default is an empty :py:class:`torch.utils.data.DataLoader`. trg_test_dl : :py:class:`torch.utils.data.DataLoader` The target domain test :py:class:`torch.utils.data.DataLoader` instance build from an instance of - :py:class:`pysegcnn.core.split.CustomSubset`. The default is an empty + :py:class:`torch.utils.data.Subset`. The default is an empty :py:class:`torch.utils.data.DataLoader`. uda_loss_function : :py:class:`torch.nn.Module` The domain adaptation loss function. An instance of @@ -2003,6 +2004,10 @@ class NetworkInference(BaseConfig): Whether to evaluate the model on the training (``test=None``), the validation (``test=False``) or the test set (``test=True``). The default is `False`. + aggregate : `bool` + Whether to aggregate the statistics of the different models in + ``state_files``. Useful to aggregate the results of mutliple model + runs in cross validation. The default is `False`. ds : `dict` The dataset configuration dictionary passed to :py:class:`pysegcnn.core.trainer.DatasetConfig` when evaluating on @@ -2036,10 +2041,6 @@ class NetworkInference(BaseConfig): alpha : `int` The level of the percentiles for contrast stretching of the false color compsite. The default is `0`, i.e. no stretching. - animate : `bool` - Whether to create an animation of (input, ground truth, prediction) for - the scenes in the train/validation/test dataset. Only works if - ``predict_scene=True`` and ``plot_scene=True``. device : `str` The device to evaluate the model on, i.e. `cpu` or `cuda`. base_path : :py:class:`pathlib.Path` @@ -2054,9 +2055,9 @@ class NetworkInference(BaseConfig): Path to search for model state files ``state_files``. plot_kwargs : `dict` Keyword arguments for :py:func:`pysegcnn.core.graphics.plot_sample` - trg_ds : :py:class:`pysegcnn.core.split.CustomSubset` + trg_ds : :py:class:`torch.utils.data.Subset` The dataset to evaluate ``model`` on. - src_ds : :py:class:`pysegcnn.core.split.CustomSubset` + src_ds : :py:class:`torch.utils.data.Subset` The model source domain training dataset. """ @@ -2065,6 +2066,7 @@ class NetworkInference(BaseConfig): implicit: bool = True domain: str = 'src' test: object = False + aggregate: bool = False ds: dict = dataclasses.field(default_factory={}) ds_split: dict = dataclasses.field(default_factory={}) map_labels: bool = False @@ -2133,11 +2135,11 @@ class NetworkInference(BaseConfig): This function assumes that the datasets are stored in a directory named "Datasets" on each machine. - See ``DRIVE_PATH`` in :py:mod:`pysegcnn.main.config`. + See ``DRIVE_PATH`` in :py:mod:`pysegcnn.main.eval_config`. Parameters ---------- - ds : :py:class:`pysegcnn.core.split.CustomSubset` + ds : :py:class:`torch.utils.data.Subset` A subset of an instance of :py:class:`pysegcnn.core.dataset.ImageDataset`. drive_path : `str` @@ -2148,8 +2150,7 @@ class NetworkInference(BaseConfig): ------ TypeError Raised if ``ds`` is not an instance of - :py:class:`pysegcnn.core.split.CustomSubset` and if ``ds`` is not - a subset of an instance of + :py:class:`torch.utils.data.Subset` build from an instance of :py:class:`pysegcnn.core.dataset.ImageDataset`. """ @@ -2180,23 +2181,23 @@ class NetworkInference(BaseConfig): Returns ------- - ds : :py:class:`pysegcnn.core.split.CustomSubset` + ds : :py:class:`torch.utils.data.Subset` The dataset to evaluate the model on. """ # load model state model_state = Network.load(state) + # check whether to evaluate the model on the training, validation + # or test set + if test is None: + ds_set = 'train' + else: + ds_set = 'test' if test else 'valid' + # check whether to evaluate on the datasets defined at training time if implicit: - # check whether to evaluate the model on the training, validation - # or test set - if test is None: - ds_set = 'train' - else: - ds_set = 'test' if test else 'valid' - # the dataset to evaluate the model on ds = model_state[domain + '_{}_dl'.format(ds_set)].dataset if ds is None: @@ -2214,18 +2215,18 @@ class NetworkInference(BaseConfig): # split configuration sc = SplitConfig(**self.ds_split) - train_ds, valid_ds, test_ds = sc.train_val_test_split(ds) + folds = sc.train_val_test_split(ds)[0] # check whether to evaluate the model on the training, validation # or test set - if test is None: - ds = train_ds - else: - ds = test_ds if test else valid_ds + ds = folds[ds_set] # log dataset representation LOGGER.info('Evaluating on {} set of explicitly defined dataset: ' - '\n {}'.format(ds.name, repr(ds.dataset))) + '\n {}'.format(ds_set, repr(ds.dataset))) + + # name the current dataset + ds.name = '_'.join([domain, ds_set]) # check the dataset path: replace by path on current machine self.replace_dataset_path(ds, DRIVE_PATH) @@ -2464,11 +2465,11 @@ class NetworkInference(BaseConfig): the samples (``self.predict_scene=False``) or the name of the scenes of the target dataset (``self.predict_scene=True``). The values are dictionaries with keys: - ``'input'`` + ``'x'`` Model input data of the sample (:py:class:`numpy.ndarray`). - ``'labels' + ``'y' Ground truth class labels (:py:class:`numpy.ndarray`). - ``'prediction'`` + ``'y_pred'`` Model prediction class labels (:py:class:`numpy.ndarray`). """ @@ -2498,7 +2499,8 @@ class NetworkInference(BaseConfig): if self.dataloader.batch_size > 1: # id of the current scene - batch = self.trg_ds.ids[batch] + current_scene = np.int(batch * self.dataloader.batch_size) + batch = self.trg_ds.dataset.scenes['id'][current_scene] # modify the progress string progress = progress.replace('Sample', 'Scene') @@ -2542,6 +2544,7 @@ class NetworkInference(BaseConfig): state=batch_name, plot_path=self.scenes_path, **self.kwargs) + return output def eval_file(self, state_file): @@ -2577,7 +2580,7 @@ class NetworkInference(BaseConfig): # initialize logging log = LogConfig(state) dictConfig(log_conf(log.log_file)) - log.init_log('{}: ' + 'Evaluating model: {}.'.format(state.name)) + log.init_log('Evaluating model: {}.'.format(state.name)) # check whether model was already evaluated if self.eval_file(state).exists(): @@ -2632,4 +2635,19 @@ class NetworkInference(BaseConfig): # save model predictions to list inference[state.stem] = output + # check whether to aggregate the results of the different model runs + if self.aggregate: + + # chech whether to compute the aggregated confusion matrix + if self.cm: + # initialize the aggregated confusion matrix + cm_agg = np.zeros(shape=2 * (len(self.src_ds.labels), )) + + # update aggregated confusion matrix + for _, output in inference.items(): + cm_agg += output['cm'] + + # save aggregated confusion matrix to dictionary + inference['cm_agg'] = cm_agg + return inference -- GitLab