From 7d24bc885c318fac3362b234cc3ede906d9ff27c Mon Sep 17 00:00:00 2001 From: "Daniel.Frisinghelli" <daniel.frisinghelli@eurac.edu> Date: Tue, 26 Jan 2021 14:40:49 +0100 Subject: [PATCH] Improved extraction from raster: implemented clip to shapefile. --- pysegcnn/core/utils.py | 111 ++++++++++++++++++++++++++++++----------- 1 file changed, 82 insertions(+), 29 deletions(-) diff --git a/pysegcnn/core/utils.py b/pysegcnn/core/utils.py index 7651b44..5556b62 100644 --- a/pysegcnn/core/utils.py +++ b/pysegcnn/core/utils.py @@ -2102,8 +2102,8 @@ def reproject_vector(src_ds, trg_ds, ref_ds=None, epsg=None, overwrite=False): del src_ds, trg_ds -def vector2raster(src_ds, trg_ds, attribute, pixel_size, out_type, no_data=0, - overwrite=False): +def vector2raster(src_ds, trg_ds, pixel_size, out_type, attribute=None, + burn_value=255, no_data=0, overwrite=False): """Convert a shapefile to a GeoTIFF. The vector data in the shapefile is converted to a GeoTIFF with a spatial @@ -2117,9 +2117,6 @@ def vector2raster(src_ds, trg_ds, attribute, pixel_size, out_type, no_data=0, The shapefile to convert. trg_ds : `str` or :py:class:`pathlib.Path` The target raster dataset. - attribute : `str` - The shapefile attribute to use for the GeoTIFF values. Note that an - error is raised if ``attribute`` does not exist in ``src_ds``. pixel_size : `tuple` [`int`, `int`] The pixel size of the target dataset, (height, width). The default is `(None, None)`. @@ -2127,6 +2124,14 @@ def vector2raster(src_ds, trg_ds, attribute, pixel_size, out_type, no_data=0, An integer describing the data type of the target raster dataset. See :py:func:`gdal.GetDataTypeName` for an enumeration of the data types corresponding to the different integers. + attribute : `str`, optional + The shapefile attribute to use for the GeoTIFF values. Note that an + error is raised if ``attribute`` does not exist in ``src_ds``. The + default is `None`, which means that a constant burn value of 255 is + used. + burn_value : `int`, optional + A fixed value to burn into each band for all objets in ``src_ds``. Used + if ``attribute=None``. The default is `255`. no_data : `int` or `float` The value to assign to NoData values in ``src_ds``. The default is `0`. overwrite : `bool`, optional @@ -2163,23 +2168,28 @@ def vector2raster(src_ds, trg_ds, attribute, pixel_size, out_type, no_data=0, # the field names of the source vector dataset field_names = [field.name for field in src_lr.schema] - # check if the defined attribute name is in the field names - if attribute not in field_names: - raise ValueError('"{}" is not a valid attribute. {} has the following ' - 'attributes: \n{}'.format( - attribute, src_path.name, '\n'.join(field_names)) - ) - - # get the source spatial extent - x_min, x_max, y_min, y_max = src_lr.GetExtent() - # encode the NoData value to the output data type nodata = getattr(Gdal2Numpy, gdal.GetDataTypeName(out_type)).value(no_data) + # check whether to use only a single attribute + if attribute is not None: + + # check whether the attribute is valid + if attribute not in field_names: + raise ValueError('"{}" is not a valid attribute. {} has the ' + 'following attributes: \n{}'.format( + attribute, src_path.name, + '\n'.join(field_names))) + else: + # do not use a constant burn value when retrieving a specific + # attribute + burn_value = None + # rasterize vector dataset to defined spatial resolution gdal.Rasterize(str(trg_path), str(src_path), - xRes=pixel_size[1], yRes=pixel_size[0], noData=nodata, - outputType=out_type, attribute=attribute) + xRes=pixel_size[1], yRes=pixel_size[0], + noData=nodata, outputType=out_type, attribute=attribute, + outputSRS=src_lr.GetSpatialRef(), burnValues=burn_value) # clear source dataset del src_ds @@ -2251,12 +2261,16 @@ def extract_by_mask(src_ds, mask_ds, trg_ds, overwrite=False): Extract the extent of ``mask_ds`` from ``src_ds``. The masked values of ``src_ds`` are saved in ``trg_ds``. + If ``mask_ds`` is a shapefile, it is expected to be in the same coordinate + reference system as ``src_ds``. + Parameters ---------- src_ds : `str` or :py:class:`pathlib.Path` The input raster to extract values from. mask_ds : `str` or :py:class:`pathlib.Path` - A mask raster defining the area of interest. + A mask defining the area of interest. Either a raster file or + a shapefile. trg_ds : `str` or :py:class:`pathlib.Path` The masked target dataset. overwrite : `bool`, optional @@ -2286,20 +2300,59 @@ def extract_by_mask(src_ds, mask_ds, trg_ds, overwrite=False): LOGGER.info('Overwrite {}'.format(str(trg_path))) trg_path.unlink() - # read the source and mask dataset + # read the source dataset src_ds = gdal.Open(str(src_path)) - mask_ds = gdal.Open(str(mask_path)) - # spatial extent of the mask - gt = mask_ds.GetGeoTransform() - extent = [gt[0], gt[3] + gt[5] * mask_ds.RasterYSize, - gt[0] + gt[1] * mask_ds.RasterXSize, gt[3]] + # the source dataset NoData value + no_data = src_ds.GetRasterBand(1).GetNoDataValue() - # extract values by mask - gdal.Warp(str(trg_path), str(src_path), - outputBounds=extent, - xRes=src_ds.GetGeoTransform()[1], - yRes=src_ds.GetGeoTransform()[5]) + # source dataset spatial reference + src_sr = osr.SpatialReference() + src_sr.ImportFromWkt(src_ds.GetProjection()) + + # checkt the type of the mask dataset + if mask_path.name.endswith('.shp'): + # clip raster values by shapefile + gdal.Warp(str(trg_path), str(src_path), + cutlineDSName=str(mask_path), + cropToCutline=True, + xRes=src_ds.GetGeoTransform()[1], + yRes=src_ds.GetGeoTransform()[5], + dstNodata=no_data) + + else: + # mask is a raster dataset + mask_ds = gdal.Open(str(mask_path)) + + # spatial extent of the mask: (x_min, x_max, y_min, y_max) + gt = mask_ds.GetGeoTransform() + extent = [gt[0], gt[0] + gt[1] * mask_ds.RasterXSize, + gt[3] + gt[5] * mask_ds.RasterYSize, gt[3]] + + # mask dataset spatial reference + mask_sr = osr.SpatialReference() + mask_sr.ImportFromWkt(mask_ds.GetProjection()) + + # coordinate transformation: from mask to source + crs_tr = osr.CoordinateTransformation(mask_sr, src_sr) + + # transform extent of mask to source coordinate system + + # TransfromPoint expects input: + # - gdal >= 3.0: x, y, z = TransformPoint(y, x) + # - gdal < 3.0 : x, y, z = TransformPoint(x, y) + x_tl, y_tl, _ = crs_tr.TransformPoint(extent[-1], extent[0]) + x_br, y_br, _ = crs_tr.TransformPoint(extent[2], extent[1]) + + # extent of the mask in the source reference coordinate system: + # (x_min, y_min, x_max, y_max) + extent = [x_tl, y_br, x_br, y_tl] + + # extract values by mask + gdal.Warp(str(trg_path), str(src_path), + outputBounds=extent, + xRes=src_ds.GetGeoTransform()[1], + yRes=src_ds.GetGeoTransform()[5]) # clear source and mask dataset del src_ds, mask_ds -- GitLab