From 7d24bc885c318fac3362b234cc3ede906d9ff27c Mon Sep 17 00:00:00 2001
From: "Daniel.Frisinghelli" <daniel.frisinghelli@eurac.edu>
Date: Tue, 26 Jan 2021 14:40:49 +0100
Subject: [PATCH] Improved extraction from raster: implemented clip to
 shapefile.

---
 pysegcnn/core/utils.py | 111 ++++++++++++++++++++++++++++++-----------
 1 file changed, 82 insertions(+), 29 deletions(-)

diff --git a/pysegcnn/core/utils.py b/pysegcnn/core/utils.py
index 7651b44..5556b62 100644
--- a/pysegcnn/core/utils.py
+++ b/pysegcnn/core/utils.py
@@ -2102,8 +2102,8 @@ def reproject_vector(src_ds, trg_ds, ref_ds=None, epsg=None, overwrite=False):
     del src_ds, trg_ds
 
 
-def vector2raster(src_ds, trg_ds, attribute, pixel_size, out_type, no_data=0,
-                  overwrite=False):
+def vector2raster(src_ds, trg_ds, pixel_size, out_type, attribute=None,
+                  burn_value=255, no_data=0, overwrite=False):
     """Convert a shapefile to a GeoTIFF.
 
     The vector data in the shapefile is converted to a GeoTIFF with a spatial
@@ -2117,9 +2117,6 @@ def vector2raster(src_ds, trg_ds, attribute, pixel_size, out_type, no_data=0,
         The shapefile to convert.
     trg_ds : `str` or :py:class:`pathlib.Path`
         The target raster dataset.
-    attribute : `str`
-        The shapefile attribute to use for the GeoTIFF values. Note that an
-        error is raised if ``attribute`` does not exist in ``src_ds``.
     pixel_size : `tuple` [`int`, `int`]
         The pixel size of the target dataset, (height, width). The default is
         `(None, None)`.
@@ -2127,6 +2124,14 @@ def vector2raster(src_ds, trg_ds, attribute, pixel_size, out_type, no_data=0,
         An integer describing the data type of the target raster dataset. See
         :py:func:`gdal.GetDataTypeName` for an enumeration of the data types
         corresponding to the different integers.
+    attribute : `str`, optional
+        The shapefile attribute to use for the GeoTIFF values. Note that an
+        error is raised if ``attribute`` does not exist in ``src_ds``. The
+        default is `None`, which means that a constant burn value of 255 is
+        used.
+    burn_value : `int`, optional
+        A fixed value to burn into each band for all objets in ``src_ds``. Used
+        if ``attribute=None``. The default is `255`.
     no_data : `int` or `float`
         The value to assign to NoData values in ``src_ds``. The default is `0`.
     overwrite : `bool`, optional
@@ -2163,23 +2168,28 @@ def vector2raster(src_ds, trg_ds, attribute, pixel_size, out_type, no_data=0,
     # the field names of the source vector dataset
     field_names = [field.name for field in src_lr.schema]
 
-    # check if the defined attribute name is in the field names
-    if attribute not in field_names:
-        raise ValueError('"{}" is not a valid attribute. {} has the following '
-                         'attributes: \n{}'.format(
-                             attribute, src_path.name, '\n'.join(field_names))
-                         )
-
-    # get the source spatial extent
-    x_min, x_max, y_min, y_max = src_lr.GetExtent()
-
     # encode the NoData value to the output data type
     nodata = getattr(Gdal2Numpy, gdal.GetDataTypeName(out_type)).value(no_data)
 
+    # check whether to use only a single attribute
+    if attribute is not None:
+
+        # check whether the attribute is valid
+        if attribute not in field_names:
+            raise ValueError('"{}" is not a valid attribute. {} has the '
+                             'following attributes: \n{}'.format(
+                                 attribute, src_path.name,
+                                 '\n'.join(field_names)))
+        else:
+            # do not use a constant burn value when retrieving a specific
+            # attribute
+            burn_value = None
+
     # rasterize vector dataset to defined spatial resolution
     gdal.Rasterize(str(trg_path), str(src_path),
-                   xRes=pixel_size[1], yRes=pixel_size[0], noData=nodata,
-                   outputType=out_type, attribute=attribute)
+                   xRes=pixel_size[1], yRes=pixel_size[0],
+                   noData=nodata, outputType=out_type, attribute=attribute,
+                   outputSRS=src_lr.GetSpatialRef(), burnValues=burn_value)
 
     # clear source dataset
     del src_ds
@@ -2251,12 +2261,16 @@ def extract_by_mask(src_ds, mask_ds, trg_ds, overwrite=False):
     Extract the extent of ``mask_ds`` from ``src_ds``. The masked values of
     ``src_ds`` are saved in ``trg_ds``.
 
+    If ``mask_ds`` is a shapefile, it is expected to be in the same coordinate
+    reference system as ``src_ds``.
+
     Parameters
     ----------
     src_ds : `str` or :py:class:`pathlib.Path`
         The input raster to extract values from.
     mask_ds : `str` or :py:class:`pathlib.Path`
-        A mask raster defining the area of interest.
+        A mask defining the area of interest. Either a raster file or
+        a shapefile.
     trg_ds : `str` or :py:class:`pathlib.Path`
         The masked target dataset.
     overwrite : `bool`, optional
@@ -2286,20 +2300,59 @@ def extract_by_mask(src_ds, mask_ds, trg_ds, overwrite=False):
             LOGGER.info('Overwrite {}'.format(str(trg_path)))
             trg_path.unlink()
 
-    # read the source and mask dataset
+    # read the source dataset
     src_ds = gdal.Open(str(src_path))
-    mask_ds = gdal.Open(str(mask_path))
 
-    # spatial extent of the mask
-    gt = mask_ds.GetGeoTransform()
-    extent = [gt[0], gt[3] + gt[5] * mask_ds.RasterYSize,
-              gt[0] + gt[1] * mask_ds.RasterXSize, gt[3]]
+    # the source dataset NoData value
+    no_data = src_ds.GetRasterBand(1).GetNoDataValue()
 
-    # extract values by mask
-    gdal.Warp(str(trg_path), str(src_path),
-              outputBounds=extent,
-              xRes=src_ds.GetGeoTransform()[1],
-              yRes=src_ds.GetGeoTransform()[5])
+    # source dataset spatial reference
+    src_sr = osr.SpatialReference()
+    src_sr.ImportFromWkt(src_ds.GetProjection())
+
+    # checkt the type of the mask dataset
+    if mask_path.name.endswith('.shp'):
+        # clip raster values by shapefile
+        gdal.Warp(str(trg_path), str(src_path),
+                  cutlineDSName=str(mask_path),
+                  cropToCutline=True,
+                  xRes=src_ds.GetGeoTransform()[1],
+                  yRes=src_ds.GetGeoTransform()[5],
+                  dstNodata=no_data)
+
+    else:
+        # mask is a raster dataset
+        mask_ds = gdal.Open(str(mask_path))
+
+        # spatial extent of the mask: (x_min, x_max, y_min, y_max)
+        gt = mask_ds.GetGeoTransform()
+        extent = [gt[0], gt[0] + gt[1] * mask_ds.RasterXSize,
+                  gt[3] + gt[5] * mask_ds.RasterYSize, gt[3]]
+
+        # mask dataset spatial reference
+        mask_sr = osr.SpatialReference()
+        mask_sr.ImportFromWkt(mask_ds.GetProjection())
+
+        # coordinate transformation: from mask to source
+        crs_tr = osr.CoordinateTransformation(mask_sr, src_sr)
+
+        # transform extent of mask to source coordinate system
+
+        # TransfromPoint expects input:
+        #   - gdal >= 3.0: x, y, z = TransformPoint(y, x)
+        #   - gdal < 3.0 : x, y, z = TransformPoint(x, y)
+        x_tl, y_tl, _ = crs_tr.TransformPoint(extent[-1], extent[0])
+        x_br, y_br, _ = crs_tr.TransformPoint(extent[2], extent[1])
+
+        # extent of the mask in the source reference coordinate system:
+        # (x_min, y_min, x_max, y_max)
+        extent = [x_tl, y_br, x_br, y_tl]
+
+        # extract values by mask
+        gdal.Warp(str(trg_path), str(src_path),
+                  outputBounds=extent,
+                  xRes=src_ds.GetGeoTransform()[1],
+                  yRes=src_ds.GetGeoTransform()[5])
 
     # clear source and mask dataset
     del src_ds, mask_ds
-- 
GitLab