From 827ce6023957c93d53239a433b0d0974d8e2b476 Mon Sep 17 00:00:00 2001 From: "Daniel.Frisinghelli" <daniel.frisinghelli@eurac.edu> Date: Thu, 4 Feb 2021 17:46:41 +0100 Subject: [PATCH] Check if torch Tensor is on gpu. --- pysegcnn/core/utils.py | 21 ++++++++++++--------- 1 file changed, 12 insertions(+), 9 deletions(-) diff --git a/pysegcnn/core/utils.py b/pysegcnn/core/utils.py index 867abc3..24d905b 100644 --- a/pysegcnn/core/utils.py +++ b/pysegcnn/core/utils.py @@ -755,29 +755,32 @@ def reconstruct_scene(tiles): The reconstructed image, shape: `(bands, height, width)`. """ - # convert to numpy array - tiles = np.asarray(tiles) + # check if tensor is on gpu and convert to numpy array + if isinstance(tiles, torch.Tensor): + tiles_cpu = np.asarray(tiles.cpu()) + else: + tiles_cpu = np.asarray(tiles) # check the dimensions of the input array - if tiles.ndim > 3: - nbands = tiles.shape[1] - tile_size = tiles.shape[2] + if tiles_cpu.ndim > 3: + nbands = tiles_cpu.shape[1] + tile_size = tiles_cpu.shape[2] else: nbands = 1 - tile_size = tiles.shape[1] + tile_size = tiles_cpu.shape[1] # calculate image size - img_size = 2 * (int(np.sqrt(tiles.shape[0]) * tile_size),) + img_size = 2 * (int(np.sqrt(tiles_cpu.shape[0]) * tile_size),) # calculate the topleft corners of the tiles topleft = tile_topleft_corner(img_size, tile_size) # iterate over the tiles scene = np.zeros(shape=(nbands,) + img_size) - for t in range(tiles.shape[0]): + for t in range(tiles_cpu.shape[0]): scene[..., topleft[t][0]: topleft[t][0] + tile_size, - topleft[t][1]: topleft[t][1] + tile_size] = tiles[t, ...] + topleft[t][1]: topleft[t][1] + tile_size] = tiles_cpu[t, ...] return scene.squeeze() -- GitLab