From 533736943606b8558e2ec0a32d850fe820ed8eab Mon Sep 17 00:00:00 2001 From: "Daniel.Frisinghelli" <daniel.frisinghelli@eurac.edu> Date: Wed, 20 Jan 2021 18:18:59 +0100 Subject: [PATCH] Major refactor of NetworkInference: implemented label mapping. --- pysegcnn/core/trainer.py | 132 ++++++++++++++++++++++++++------------- 1 file changed, 90 insertions(+), 42 deletions(-) diff --git a/pysegcnn/core/trainer.py b/pysegcnn/core/trainer.py index 77b0b25..457671b 100644 --- a/pysegcnn/core/trainer.py +++ b/pysegcnn/core/trainer.py @@ -2173,30 +2173,14 @@ class NetworkInference(BaseConfig): Path to store plots of model predictions for entire scenes. perfmc_path : :py:class:`pathlib.Path` Path to store plots of model performance, e.g. confusion matrix. - animtn_path : :py:class:`pathlib.Path` - Path to store animations. models_path : :py:class:`pathlib.Path` Path to search for model state files ``state_files``. plot_kwargs : `dict` Keyword arguments for :py:func:`pysegcnn.core.graphics.plot_sample` - basename : `str` - Base filename for each plot. - model : :py:class:`pysegcnn.core.models.Network` - The model to use for inference. - model_state : `dict` - A dictionary containing the model and optimizer state, as - constructed by :py:meth:`~pysegcnn.core.Network.save`. trg_ds : :py:class:`pysegcnn.core.split.CustomSubset` The dataset to evaluate ``model`` on. src_ds : :py:class:`pysegcnn.core.split.CustomSubset` The model source domain training dataset. - fig : :py:class:`matplotlib.figure.Figure` - A :py:class:`matplotlib.figure.Figure` instance to iteratively plot to. - anim : :py:class:`pysegcnn.core.graphics.Animate` - An instance :py:class:`pysegcnn.core.graphics.Animate` Used to create - animations if ``animate=True``. - conf_mat : :py:class:`numpy.ndarray` - The model confusion matrix. """ @@ -2214,7 +2198,6 @@ class NetworkInference(BaseConfig): cm: bool = True figsize: tuple = (10, 10) alpha: int = 5 - animate: bool = False def __post_init__(self): """Check the type of each argument. @@ -2250,7 +2233,6 @@ class NetworkInference(BaseConfig): self.sample_path = self.base_path.joinpath('_samples') self.scenes_path = self.base_path.joinpath('_scenes') self.perfmc_path = self.base_path.joinpath('_graphics') - self.animtn_path = self.base_path.joinpath('_animations') # input path for model state files self.models_path = self.base_path.joinpath('_models') @@ -2397,6 +2379,40 @@ class NetworkInference(BaseConfig): """ return self.trg_ds.labels + @property + def source_label_map(self): + """Mapping of the original source labels to the model source labels. + + See + :py:meth:`pysegcnn.core.trainer.NetworkInference._original_source_labels`. + + Returns + ------- + source_labels : :py:class:`numpy.ndarray` + The mapping from the original source class identifiers to the + identifiers used during training. + + """ + return np.array([list(self.source_labels.keys()), + list(self._original_source_labels.keys())]).T + + @property + def target_label_map(self): + """Mapping of the original target labels to the model target labels. + + See + :py:meth:`pysegcnn.core.trainer.NetworkInference._original_source_labels`. + + Returns + ------- + target_labels : :py:class:`numpy.ndarray` + The mapping from the original target class identifiers to the + identifiers used for evaluation. + + """ + return np.array([list(self._original_target_labels_labels.keys()), + list(self.target_labels.keys())]).T + @property def label_map(self): """Label mapping dictionary from the source to the target domain. @@ -2508,6 +2524,34 @@ class NetworkInference(BaseConfig): """ return self.trg_ds.dataset.tiles if self.predict_scene else 1 + @property + def _original_source_labels(self): + """Original source domain labels. + + Since PyTorch requires class labels to be an ascending sequence + starting from 0, the actual class labels in the ground truth may differ + from the class labels fed to the model. + + Returns + ------- + original_source_labels : `dict` [`int`, `dict`] + The original class labels of the source domain. + + """ + return self.src_ds._labels + + @property + def _original_target_labels(self): + """Original target domain labels. + + Returns + ------- + original_target_labels : `dict` [`int`, `dict`] + The original class labels of the target domain. + + """ + return self.trg_ds.dataset._labels + def map_to_target(self, prd): """Map source domain labels to target domain labels. @@ -2523,16 +2567,13 @@ class NetworkInference(BaseConfig): """ # map actual source labels to original source labels - # prd = array_replace(prd, np.array([self.src_ds.labels.keys(), - # self.src_ds._labels.keys()])) + prd = array_replace(prd, self.source_label_map) # apply the label mapping - # prd = array_replace(prd, self.label_map.to_numpy()) + prd = array_replace(prd, self.label_map.to_numpy()) # map original target labels to actual target labels - # for oid, aid in zip(self.trg_ds._labels.keys(), - # self.target_labels.keys()): - # prd[torch.where(prd == oid)] = aid + prd = array_replace(prd, self.target_label_map) return prd @@ -2626,26 +2667,34 @@ class NetworkInference(BaseConfig): **self.kwargs) return output + def eval_file(self, state_file): + return pathlib.Path(str(state_file).replace('.pt', '_eval.pt')) + def evaluate(self): """Evaluate the models on a defined dataset. Returns ------- - output : `dict` [`str`, `dict`] - The inference output dictionary. The keys are either the number of - the samples (``self.predict_scene=False``) or the name of the + inference : `dict` [`str`, `dict`] + The inference output dictionary. The keys are the names of the + models in ``self.state_file`` and the values are dictionaries + where the keys are either the number of the batches + (``self.predict_scene=False``) or the name of the scenes of the target dataset (``self.predict_scene=True``). The - values are dictionaries with keys: + values of the nested dictionaries are again dictionaries with keys: ``'x'`` Model input data of the sample (:py:class:`numpy.ndarray`). ``'y' Ground truth class labels (:py:class:`numpy.ndarray`). ``'y_pred'`` Model prediction class labels (:py:class:`numpy.ndarray`). + ``'cm'`` + The confusion matrix of the model, which is only present if + ``self.cm=True`` (:py:class:`numpy.ndarray`). """ # iterate over the models to evaluate - inference = [] + inference = {} for state in self.state_files: # initialize logging @@ -2653,6 +2702,15 @@ class NetworkInference(BaseConfig): dictConfig(log_conf(log.log_file)) log.init_log('{}: ' + 'Evaluating model: {}.'.format(state.name)) + # check whether model was already evaluated + if self.eval_file(state).exists(): + + # load existing model evaluation + LOGGER.info('Found existing model evaluation: {}.' + .format(self.eval_file(state))) + inference[state.stem] = torch.load(self.eval_file(state)) + continue + # plot loss and accuracy plot_loss(check_filename_length(state), outpath=self.perfmc_path) @@ -2692,19 +2750,9 @@ class NetworkInference(BaseConfig): outpath=self.perfmc_path) # save model predictions to file - torch.save(output, str(state).replace('.pt', '_eval.pt')) + torch.save(output, self.eval_file(state)) # save model predictions to list - inference.append(output) - - # check whether to compute an aggregated confusion matrix - if self.aggregate and self.cm: - - # initialize aggregated confusion matrix - cm = np.zeros(shape=2 * (len(self.src_ds.labels), )) - - # iterate over the different model runs - for out in inference: - cm += out['cm'] + inference[state.stem] = output - # plot aggregated cm + return inference -- GitLab