Skip to content
Snippets Groups Projects
Commit 4a56331e authored by Frisinghelli Daniel's avatar Frisinghelli Daniel
Browse files

Improved extract by mask utility.

parent 13d2f989
No related branches found
No related tags found
No related merge requests found
......@@ -2255,7 +2255,8 @@ def dec2bin(number, nbits=8):
return binary
def extract_by_mask(src_ds, mask_ds, trg_ds, overwrite=False):
def extract_by_mask(src_ds, mask_ds, trg_ds, overwrite=False,
src_no_data=None, trg_no_data=0):
"""Extract raster values by mask.
Extract the extent of ``mask_ds`` from ``src_ds``. The masked values of
......@@ -2275,6 +2276,12 @@ def extract_by_mask(src_ds, mask_ds, trg_ds, overwrite=False):
The masked target dataset.
overwrite : `bool`, optional
Whether to overwrite ``trg_ds``, if it exists. The default is `False`.
src_no_data : `int` or `float`, optional
The value of NoData values in ``src_ds``. The default is `None`, which
means the value is read from ``src_ds``. If specified, values equal to
``src_no_data`` are masked as ``trg_no_data`` in ``trg_ds``.
trg_no_data : `int` or `float`, optional
The value to assign to NoData values in ``trg_ds``. The default is `0`.
"""
# convert path to source dataset and mask dataset to pathlib.Path object
......@@ -2303,9 +2310,6 @@ def extract_by_mask(src_ds, mask_ds, trg_ds, overwrite=False):
# read the source dataset
src_ds = gdal.Open(str(src_path))
# the source dataset NoData value
no_data = src_ds.GetRasterBand(1).GetNoDataValue()
# source dataset spatial reference
src_sr = osr.SpatialReference()
src_sr.ImportFromWkt(src_ds.GetProjection())
......@@ -2318,7 +2322,8 @@ def extract_by_mask(src_ds, mask_ds, trg_ds, overwrite=False):
cropToCutline=True,
xRes=src_ds.GetGeoTransform()[1],
yRes=src_ds.GetGeoTransform()[5],
dstNodata=no_data)
srcNodata=src_no_data,
dstNodata=trg_no_data)
else:
# mask is a raster dataset
......@@ -2341,8 +2346,8 @@ def extract_by_mask(src_ds, mask_ds, trg_ds, overwrite=False):
# TransfromPoint expects input:
# - gdal >= 3.0: x, y, z = TransformPoint(y, x)
# - gdal < 3.0 : x, y, z = TransformPoint(x, y)
x_tl, y_tl, _ = crs_tr.TransformPoint(extent[-1], extent[0])
x_br, y_br, _ = crs_tr.TransformPoint(extent[2], extent[1])
x_tl, y_tl, _ = crs_tr.TransformPoint(extent[0], extent[-1])
x_br, y_br, _ = crs_tr.TransformPoint(extent[1], extent[2])
# extent of the mask in the source reference coordinate system:
# (x_min, y_min, x_max, y_max)
......@@ -2352,7 +2357,9 @@ def extract_by_mask(src_ds, mask_ds, trg_ds, overwrite=False):
gdal.Warp(str(trg_path), str(src_path),
outputBounds=extent,
xRes=src_ds.GetGeoTransform()[1],
yRes=src_ds.GetGeoTransform()[5])
yRes=src_ds.GetGeoTransform()[5],
srcNodata=src_no_data,
dstNodata=trg_no_data)
# clear source and mask dataset
del src_ds, mask_ds
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment