diff --git a/pysegcnn/core/trainer.py b/pysegcnn/core/trainer.py index cc805cc42aae2bc66807c44a923363450036b8e0..45d6ee2fbf056c3ade12f44fa61e9cdbfe34c883 100644 --- a/pysegcnn/core/trainer.py +++ b/pysegcnn/core/trainer.py @@ -54,7 +54,6 @@ from pysegcnn.core.graphics import (plot_loss, plot_confusion_matrix, from pysegcnn.core.constants import map_labels from pysegcnn.main.train_config import HERE, DRIVE_PATH - # module level logger LOGGER = logging.getLogger(__name__) @@ -2505,9 +2504,14 @@ class NetworkInference(BaseConfig): # check whether to reconstruct the scene if self.dataloader.batch_size > 1: - # id of the current scene - current_scene = np.int(batch * self.dataloader.batch_size) - batch = self.trg_ds.dataset.scenes[current_scene]['id'] + # tiles of the current scene + current_tiles = self.trg_ds.indices[ + np.arange(batch * self.dataloader.batch_size, + (batch + 1) * self.dataloader.batch_size)] + + # name of the current scene + batch = np.unique([sid for sid in self.trg_ds.dataset.scenes[ + current_tiles]['id']]).item() # modify the progress string progress = progress.replace('Sample', 'Scene')