diff --git a/pysegcnn/core/utils.py b/pysegcnn/core/utils.py index d0c39d6f8f895323578804d65e90a4d64f133325..3529dc4dbcac4e90b7bc200387a0c241149ab31b 100644 --- a/pysegcnn/core/utils.py +++ b/pysegcnn/core/utils.py @@ -2255,7 +2255,8 @@ def dec2bin(number, nbits=8): return binary -def extract_by_mask(src_ds, mask_ds, trg_ds, overwrite=False): +def extract_by_mask(src_ds, mask_ds, trg_ds, overwrite=False, + src_no_data=None, trg_no_data=0): """Extract raster values by mask. Extract the extent of ``mask_ds`` from ``src_ds``. The masked values of @@ -2275,6 +2276,12 @@ def extract_by_mask(src_ds, mask_ds, trg_ds, overwrite=False): The masked target dataset. overwrite : `bool`, optional Whether to overwrite ``trg_ds``, if it exists. The default is `False`. + src_no_data : `int` or `float`, optional + The value of NoData values in ``src_ds``. The default is `None`, which + means the value is read from ``src_ds``. If specified, values equal to + ``src_no_data`` are masked as ``trg_no_data`` in ``trg_ds``. + trg_no_data : `int` or `float`, optional + The value to assign to NoData values in ``trg_ds``. The default is `0`. """ # convert path to source dataset and mask dataset to pathlib.Path object @@ -2303,9 +2310,6 @@ def extract_by_mask(src_ds, mask_ds, trg_ds, overwrite=False): # read the source dataset src_ds = gdal.Open(str(src_path)) - # the source dataset NoData value - no_data = src_ds.GetRasterBand(1).GetNoDataValue() - # source dataset spatial reference src_sr = osr.SpatialReference() src_sr.ImportFromWkt(src_ds.GetProjection()) @@ -2318,7 +2322,8 @@ def extract_by_mask(src_ds, mask_ds, trg_ds, overwrite=False): cropToCutline=True, xRes=src_ds.GetGeoTransform()[1], yRes=src_ds.GetGeoTransform()[5], - dstNodata=no_data) + srcNodata=src_no_data, + dstNodata=trg_no_data) else: # mask is a raster dataset @@ -2341,8 +2346,8 @@ def extract_by_mask(src_ds, mask_ds, trg_ds, overwrite=False): # TransfromPoint expects input: # - gdal >= 3.0: x, y, z = TransformPoint(y, x) # - gdal < 3.0 : x, y, z = TransformPoint(x, y) - x_tl, y_tl, _ = crs_tr.TransformPoint(extent[-1], extent[0]) - x_br, y_br, _ = crs_tr.TransformPoint(extent[2], extent[1]) + x_tl, y_tl, _ = crs_tr.TransformPoint(extent[0], extent[-1]) + x_br, y_br, _ = crs_tr.TransformPoint(extent[1], extent[2]) # extent of the mask in the source reference coordinate system: # (x_min, y_min, x_max, y_max) @@ -2352,7 +2357,9 @@ def extract_by_mask(src_ds, mask_ds, trg_ds, overwrite=False): gdal.Warp(str(trg_path), str(src_path), outputBounds=extent, xRes=src_ds.GetGeoTransform()[1], - yRes=src_ds.GetGeoTransform()[5]) + yRes=src_ds.GetGeoTransform()[5], + srcNodata=src_no_data, + dstNodata=trg_no_data) # clear source and mask dataset del src_ds, mask_ds