diff --git a/pysegcnn/core/trainer.py b/pysegcnn/core/trainer.py index e663a336f417f73fa9649dae8141503159899859..2640ca953ad6796bd58ff0e5096007ae6afab2ae 100644 --- a/pysegcnn/core/trainer.py +++ b/pysegcnn/core/trainer.py @@ -2504,8 +2504,6 @@ class NetworkInference(BaseConfig): the samples (``self.predict_scene=False``) or the name of the scenes of the target dataset (``self.predict_scene=True``). The values are dictionaries with keys: - ``'x'`` - Model input data of the sample (:py:class:`numpy.ndarray`). ``'y_true' Ground truth class labels (:py:class:`numpy.ndarray`). ``'y_pred'`` @@ -2586,19 +2584,25 @@ class NetworkInference(BaseConfig): # save current scene to output dictionary output[batch] = {k: v for k, v in zip( - INFERENCE_NAMES, [inputs, labels, prdctn])} + INFERENCE_NAMES[1:], [labels, prdctn])} # re-initialize scene dictionary scenes = {k: [] for k in INFERENCE_NAMES} # plot current scene if self.plot: + + # title for prediction + title = ''.join([(v[0] + str(k)) for k, v in + self.src_ds.sensor.band_dict().items() + if v in self.bands]) + # plot inputs, ground truth and model predictions fig = plot_sample(inputs.clip(0, 1), self.bands, self.use_labels, y=labels, - y_pred={'Prediction': prdctn}, + y_pred={title: prdctn}, accuracy=True, **self.plot_kwargs) @@ -2618,8 +2622,8 @@ class NetworkInference(BaseConfig): else: # save current batch to output dictionary - output[batch] = {k: v for k, v in zip(INFERENCE_NAMES, - [inputs, labels, prdctn]) + output[batch] = {k: v for k, v in zip(INFERENCE_NAMES[1:], + [labels, prdctn]) } # calculate the accuracy of the prediction on the current batch