diff --git a/pysegcnn/core/utils.py b/pysegcnn/core/utils.py index 867abc35d2954dd31741b93865b1b06d343acb98..24d905b09e102c60b3306db4115ca0bcdbc79dba 100644 --- a/pysegcnn/core/utils.py +++ b/pysegcnn/core/utils.py @@ -755,29 +755,32 @@ def reconstruct_scene(tiles): The reconstructed image, shape: `(bands, height, width)`. """ - # convert to numpy array - tiles = np.asarray(tiles) + # check if tensor is on gpu and convert to numpy array + if isinstance(tiles, torch.Tensor): + tiles_cpu = np.asarray(tiles.cpu()) + else: + tiles_cpu = np.asarray(tiles) # check the dimensions of the input array - if tiles.ndim > 3: - nbands = tiles.shape[1] - tile_size = tiles.shape[2] + if tiles_cpu.ndim > 3: + nbands = tiles_cpu.shape[1] + tile_size = tiles_cpu.shape[2] else: nbands = 1 - tile_size = tiles.shape[1] + tile_size = tiles_cpu.shape[1] # calculate image size - img_size = 2 * (int(np.sqrt(tiles.shape[0]) * tile_size),) + img_size = 2 * (int(np.sqrt(tiles_cpu.shape[0]) * tile_size),) # calculate the topleft corners of the tiles topleft = tile_topleft_corner(img_size, tile_size) # iterate over the tiles scene = np.zeros(shape=(nbands,) + img_size) - for t in range(tiles.shape[0]): + for t in range(tiles_cpu.shape[0]): scene[..., topleft[t][0]: topleft[t][0] + tile_size, - topleft[t][1]: topleft[t][1] + tile_size] = tiles[t, ...] + topleft[t][1]: topleft[t][1] + tile_size] = tiles_cpu[t, ...] return scene.squeeze()