From b8ae911eaea24609a62dd78b3da7e28d88176d4c Mon Sep 17 00:00:00 2001
From: "Daniel.Frisinghelli" <daniel.frisinghelli@eurac.edu>
Date: Mon, 11 Jan 2021 17:27:34 +0100
Subject: [PATCH] Added a function to extract values from a raster by another
 raster mask.

---
 pysegcnn/core/utils.py | 60 ++++++++++++++++++++++++++++++++++++++++++
 1 file changed, 60 insertions(+)

diff --git a/pysegcnn/core/utils.py b/pysegcnn/core/utils.py
index b49161d..79ba0e6 100644
--- a/pysegcnn/core/utils.py
+++ b/pysegcnn/core/utils.py
@@ -2211,3 +2211,63 @@ def dec2bin(number, nbits=8):
             binary += '0'
 
     return binary
+
+
+def extract_by_mask(src_ds, mask_ds, trg_ds, overwrite=False):
+    """Extract raster values by mask.
+
+    Extract the extent of ``mask_ds`` from ``src_ds``. The masked values of
+    ``src_ds`` are saved in ``trg_ds``.
+
+    Parameters
+    ----------
+    src_ds : `str` or :py:class:`pathlib.Path`
+        The input raster to extract values from.
+    mask_ds : `str` or :py:class:`pathlib.Path`
+        A mask raster defining the area of interest.
+    trg_ds : `str` or :py:class:`pathlib.Path`
+        The masked target dataset.
+    overwrite : `bool`, optional
+        Whether to overwrite ``trg_ds``, if it exists. The default is `False`.
+
+    """
+    # convert path to source dataset and mask dataset to pathlib.Path object
+    src_path = pathlib.Path(src_ds)
+    mask_path = pathlib.Path(mask_ds)
+
+    # check whether the source dataset exists
+    if not src_path.exists():
+        LOGGER.warning('{} does not exist.'.format(str(src_path)))
+
+    # check whether the mask exists
+    if not mask_path.exists():
+        LOGGER.warning('{} does not exist.'.format(str(mask_path)))
+
+    # check whether the output datasets exists
+    trg_path = pathlib.Path(trg_ds)
+    if not trg_path.exists():
+        LOGGER.info('mkdir {}'.format(str(trg_path.parent)))
+        trg_path.parent.mkdir(parents=True, exist_ok=True)
+    else:
+        # check whether to overwrite existing files
+        if overwrite:
+            LOGGER.info('Overwrite {}'.format(str(trg_path)))
+            trg_path.unlink()
+
+    # read the source and mask dataset
+    src_ds = gdal.Open(str(src_path))
+    mask_ds = gdal.Open(str(mask_path))
+
+    # spatial extent of the mask
+    gt = mask_ds.GetGeoTransform()
+    extent = [gt[0], gt[3] + gt[5] * mask_ds.RasterYSize,
+              gt[0] + gt[1] * mask_ds.RasterXSize, gt[3]]
+
+    # extract values by mask
+    gdal.Warp(str(trg_path), str(src_path),
+              outputBounds=extent,
+              xRes=src_ds.GetGeoTransform()[1],
+              yRes=src_ds.GetGeoTransform()[5])
+
+    # clear source and mask dataset
+    del src_ds, mask_ds
-- 
GitLab