From 6e3a7d7cd393177fa68ec5c8c92c3624e79a107e Mon Sep 17 00:00:00 2001 From: "Daniel.Frisinghelli" <daniel.frisinghelli@eurac.edu> Date: Fri, 5 Feb 2021 15:25:58 +0100 Subject: [PATCH] Fixed some bugs: dataset path replacement and confusion matrix initialization. --- pysegcnn/core/trainer.py | 81 ++++++++++++++++++++++++++++------------ 1 file changed, 58 insertions(+), 23 deletions(-) diff --git a/pysegcnn/core/trainer.py b/pysegcnn/core/trainer.py index 45d6ee2..457b524 100644 --- a/pysegcnn/core/trainer.py +++ b/pysegcnn/core/trainer.py @@ -52,7 +52,7 @@ from pysegcnn.core.logging import log_conf from pysegcnn.core.graphics import (plot_loss, plot_confusion_matrix, plot_sample) from pysegcnn.core.constants import map_labels -from pysegcnn.main.train_config import HERE, DRIVE_PATH +from pysegcnn.main.train_config import HERE # module level logger LOGGER = logging.getLogger(__name__) @@ -2018,6 +2018,13 @@ class NetworkInference(BaseConfig): :py:class:`pysegcnn.core.trainer.SplitConfig` when evaluating on an explicitly defined dataset, i.e. ``implicit=False``. The default is `{}`. + drive_path : `str` + Path to the datasets on the current machine. Per default, the path to + the datasets is assumed to be the same as during training, i.e. model + training and evaluation is done on the same machine. Otherwise, + ``drive_path`` should be the path to the datasets, ending with + `'Datasets'`, on the machine model evaluation is performed. The default + is `''`. map_labels : `bool` Whether to map the model labels from the model source domain to the defined ``domain`` in case the domain class labels differ. The default @@ -2036,6 +2043,9 @@ class NetworkInference(BaseConfig): `['nir', 'red', 'green']`. cm : `bool` Whether to compute the confusion matrix. The default is `True`. + overwrite : `bool` + Whether to overwrite existing model evaluations. The default is + `False`. figsize : `tuple` The figure size in centimeters. The default is `(10, 10)`. alpha : `int` @@ -2069,12 +2079,14 @@ class NetworkInference(BaseConfig): aggregate: bool = False ds: dict = dataclasses.field(default_factory={}) ds_split: dict = dataclasses.field(default_factory={}) + drive_path: str = '' map_labels: bool = False predict_scene: bool = True plot_scenes: bool = False plot_bands: list = dataclasses.field( default_factory=lambda: ['red', 'green', 'blue']) cm: bool = True + overwrite: bool = False figsize: tuple = (10, 10) alpha: int = 5 @@ -2135,8 +2147,6 @@ class NetworkInference(BaseConfig): This function assumes that the datasets are stored in a directory named "Datasets" on each machine. - See ``DRIVE_PATH`` in :py:mod:`pysegcnn.main.eval_config`. - Parameters ---------- ds : :py:class:`torch.utils.data.Subset` @@ -2146,19 +2156,12 @@ class NetworkInference(BaseConfig): Base path to the datasets on the current machine. ``drive_path`` should end with `'Datasets'`. - Raises - ------ - TypeError - Raised if ``ds`` is not an instance of - :py:class:`torch.utils.data.Subset` build from an instance of - :py:class:`pysegcnn.core.dataset.ImageDataset`. - """ # iterate over the scenes of the dataset for scene in ds.dataset.scenes: for k, v in scene.items(): - # do only look for paths - if isinstance(v, str) and k != 'id': + # do only look for paths to the dataset + if k in ds.dataset.use_bands + ['gt']: # drive path: match path before "Datasets" # dpath = re.search('^(.*)(?=(/.*Datasets))', v) @@ -2229,7 +2232,16 @@ class NetworkInference(BaseConfig): ds.name = '_'.join([domain, ds_set]) # check the dataset path: replace by path on current machine - # self.replace_dataset_path(ds, DRIVE_PATH) + if self.drive_path: + # check whether the specified path exists on the current machine + self.drive_path = pathlib.Path(self.drive_path) + if not self.drive_path.exists(): + raise FileExistsError('Dataset path {} does not exist.' + .format(self.drive_path)) + + # replace dataset path of target dataset by path on the current + # machine + self.replace_dataset_path(ds, self.drive_path) return ds @@ -2472,7 +2484,7 @@ class NetworkInference(BaseConfig): values are dictionaries with keys: ``'x'`` Model input data of the sample (:py:class:`numpy.ndarray`). - ``'y' + ``'y_true' Ground truth class labels (:py:class:`numpy.ndarray`). ``'y_pred'`` Model prediction class labels (:py:class:`numpy.ndarray`). @@ -2527,7 +2539,7 @@ class NetworkInference(BaseConfig): prdctn = self.map_to_target(prdctn) # save current batch to output dictionary - output[batch] = {'x': inputs, 'y': labels, 'y_pred': prdctn} + output[batch] = {'x': inputs, 'y_true': labels, 'y_pred': prdctn} # filename for the plot of the current batch batch_name = '_'.join([model.state_file.stem, @@ -2575,7 +2587,7 @@ class NetworkInference(BaseConfig): values of the nested dictionaries are again dictionaries with keys: ``'x'`` Model input data of the sample (:py:class:`numpy.ndarray`). - ``'y' + ``'y_true' Ground truth class labels (:py:class:`numpy.ndarray`). ``'y_pred'`` Model prediction class labels (:py:class:`numpy.ndarray`). @@ -2595,12 +2607,20 @@ class NetworkInference(BaseConfig): # check whether model was already evaluated if self.eval_file(state).exists(): - - # load existing model evaluation LOGGER.info('Found existing model evaluation: {}.' .format(self.eval_file(state))) - inference[state.stem] = torch.load(self.eval_file(state)) - continue + + # load existing model evaluation + if not self.overwrite: + LOGGER.info('Using model evaluation: {}.' + .format(self.eval_file(state))) + inference[state.stem] = torch.load(self.eval_file(state)) + continue + else: + # overwrite existing model evaluation + LOGGER.info('Overwriting model evaluation: {}.' + .format(self.eval_file(state))) + self.eval_file(state).unlink() # plot loss and accuracy plot_loss(check_filename_length(state), outpath=self.perfmc_path) @@ -2622,11 +2642,14 @@ class NetworkInference(BaseConfig): # check whether to calculate confusion matrix if self.cm: - # TODO: merge predictions for all scenes in output + # merge predictions for all samples + y_true = np.asarray([v['y_true'].numpy().flatten() for _, v + in output.items()]).flatten() + y_pred = np.asarray([v['y_pred'].numpy().flatten() for _, v + in output.items()]).flatten() # calculate confusion matrix - conf_mat = confusion_matrix(output['y'].numpy().flatten(), - output['y_pred'].numpy().flatten()) + conf_mat = confusion_matrix(y_true, y_pred) # add confusion matrix to model output output['cm'] = conf_mat @@ -2646,6 +2669,9 @@ class NetworkInference(BaseConfig): # check whether to aggregate the results of the different model runs if self.aggregate: + # base name for all models + base_name = str(self.state_files[0]).name + # chech whether to compute the aggregated confusion matrix if self.cm: # initialize the aggregated confusion matrix @@ -2659,4 +2685,13 @@ class NetworkInference(BaseConfig): # save aggregated confusion matrix to dictionary inference['cm_agg'] = cm_agg + # create file name for aggregated confusion matrix + fold_number = re.search('f[0-9]', base_name)[0] + cm_name = base_name.replace(fold_number, 'kfold') + + # plot aggregated confusion matrix and save to file + plot_confusion_matrix( + cm_agg, self.source_labels, state_file=cm_name, + outpath=self.perfmc_path) + return inference -- GitLab