From b8ae911eaea24609a62dd78b3da7e28d88176d4c Mon Sep 17 00:00:00 2001 From: "Daniel.Frisinghelli" <daniel.frisinghelli@eurac.edu> Date: Mon, 11 Jan 2021 17:27:34 +0100 Subject: [PATCH] Added a function to extract values from a raster by another raster mask. --- pysegcnn/core/utils.py | 60 ++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 60 insertions(+) diff --git a/pysegcnn/core/utils.py b/pysegcnn/core/utils.py index b49161d..79ba0e6 100644 --- a/pysegcnn/core/utils.py +++ b/pysegcnn/core/utils.py @@ -2211,3 +2211,63 @@ def dec2bin(number, nbits=8): binary += '0' return binary + + +def extract_by_mask(src_ds, mask_ds, trg_ds, overwrite=False): + """Extract raster values by mask. + + Extract the extent of ``mask_ds`` from ``src_ds``. The masked values of + ``src_ds`` are saved in ``trg_ds``. + + Parameters + ---------- + src_ds : `str` or :py:class:`pathlib.Path` + The input raster to extract values from. + mask_ds : `str` or :py:class:`pathlib.Path` + A mask raster defining the area of interest. + trg_ds : `str` or :py:class:`pathlib.Path` + The masked target dataset. + overwrite : `bool`, optional + Whether to overwrite ``trg_ds``, if it exists. The default is `False`. + + """ + # convert path to source dataset and mask dataset to pathlib.Path object + src_path = pathlib.Path(src_ds) + mask_path = pathlib.Path(mask_ds) + + # check whether the source dataset exists + if not src_path.exists(): + LOGGER.warning('{} does not exist.'.format(str(src_path))) + + # check whether the mask exists + if not mask_path.exists(): + LOGGER.warning('{} does not exist.'.format(str(mask_path))) + + # check whether the output datasets exists + trg_path = pathlib.Path(trg_ds) + if not trg_path.exists(): + LOGGER.info('mkdir {}'.format(str(trg_path.parent))) + trg_path.parent.mkdir(parents=True, exist_ok=True) + else: + # check whether to overwrite existing files + if overwrite: + LOGGER.info('Overwrite {}'.format(str(trg_path))) + trg_path.unlink() + + # read the source and mask dataset + src_ds = gdal.Open(str(src_path)) + mask_ds = gdal.Open(str(mask_path)) + + # spatial extent of the mask + gt = mask_ds.GetGeoTransform() + extent = [gt[0], gt[3] + gt[5] * mask_ds.RasterYSize, + gt[0] + gt[1] * mask_ds.RasterXSize, gt[3]] + + # extract values by mask + gdal.Warp(str(trg_path), str(src_path), + outputBounds=extent, + xRes=src_ds.GetGeoTransform()[1], + yRes=src_ds.GetGeoTransform()[5]) + + # clear source and mask dataset + del src_ds, mask_ds -- GitLab