From 827ce6023957c93d53239a433b0d0974d8e2b476 Mon Sep 17 00:00:00 2001
From: "Daniel.Frisinghelli" <daniel.frisinghelli@eurac.edu>
Date: Thu, 4 Feb 2021 17:46:41 +0100
Subject: [PATCH] Check if torch Tensor is on gpu.

---
 pysegcnn/core/utils.py | 21 ++++++++++++---------
 1 file changed, 12 insertions(+), 9 deletions(-)

diff --git a/pysegcnn/core/utils.py b/pysegcnn/core/utils.py
index 867abc3..24d905b 100644
--- a/pysegcnn/core/utils.py
+++ b/pysegcnn/core/utils.py
@@ -755,29 +755,32 @@ def reconstruct_scene(tiles):
         The reconstructed image, shape: `(bands, height, width)`.
 
     """
-    # convert to numpy array
-    tiles = np.asarray(tiles)
+    # check if tensor is on gpu and convert to numpy array
+    if isinstance(tiles, torch.Tensor):
+        tiles_cpu = np.asarray(tiles.cpu())
+    else:
+        tiles_cpu = np.asarray(tiles)
 
     # check the dimensions of the input array
-    if tiles.ndim > 3:
-        nbands = tiles.shape[1]
-        tile_size = tiles.shape[2]
+    if tiles_cpu.ndim > 3:
+        nbands = tiles_cpu.shape[1]
+        tile_size = tiles_cpu.shape[2]
     else:
         nbands = 1
-        tile_size = tiles.shape[1]
+        tile_size = tiles_cpu.shape[1]
 
     # calculate image size
-    img_size = 2 * (int(np.sqrt(tiles.shape[0]) * tile_size),)
+    img_size = 2 * (int(np.sqrt(tiles_cpu.shape[0]) * tile_size),)
 
     # calculate the topleft corners of the tiles
     topleft = tile_topleft_corner(img_size, tile_size)
 
     # iterate over the tiles
     scene = np.zeros(shape=(nbands,) + img_size)
-    for t in range(tiles.shape[0]):
+    for t in range(tiles_cpu.shape[0]):
         scene[...,
               topleft[t][0]: topleft[t][0] + tile_size,
-              topleft[t][1]: topleft[t][1] + tile_size] = tiles[t, ...]
+              topleft[t][1]: topleft[t][1] + tile_size] = tiles_cpu[t, ...]
 
     return scene.squeeze()
 
-- 
GitLab