diff --git a/pysegcnn/core/trainer.py b/pysegcnn/core/trainer.py index b64aa746ecf77563e2f7f994b8a0dd08c69c7909..7ebcd1e5b778dff6e2392c56c1f3a5c2aa20c9cb 100644 --- a/pysegcnn/core/trainer.py +++ b/pysegcnn/core/trainer.py @@ -58,6 +58,9 @@ from pysegcnn.main.train_config import HERE # module level logger LOGGER = logging.getLogger(__name__) +# global variable: variable names of the model inference output +INFERENCE_NAMES = ['x', 'y', 'y_pred'] + @dataclasses.dataclass class BaseConfig: @@ -2397,8 +2400,7 @@ class NetworkInference(BaseConfig): Returns ------- plot : `bool` - Save plots for each sample or for each scene of the target dataset, - depending on ``self.predict_scene``. + Save plots for each scene of the target dataset. """ return self.plot_scenes if self.predict_scene else False @@ -2424,13 +2426,10 @@ class NetworkInference(BaseConfig): Returns ------- batch_size : `int` - The batch size of the dataloader used for model inference. Depends - on whether to predict each sample of the target dataset - individually or whether to reconstruct each scene in the target - dataset. + The batch size of the dataloader used for model inference. """ - return self.trg_ds.dataset.tiles if self.predict_scene else 1 + return 1 @property def _original_source_labels(self): @@ -2514,8 +2513,11 @@ class NetworkInference(BaseConfig): model.to(self.device) LOGGER.info('Device: {}'.format(self.device)) - # iterate over the samples of the target dataset + # initialize dictionaries to store outputs in output = {} + scenes = {k: [] for k in INFERENCE_NAMES} + + # iterate over the samples of the target dataset for batch, (inputs, labels) in enumerate(self.dataloader): # send inputs and labels to device @@ -2527,68 +2529,94 @@ class NetworkInference(BaseConfig): prdctn = F.softmax( model(inputs), dim=1).argmax(dim=1).squeeze() + # check whether the source and target domain labels differ + if self.apply_label_map: + prdctn = self.map_to_target(prdctn) + # progress string to log - progress = 'Sample: {:d}/{:d}'.format(batch + 1, - len(self.dataloader)) + progress = 'Sample: {:d}/{:d}'.format( + batch + 1, len(self.dataloader)) # check if tensor is on gpu and convert to numpy array inputs = inputs.cpu().numpy() labels = labels.cpu().numpy() prdctn = prdctn.cpu().numpy() - # check whether to reconstruct the scene - if self.dataloader.batch_size > 1: - - # tiles of the current scene - current_tiles = self.trg_ds.indices[ - np.arange(batch * self.dataloader.batch_size, - (batch + 1) * self.dataloader.batch_size)] + # check whether to reconstruct the scenes of a dataset + if self.predict_scene: + + # append model predictions of current batch to scene dictionary + for k, v in zip(INFERENCE_NAMES, [inputs, labels, prdctn]): + scenes[k].append(v) + + # check if an entire scene is processed + if batch % self.trg_ds.dataset.tiles == 0 and batch != 0: + + # convert scene dictionary to numpy arrays + inputs, labels, prdctn = [np.asarray(v) for _, v in + scenes.items()] + + # tiles of the current scene + current_tiles = self.trg_ds.indices[ + np.arange(batch - self.trg_ds.dataset.tiles, batch)] + + # name of the current scene + batch = np.unique([self.trg_ds.dataset.scenes[sid]['id'] + for sid in current_tiles]).item() + + # modify the progress string + progress = progress.replace('Sample', 'Scene') + progress += ' Id: {}'.format(batch) + + # reconstruct the entire scene + inputs = reconstruct_scene(inputs) + labels = reconstruct_scene(labels) + prdctn = reconstruct_scene(prdctn) + + # calculate the accuracy of the prediction on the current + # scene + progress += ', Accuracy: {:.2f}'.format( + accuracy_function(prdctn, labels)) + LOGGER.info(progress) + + # save current scene to output dictionary + output[batch] = {k: v for k, v in zip( + INFERENCE_NAMES, [inputs, labels, prdctn])} + + # re-initialize scene dictionary + scenes = {k: [] for k in INFERENCE_NAMES} + + # plot current scene + if self.plot: + # plot inputs, ground truth and model predictions + fig = plot_sample(inputs.clip(0, 1), + self.bands, + self.use_labels, + y=labels, + y_pred={'Prediction': prdctn}, + accuracy=True, + **self.plot_kwargs) + + # filename for the plot of the current batch + batch_name = '_'.join( + [model.state_file.stem, + '{}_{}.pdf'.format(self.trg_ds.name, batch)]) + + # save figure + fig.savefig(check_filename_length( + self.scenes_path.joinpath(batch_name)), + bbox_inches='tight') - # name of the current scene - batch = np.unique([self.trg_ds.dataset.scenes[sid]['id'] for - sid in current_tiles]).item() - - # modify the progress string - progress = progress.replace('Sample', 'Scene') - progress += ' Id: {}'.format(batch) - - # reconstruct the entire scene - inputs = reconstruct_scene(inputs) - labels = reconstruct_scene(labels) - prdctn = reconstruct_scene(prdctn) - - # check whether the source and target domain labels differ - if self.apply_label_map: - prdctn = self.map_to_target(prdctn) - - # save current batch to output dictionary - output[batch] = {'x': inputs, 'y_true': labels, 'y_pred': prdctn} - - # filename for the plot of the current batch - batch_name = '_'.join([model.state_file.stem, - '{}_{}.pdf'.format(self.trg_ds.name, batch)]) - - # calculate the accuracy of the prediction - progress += ', Accuracy: {:.2f}'.format( - accuracy_function(prdctn, labels)) - LOGGER.info(progress) - - # plot current scene - if self.plot: - - # plot inputs, ground truth and model predictions - fig = plot_sample(inputs.clip(0, 1), - self.bands, - self.use_labels, - y=labels, - y_pred={'Prediction': prdctn}, - accuracy=True, - **self.plot_kwargs) - - # save figure - fig.savefig(check_filename_length( - self.scenes_path.joinpath(batch_name)), - bbox_inches='tight') + else: + # save current batch to output dictionary + output[batch] = {k: v for k, v in zip(INFERENCE_NAMES, + [inputs, labels, prdctn]) + } + + # calculate the accuracy of the prediction on the current batch + progress += ', Accuracy: {:.2f}'.format( + accuracy_function(prdctn, labels)) + LOGGER.info(progress) return output