diff --git a/pysegcnn/core/utils.py b/pysegcnn/core/utils.py index 1db0a9750199ae82abb0d392e047b4f0b92756b2..867abc35d2954dd31741b93865b1b06d343acb98 100644 --- a/pysegcnn/core/utils.py +++ b/pysegcnn/core/utils.py @@ -265,7 +265,7 @@ def img2np(path, tile_size=None, tile=None, pad=False, cval=0): def np2tif(array, filename, no_data=None, names=None, src_ds=None, epsg=None, - geotransform=None): + geotransform=None, overwrite=False): """Save a :py:class`numpy.ndarray` as a GeoTIFF. The spatial coordinate reference system can be specified in two ways: @@ -299,6 +299,8 @@ def np2tif(array, filename, no_data=None, names=None, src_ds=None, epsg=None, A tuple with six elements of the form, (x_top_left, x_res, x_shift, y_top_left, -y_res, y_shift), describing the spatial reference. + overwrite : `bool`, optional + Whether to overwrite ``filename`` if it exists. The default is `False`. .. _EPSG: https://epsg.io/ @@ -333,12 +335,20 @@ def np2tif(array, filename, no_data=None, names=None, src_ds=None, epsg=None, raise ValueError('{} is not a file.'.format(filename)) filename = pathlib.Path(str(filename).replace(filename.suffix, '.tif')) + # check if file exists + if filename.exists() and not overwrite: + LOGGER.info('{} already exists.'.format(filename)) + return + + # create temporary file + tmp_path = _tmp_path(filename) + # create output GeoTIFF - trg_ds = driver.Create(str(filename), width, height, bands, dtype) + tmp_ds = driver.Create(str(tmp_path), width, height, bands, dtype) # iterate over the number of bands and write to output file for b in range(bands): - trg_band = trg_ds.GetRasterBand(b + 1) + trg_band = tmp_ds.GetRasterBand(b + 1) trg_band.WriteArray(array[b, ...]) # set the band description, if specified @@ -353,8 +363,8 @@ def np2tif(array, filename, no_data=None, names=None, src_ds=None, epsg=None, # set spatial reference if src_ds is not None: # inherit spatial reference from source dataset - trg_ds.SetProjection(src_ds.GetProjection()) - trg_ds.SetGeoTransform(src_ds.GetGeoTransform()) + tmp_ds.SetProjection(src_ds.GetProjection()) + tmp_ds.SetGeoTransform(src_ds.GetGeoTransform()) else: # check whether both the epsg code and the geotransform tuple are # specified @@ -363,16 +373,17 @@ def np2tif(array, filename, no_data=None, names=None, src_ds=None, epsg=None, sr = osr.SpatialReference().ImportFromEPSG(epsg).ExporttoWkt() # set spatial reference from epsg - trg_ds.SetProjection(sr) - trg_ds.SetGeoTransform(geotransform) + tmp_ds.SetProjection(sr) + tmp_ds.SetGeoTransform(geotransform) else: raise ValueError('Both "epsg" and "geotransform" required to set ' 'spatial reference if "src_ds" is None.') # clear dataset - del trg_band, trg_ds + del trg_band, tmp_ds - return + # compress raster + compress_raster(tmp_path, filename) def read_hdf(path, **kwargs): @@ -431,6 +442,8 @@ def hdf2tifs(path, outpath=None, overwrite=False, create_stack=True, **kwargs): The default (``outpath=None``) is to save the GeoTIFFs in a directory named after the filename of ``path``, within the parent directory of ``path``. + The output GeoTIFFs are compressed by default. + Parameters ---------- path : `str` or py:class:`pathlib.Path` @@ -504,7 +517,8 @@ def hdf2tifs(path, outpath=None, overwrite=False, create_stack=True, **kwargs): # convert hdf subdataset to GeoTIFF LOGGER.info('Converting: {}'.format(tif_name.name)) - gdal.Translate(str(tif_name), gdal.Open(ds[0]), **kwargs) + gdal.Translate(str(tif_name), gdal.Open(ds[0]), creationOptions=[ + 'COMPRESS=DEFLATE', 'PREDICTOR=1', 'TILED=YES'], **kwargs) # check whether to create a GeoTIFF stack if create_stack: @@ -526,6 +540,8 @@ def hdf2tifs(path, outpath=None, overwrite=False, create_stack=True, **kwargs): def stack_tifs(filename, tifs, **kwargs): """Create a stacked GeoTIFF from a list of single-band GeoTIFFs. + The output GeoTIFF stack is compressed by default. + Parameters ---------- filename : `str` or py:class:`pathlib.Path` @@ -552,7 +568,8 @@ def stack_tifs(filename, tifs, **kwargs): # create GeoTIFF stack gdal.PushErrorHandler('CPLQuietErrorHandler') - gdal.Translate(str(filename), vrt_ds, **kwargs) + gdal.Translate(str(filename), vrt_ds, creationOptions=[ + 'COMPRESS=DEFLATE', 'PREDICTOR=1', 'TILED=YES'], **kwargs) del vrt_ds @@ -1917,7 +1934,8 @@ def reproject_raster(src_ds, trg_ds, ref_ds=None, epsg=None, resample='near', # check whether the source dataset exists if not src_path.exists(): - LOGGER.warning('{} does not exist.'.format(str(src_path))) + LOGGER.info('{} does not exist.'.format(str(src_path))) + return # check whether the output datasets exists trg_path = pathlib.Path(trg_ds) @@ -1925,9 +1943,13 @@ def reproject_raster(src_ds, trg_ds, ref_ds=None, epsg=None, resample='near', trg_path.parent.mkdir(parents=True, exist_ok=True) else: # check whether to overwrite existing files - if overwrite: - LOGGER.info('Overwrite {}'.format(str(trg_path))) - trg_path.unlink() + if not overwrite: + LOGGER.info('{} already exists.'.format(trg_path)) + return + + # overwrite + LOGGER.info('Overwrite {}'.format(str(trg_path))) + trg_path.unlink() # read the source dataset src_ds = gdal.Open(str(src_path)) @@ -1971,9 +1993,12 @@ def reproject_raster(src_ds, trg_ds, ref_ds=None, epsg=None, resample='near', ref_xres = pixel_size[0] ref_yres = pixel_size[1] + # create a temporary path + tmp_path = _tmp_path(trg_path) + # reproject source dataset to target projection - LOGGER.info('Reproject {}:'.format(src_path.name)) - gdal.Warp(str(trg_path), str(src_path), + LOGGER.info('Reproject: {}'.format(src_path.name)) + gdal.Warp(str(tmp_path), str(src_path), srcSRS=src_sr, dstSRS=ref_sr, outputType=out_type, @@ -1982,6 +2007,9 @@ def reproject_raster(src_ds, trg_ds, ref_ds=None, epsg=None, resample='near', yRes=ref_yres, resampleAlg=resample) + # compress raster + compress_raster(tmp_path, trg_path) + # clear gdal cache del src_ds, ref_ds @@ -2015,18 +2043,22 @@ def reproject_vector(src_ds, trg_ds, ref_ds=None, epsg=None, overwrite=False): # check whether the source dataset exists if not src_path.exists(): - LOGGER.warning('{} does not exist.'.format(str(src_path))) + LOGGER.info('{} does not exist.'.format(str(src_path))) + return # check whether the output datasets exists trg_path = pathlib.Path(trg_ds) if not trg_path.exists(): - LOGGER.info('mkdir {}'.format(str(trg_path.parent))) trg_path.parent.mkdir(parents=True, exist_ok=True) else: # check whether to overwrite existing files - if overwrite: - LOGGER.info('Overwrite {}'.format(str(trg_path))) - trg_path.unlink() + if not overwrite: + LOGGER.info('{} already exists.'.format(trg_path)) + return + + # overwrite + LOGGER.info('Overwrite {}'.format(str(trg_path))) + trg_path.unlink() # read the source vector dataset src_ds = ogr.Open(str(src_path)) @@ -2152,18 +2184,22 @@ def vector2raster(src_ds, trg_ds, pixel_size, out_type, attribute=None, # check whether the source dataset exists if not src_path.exists(): - LOGGER.warning('{} does not exist.'.format(str(src_path))) + LOGGER.info('{} does not exist.'.format(str(src_path))) + return # check whether the output datasets exists trg_path = pathlib.Path(trg_ds) if not trg_path.exists(): - LOGGER.info('mkdir {}'.format(str(trg_path.parent))) trg_path.parent.mkdir(parents=True, exist_ok=True) else: # check whether to overwrite existing files - if overwrite: - LOGGER.info('Overwrite {}'.format(str(trg_path))) - trg_path.unlink() + if not overwrite: + LOGGER.info('{} already exists.'.format(trg_path)) + return + + # overwrite + LOGGER.info('Overwrite {}'.format(str(trg_path))) + trg_path.unlink() # read the source vector dataset src_ds = ogr.Open(str(src_path)) @@ -2189,12 +2225,19 @@ def vector2raster(src_ds, trg_ds, pixel_size, out_type, attribute=None, # attribute burn_value = None + # create a temporary path + tmp_path = _tmp_path(trg_path) + # rasterize vector dataset to defined spatial resolution - gdal.Rasterize(str(trg_path), str(src_path), + LOGGER.info('Rasterizing: {}'.format(src_path)) + gdal.Rasterize(str(tmp_path), str(src_path), xRes=pixel_size[1], yRes=pixel_size[0], noData=nodata, outputType=out_type, attribute=attribute, outputSRS=src_lr.GetSpatialRef(), burnValues=burn_value) + # compress raster dataset + compress_raster(tmp_path, trg_path) + # clear source dataset del src_ds @@ -2261,7 +2304,7 @@ def dec2bin(number, nbits=8): def extract_by_mask(src_ds, mask_ds, trg_ds, overwrite=False, src_no_data=None, trg_no_data=None): - """Extract raster values by mask. + """Extract raster values by a shapefile. Extract the extent of ``mask_ds`` from ``src_ds``. The masked values of ``src_ds`` are saved in ``trg_ds``. @@ -2274,8 +2317,7 @@ def extract_by_mask(src_ds, mask_ds, trg_ds, overwrite=False, src_ds : `str` or :py:class:`pathlib.Path` The input raster to extract values from. mask_ds : `str` or :py:class:`pathlib.Path` - A mask defining the area of interest. Either a raster file or - a shapefile. + A shapefile defining the area of interest. trg_ds : `str` or :py:class:`pathlib.Path` The masked target dataset. overwrite : `bool`, optional @@ -2295,11 +2337,13 @@ def extract_by_mask(src_ds, mask_ds, trg_ds, overwrite=False, # check whether the source dataset exists if not src_path.exists(): - LOGGER.warning('{} does not exist.'.format(str(src_path))) + LOGGER.info('{} does not exist.'.format(str(src_path))) + return # check whether the mask exists if not mask_path.exists(): - LOGGER.warning('{} does not exist.'.format(str(mask_path))) + LOGGER.info('{} does not exist.'.format(str(mask_path))) + return # check whether the output datasets exists trg_path = pathlib.Path(trg_ds) @@ -2307,65 +2351,138 @@ def extract_by_mask(src_ds, mask_ds, trg_ds, overwrite=False, trg_path.parent.mkdir(parents=True, exist_ok=True) else: # check whether to overwrite existing files - if overwrite: - LOGGER.info('Overwrite {}'.format(str(trg_path))) - trg_path.unlink() + if not overwrite: + LOGGER.info('{} already exists.'.format(trg_path)) + return + + # overwrite + LOGGER.info('Overwrite {}'.format(str(trg_path))) + trg_path.unlink() + + # create a temporary file + tmp_path = _tmp_path(trg_path) # read the source dataset src_ds = gdal.Open(str(src_path)) + LOGGER.info('Extract: {}, {}'.format(src_path.name, mask_path.name)) - # source dataset spatial reference - src_sr = osr.SpatialReference() - src_sr.ImportFromWkt(src_ds.GetProjection()) + # extract raster values by shapefile + gdal.Warp(str(tmp_path), str(src_path), + cutlineDSName=str(mask_path), + cropToCutline=True, + xRes=src_ds.GetGeoTransform()[1], + yRes=src_ds.GetGeoTransform()[5], + srcNodata=src_no_data, + dstNodata=trg_no_data) + + # compress raster + compress_raster(tmp_path, trg_path) + + # clear source dataset + del src_ds - LOGGER.info('Extract: {}, {}'.format(src_path.name, mask_path.name)) - # checkt the type of the mask dataset - if mask_path.name.endswith('.shp'): - # clip raster values by shapefile - gdal.Warp(str(trg_path), str(src_path), - cutlineDSName=str(mask_path), - cropToCutline=True, - xRes=src_ds.GetGeoTransform()[1], - yRes=src_ds.GetGeoTransform()[5], - srcNodata=src_no_data, - dstNodata=trg_no_data) +def clip_raster(src_ds, mask_ds, trg_ds, overwrite=False, src_no_data=None, + trg_no_data=None): + """Clip raster to extent of another raster. + Clip the extent of ``src_ds`` to the extent of ``mask_ds``. The clipped + raster is saved in ``trg_ds``. + + Parameters + ---------- + src_ds : `str` or :py:class:`pathlib.Path` + The input raster to clip. + mask_ds : `str` or :py:class:`pathlib.Path` + The raster defining the extent of interest. + trg_ds : `str` or :py:class:`pathlib.Path` + The clipped raster dataset. + overwrite : `bool`, optional + Whether to overwrite ``trg_ds``, if it exists. The default is `False`. + src_no_data : `int` or `float`, optional + The value of NoData values in ``src_ds``. The default is `None`, which + means the value is read from ``src_ds``. If specified, values equal to + ``src_no_data`` are masked as ``trg_no_data`` in ``trg_ds``. + trg_no_data : `int` or `float`, optional + The value to assign to NoData values in ``trg_ds``. The default is + `None`, which means conserving the NoData value of ``src_ds``. + + """ + # convert path to source dataset and mask dataset to pathlib.Path object + src_path = pathlib.Path(src_ds) + mask_path = pathlib.Path(mask_ds) + + # check whether the source dataset exists + if not src_path.exists(): + LOGGER.info('{} does not exist.'.format(str(src_path))) + return + + # check whether the mask exists + if not mask_path.exists(): + LOGGER.info('{} does not exist.'.format(str(mask_path))) + return + + # check whether the output datasets exists + trg_path = pathlib.Path(trg_ds) + if not trg_path.exists(): + trg_path.parent.mkdir(parents=True, exist_ok=True) else: - # mask is a raster dataset - mask_ds = gdal.Open(str(mask_path)) + # check whether to overwrite existing files + if not overwrite: + LOGGER.info('{} already exists.'.format(trg_path)) + return + + # overwrite + LOGGER.info('Overwrite {}'.format(str(trg_path))) + trg_path.unlink() - # spatial extent of the mask: (x_min, x_max, y_min, y_max) - gt = mask_ds.GetGeoTransform() - extent = [gt[0], gt[0] + gt[1] * mask_ds.RasterXSize, - gt[3] + gt[5] * mask_ds.RasterYSize, gt[3]] + # mask is a raster dataset + mask_ds = gdal.Open(str(mask_path)) - # mask dataset spatial reference - mask_sr = osr.SpatialReference() - mask_sr.ImportFromWkt(mask_ds.GetProjection()) + # spatial extent of the mask: (x_min, x_max, y_min, y_max) + gt = mask_ds.GetGeoTransform() + extent = [gt[0], gt[0] + gt[1] * mask_ds.RasterXSize, + gt[3] + gt[5] * mask_ds.RasterYSize, gt[3]] - # coordinate transformation: from mask to source - crs_tr = osr.CoordinateTransformation(mask_sr, src_sr) + # mask dataset spatial reference + mask_sr = osr.SpatialReference() + mask_sr.ImportFromWkt(mask_ds.GetProjection()) + + # source dataset spatial reference + src_ds = gdal.Open(str(src_path)) + src_sr = osr.SpatialReference() + src_sr.ImportFromWkt(src_ds.GetProjection()) - # transform extent of mask to source coordinate system + # coordinate transformation: from mask to source + crs_tr = osr.CoordinateTransformation(mask_sr, src_sr) - # TransfromPoint expects input: - # - gdal >= 3.0: x, y, z = TransformPoint(y, x) - # - gdal < 3.0 : x, y, z = TransformPoint(x, y) - x_tl, y_tl, _ = crs_tr.TransformPoint(extent[0], extent[-1]) - x_br, y_br, _ = crs_tr.TransformPoint(extent[1], extent[2]) + # transform extent of mask to source coordinate system - # extent of the mask in the source reference coordinate system: - # (x_min, y_min, x_max, y_max) - extent = [x_tl, y_br, x_br, y_tl] + # TransfromPoint expects input: + # - gdal >= 3.0: x, y, z = TransformPoint(y, x) + # - gdal < 3.0 : x, y, z = TransformPoint(x, y) + x_tl, y_tl, _ = crs_tr.TransformPoint(extent[0], extent[-1]) + x_br, y_br, _ = crs_tr.TransformPoint(extent[1], extent[2]) - # extract values by mask - gdal.Warp(str(trg_path), str(src_path), - outputBounds=extent, - xRes=src_ds.GetGeoTransform()[1], - yRes=src_ds.GetGeoTransform()[5], - srcNodata=src_no_data, - dstNodata=trg_no_data) + # extent of the mask in the source reference coordinate system: + # (x_min, y_min, x_max, y_max) + extent = [x_tl, y_br, x_br, y_tl] + + # create a temporary file + tmp_path = _tmp_path(trg_path) + + # clip raster extent + LOGGER.info('Clipping: {}, Extent: (x_tl={:.2f}, y_br={:.2f}, x_br={:.2f},' + ' y_tl={:.2f})'.format(src_path.name, *extent)) + gdal.Warp(str(tmp_path), str(src_path), + outputBounds=extent, + xRes=src_ds.GetGeoTransform()[1], + yRes=src_ds.GetGeoTransform()[5], + srcNodata=src_no_data, + dstNodata=trg_no_data) + + # compress raster dataset + compress_raster(tmp_path, trg_path) # clear source and mask dataset del src_ds, mask_ds @@ -2466,5 +2583,55 @@ def merge_tifs(trg_ds, tifs): List of paths to the GeoTiffs to mosaic. """ - LOGGER.info('Creating mosaic: {}'.format(trg_ds)) - gdal.Warp(str(trg_ds), [str(tif) for tif in tifs]) + # create mosaic + tmp_path = _tmp_path(trg_ds) + LOGGER.info('Create mosaic: {}'.format(trg_ds)) + gdal.Warp(str(tmp_path), [str(tif) for tif in tifs]) + + # compress raster + compress_raster(tmp_path, trg_ds) + + +def compress_raster(src_ds, trg_ds): + """Compress a raster dataset using :py:func:`gdal.Translate`. + + Parameters + ---------- + src_ds : `str` or :py:class:`pathlib.Path` + Path to the raster dataset to compress. + trg_ds : `str` or :py:class:`pathlib.Path` + Path to save the compressed raster dataset. + + """ + # check if the raster dataset exists + src_ds = pathlib.Path(src_ds) + if not src_ds.exists(): + LOGGER.info('{} does not exist'.format(src_ds)) + return + + # compress raster dataset + LOGGER.info('Compressing: {}'.format(trg_ds)) + gdal.Translate(str(trg_ds), str(src_ds), creationOptions=[ + 'COMPRESS=DEFLATE', 'PREDICTOR=1', 'TILED=YES']) + + # remove uncompressed raster dataset + src_ds.unlink() + + +def _tmp_path(path): + """Create a temporary filename from a given path. + + Parameters + ---------- + path : `str` or :py:class:`pathlib.Path` + Path to a file for which a temporary path is required. + + Returns + ------- + tmp_path : `str` or :py:class:`pathlib.Path` + Path to the temporary file. + + """ + path = pathlib.Path(path) + return pathlib.Path(str(path).replace(path.suffix, + '_tmp{}'.format(path.suffix)))