Skip to content
Snippets Groups Projects
Commit 836942cd authored by Frisinghelli Daniel's avatar Frisinghelli Daniel
Browse files

Implemented raster compression in all relevant functions.

parent b075a5e7
No related branches found
No related tags found
No related merge requests found
......@@ -265,7 +265,7 @@ def img2np(path, tile_size=None, tile=None, pad=False, cval=0):
def np2tif(array, filename, no_data=None, names=None, src_ds=None, epsg=None,
geotransform=None):
geotransform=None, overwrite=False):
"""Save a :py:class`numpy.ndarray` as a GeoTIFF.
The spatial coordinate reference system can be specified in two ways:
......@@ -299,6 +299,8 @@ def np2tif(array, filename, no_data=None, names=None, src_ds=None, epsg=None,
A tuple with six elements of the form,
(x_top_left, x_res, x_shift, y_top_left, -y_res, y_shift), describing
the spatial reference.
overwrite : `bool`, optional
Whether to overwrite ``filename`` if it exists. The default is `False`.
.. _EPSG:
https://epsg.io/
......@@ -333,12 +335,20 @@ def np2tif(array, filename, no_data=None, names=None, src_ds=None, epsg=None,
raise ValueError('{} is not a file.'.format(filename))
filename = pathlib.Path(str(filename).replace(filename.suffix, '.tif'))
# check if file exists
if filename.exists() and not overwrite:
LOGGER.info('{} already exists.'.format(filename))
return
# create temporary file
tmp_path = _tmp_path(filename)
# create output GeoTIFF
trg_ds = driver.Create(str(filename), width, height, bands, dtype)
tmp_ds = driver.Create(str(tmp_path), width, height, bands, dtype)
# iterate over the number of bands and write to output file
for b in range(bands):
trg_band = trg_ds.GetRasterBand(b + 1)
trg_band = tmp_ds.GetRasterBand(b + 1)
trg_band.WriteArray(array[b, ...])
# set the band description, if specified
......@@ -353,8 +363,8 @@ def np2tif(array, filename, no_data=None, names=None, src_ds=None, epsg=None,
# set spatial reference
if src_ds is not None:
# inherit spatial reference from source dataset
trg_ds.SetProjection(src_ds.GetProjection())
trg_ds.SetGeoTransform(src_ds.GetGeoTransform())
tmp_ds.SetProjection(src_ds.GetProjection())
tmp_ds.SetGeoTransform(src_ds.GetGeoTransform())
else:
# check whether both the epsg code and the geotransform tuple are
# specified
......@@ -363,16 +373,17 @@ def np2tif(array, filename, no_data=None, names=None, src_ds=None, epsg=None,
sr = osr.SpatialReference().ImportFromEPSG(epsg).ExporttoWkt()
# set spatial reference from epsg
trg_ds.SetProjection(sr)
trg_ds.SetGeoTransform(geotransform)
tmp_ds.SetProjection(sr)
tmp_ds.SetGeoTransform(geotransform)
else:
raise ValueError('Both "epsg" and "geotransform" required to set '
'spatial reference if "src_ds" is None.')
# clear dataset
del trg_band, trg_ds
del trg_band, tmp_ds
return
# compress raster
compress_raster(tmp_path, filename)
def read_hdf(path, **kwargs):
......@@ -431,6 +442,8 @@ def hdf2tifs(path, outpath=None, overwrite=False, create_stack=True, **kwargs):
The default (``outpath=None``) is to save the GeoTIFFs in a directory named
after the filename of ``path``, within the parent directory of ``path``.
The output GeoTIFFs are compressed by default.
Parameters
----------
path : `str` or py:class:`pathlib.Path`
......@@ -504,7 +517,8 @@ def hdf2tifs(path, outpath=None, overwrite=False, create_stack=True, **kwargs):
# convert hdf subdataset to GeoTIFF
LOGGER.info('Converting: {}'.format(tif_name.name))
gdal.Translate(str(tif_name), gdal.Open(ds[0]), **kwargs)
gdal.Translate(str(tif_name), gdal.Open(ds[0]), creationOptions=[
'COMPRESS=DEFLATE', 'PREDICTOR=1', 'TILED=YES'], **kwargs)
# check whether to create a GeoTIFF stack
if create_stack:
......@@ -526,6 +540,8 @@ def hdf2tifs(path, outpath=None, overwrite=False, create_stack=True, **kwargs):
def stack_tifs(filename, tifs, **kwargs):
"""Create a stacked GeoTIFF from a list of single-band GeoTIFFs.
The output GeoTIFF stack is compressed by default.
Parameters
----------
filename : `str` or py:class:`pathlib.Path`
......@@ -552,7 +568,8 @@ def stack_tifs(filename, tifs, **kwargs):
# create GeoTIFF stack
gdal.PushErrorHandler('CPLQuietErrorHandler')
gdal.Translate(str(filename), vrt_ds, **kwargs)
gdal.Translate(str(filename), vrt_ds, creationOptions=[
'COMPRESS=DEFLATE', 'PREDICTOR=1', 'TILED=YES'], **kwargs)
del vrt_ds
......@@ -1917,7 +1934,8 @@ def reproject_raster(src_ds, trg_ds, ref_ds=None, epsg=None, resample='near',
# check whether the source dataset exists
if not src_path.exists():
LOGGER.warning('{} does not exist.'.format(str(src_path)))
LOGGER.info('{} does not exist.'.format(str(src_path)))
return
# check whether the output datasets exists
trg_path = pathlib.Path(trg_ds)
......@@ -1925,9 +1943,13 @@ def reproject_raster(src_ds, trg_ds, ref_ds=None, epsg=None, resample='near',
trg_path.parent.mkdir(parents=True, exist_ok=True)
else:
# check whether to overwrite existing files
if overwrite:
LOGGER.info('Overwrite {}'.format(str(trg_path)))
trg_path.unlink()
if not overwrite:
LOGGER.info('{} already exists.'.format(trg_path))
return
# overwrite
LOGGER.info('Overwrite {}'.format(str(trg_path)))
trg_path.unlink()
# read the source dataset
src_ds = gdal.Open(str(src_path))
......@@ -1971,9 +1993,12 @@ def reproject_raster(src_ds, trg_ds, ref_ds=None, epsg=None, resample='near',
ref_xres = pixel_size[0]
ref_yres = pixel_size[1]
# create a temporary path
tmp_path = _tmp_path(trg_path)
# reproject source dataset to target projection
LOGGER.info('Reproject {}:'.format(src_path.name))
gdal.Warp(str(trg_path), str(src_path),
LOGGER.info('Reproject: {}'.format(src_path.name))
gdal.Warp(str(tmp_path), str(src_path),
srcSRS=src_sr,
dstSRS=ref_sr,
outputType=out_type,
......@@ -1982,6 +2007,9 @@ def reproject_raster(src_ds, trg_ds, ref_ds=None, epsg=None, resample='near',
yRes=ref_yres,
resampleAlg=resample)
# compress raster
compress_raster(tmp_path, trg_path)
# clear gdal cache
del src_ds, ref_ds
......@@ -2015,18 +2043,22 @@ def reproject_vector(src_ds, trg_ds, ref_ds=None, epsg=None, overwrite=False):
# check whether the source dataset exists
if not src_path.exists():
LOGGER.warning('{} does not exist.'.format(str(src_path)))
LOGGER.info('{} does not exist.'.format(str(src_path)))
return
# check whether the output datasets exists
trg_path = pathlib.Path(trg_ds)
if not trg_path.exists():
LOGGER.info('mkdir {}'.format(str(trg_path.parent)))
trg_path.parent.mkdir(parents=True, exist_ok=True)
else:
# check whether to overwrite existing files
if overwrite:
LOGGER.info('Overwrite {}'.format(str(trg_path)))
trg_path.unlink()
if not overwrite:
LOGGER.info('{} already exists.'.format(trg_path))
return
# overwrite
LOGGER.info('Overwrite {}'.format(str(trg_path)))
trg_path.unlink()
# read the source vector dataset
src_ds = ogr.Open(str(src_path))
......@@ -2152,18 +2184,22 @@ def vector2raster(src_ds, trg_ds, pixel_size, out_type, attribute=None,
# check whether the source dataset exists
if not src_path.exists():
LOGGER.warning('{} does not exist.'.format(str(src_path)))
LOGGER.info('{} does not exist.'.format(str(src_path)))
return
# check whether the output datasets exists
trg_path = pathlib.Path(trg_ds)
if not trg_path.exists():
LOGGER.info('mkdir {}'.format(str(trg_path.parent)))
trg_path.parent.mkdir(parents=True, exist_ok=True)
else:
# check whether to overwrite existing files
if overwrite:
LOGGER.info('Overwrite {}'.format(str(trg_path)))
trg_path.unlink()
if not overwrite:
LOGGER.info('{} already exists.'.format(trg_path))
return
# overwrite
LOGGER.info('Overwrite {}'.format(str(trg_path)))
trg_path.unlink()
# read the source vector dataset
src_ds = ogr.Open(str(src_path))
......@@ -2189,12 +2225,19 @@ def vector2raster(src_ds, trg_ds, pixel_size, out_type, attribute=None,
# attribute
burn_value = None
# create a temporary path
tmp_path = _tmp_path(trg_path)
# rasterize vector dataset to defined spatial resolution
gdal.Rasterize(str(trg_path), str(src_path),
LOGGER.info('Rasterizing: {}'.format(src_path))
gdal.Rasterize(str(tmp_path), str(src_path),
xRes=pixel_size[1], yRes=pixel_size[0],
noData=nodata, outputType=out_type, attribute=attribute,
outputSRS=src_lr.GetSpatialRef(), burnValues=burn_value)
# compress raster dataset
compress_raster(tmp_path, trg_path)
# clear source dataset
del src_ds
......@@ -2261,7 +2304,7 @@ def dec2bin(number, nbits=8):
def extract_by_mask(src_ds, mask_ds, trg_ds, overwrite=False,
src_no_data=None, trg_no_data=None):
"""Extract raster values by mask.
"""Extract raster values by a shapefile.
Extract the extent of ``mask_ds`` from ``src_ds``. The masked values of
``src_ds`` are saved in ``trg_ds``.
......@@ -2274,8 +2317,7 @@ def extract_by_mask(src_ds, mask_ds, trg_ds, overwrite=False,
src_ds : `str` or :py:class:`pathlib.Path`
The input raster to extract values from.
mask_ds : `str` or :py:class:`pathlib.Path`
A mask defining the area of interest. Either a raster file or
a shapefile.
A shapefile defining the area of interest.
trg_ds : `str` or :py:class:`pathlib.Path`
The masked target dataset.
overwrite : `bool`, optional
......@@ -2295,11 +2337,13 @@ def extract_by_mask(src_ds, mask_ds, trg_ds, overwrite=False,
# check whether the source dataset exists
if not src_path.exists():
LOGGER.warning('{} does not exist.'.format(str(src_path)))
LOGGER.info('{} does not exist.'.format(str(src_path)))
return
# check whether the mask exists
if not mask_path.exists():
LOGGER.warning('{} does not exist.'.format(str(mask_path)))
LOGGER.info('{} does not exist.'.format(str(mask_path)))
return
# check whether the output datasets exists
trg_path = pathlib.Path(trg_ds)
......@@ -2307,65 +2351,138 @@ def extract_by_mask(src_ds, mask_ds, trg_ds, overwrite=False,
trg_path.parent.mkdir(parents=True, exist_ok=True)
else:
# check whether to overwrite existing files
if overwrite:
LOGGER.info('Overwrite {}'.format(str(trg_path)))
trg_path.unlink()
if not overwrite:
LOGGER.info('{} already exists.'.format(trg_path))
return
# overwrite
LOGGER.info('Overwrite {}'.format(str(trg_path)))
trg_path.unlink()
# create a temporary file
tmp_path = _tmp_path(trg_path)
# read the source dataset
src_ds = gdal.Open(str(src_path))
LOGGER.info('Extract: {}, {}'.format(src_path.name, mask_path.name))
# source dataset spatial reference
src_sr = osr.SpatialReference()
src_sr.ImportFromWkt(src_ds.GetProjection())
# extract raster values by shapefile
gdal.Warp(str(tmp_path), str(src_path),
cutlineDSName=str(mask_path),
cropToCutline=True,
xRes=src_ds.GetGeoTransform()[1],
yRes=src_ds.GetGeoTransform()[5],
srcNodata=src_no_data,
dstNodata=trg_no_data)
# compress raster
compress_raster(tmp_path, trg_path)
# clear source dataset
del src_ds
LOGGER.info('Extract: {}, {}'.format(src_path.name, mask_path.name))
# checkt the type of the mask dataset
if mask_path.name.endswith('.shp'):
# clip raster values by shapefile
gdal.Warp(str(trg_path), str(src_path),
cutlineDSName=str(mask_path),
cropToCutline=True,
xRes=src_ds.GetGeoTransform()[1],
yRes=src_ds.GetGeoTransform()[5],
srcNodata=src_no_data,
dstNodata=trg_no_data)
def clip_raster(src_ds, mask_ds, trg_ds, overwrite=False, src_no_data=None,
trg_no_data=None):
"""Clip raster to extent of another raster.
Clip the extent of ``src_ds`` to the extent of ``mask_ds``. The clipped
raster is saved in ``trg_ds``.
Parameters
----------
src_ds : `str` or :py:class:`pathlib.Path`
The input raster to clip.
mask_ds : `str` or :py:class:`pathlib.Path`
The raster defining the extent of interest.
trg_ds : `str` or :py:class:`pathlib.Path`
The clipped raster dataset.
overwrite : `bool`, optional
Whether to overwrite ``trg_ds``, if it exists. The default is `False`.
src_no_data : `int` or `float`, optional
The value of NoData values in ``src_ds``. The default is `None`, which
means the value is read from ``src_ds``. If specified, values equal to
``src_no_data`` are masked as ``trg_no_data`` in ``trg_ds``.
trg_no_data : `int` or `float`, optional
The value to assign to NoData values in ``trg_ds``. The default is
`None`, which means conserving the NoData value of ``src_ds``.
"""
# convert path to source dataset and mask dataset to pathlib.Path object
src_path = pathlib.Path(src_ds)
mask_path = pathlib.Path(mask_ds)
# check whether the source dataset exists
if not src_path.exists():
LOGGER.info('{} does not exist.'.format(str(src_path)))
return
# check whether the mask exists
if not mask_path.exists():
LOGGER.info('{} does not exist.'.format(str(mask_path)))
return
# check whether the output datasets exists
trg_path = pathlib.Path(trg_ds)
if not trg_path.exists():
trg_path.parent.mkdir(parents=True, exist_ok=True)
else:
# mask is a raster dataset
mask_ds = gdal.Open(str(mask_path))
# check whether to overwrite existing files
if not overwrite:
LOGGER.info('{} already exists.'.format(trg_path))
return
# overwrite
LOGGER.info('Overwrite {}'.format(str(trg_path)))
trg_path.unlink()
# spatial extent of the mask: (x_min, x_max, y_min, y_max)
gt = mask_ds.GetGeoTransform()
extent = [gt[0], gt[0] + gt[1] * mask_ds.RasterXSize,
gt[3] + gt[5] * mask_ds.RasterYSize, gt[3]]
# mask is a raster dataset
mask_ds = gdal.Open(str(mask_path))
# mask dataset spatial reference
mask_sr = osr.SpatialReference()
mask_sr.ImportFromWkt(mask_ds.GetProjection())
# spatial extent of the mask: (x_min, x_max, y_min, y_max)
gt = mask_ds.GetGeoTransform()
extent = [gt[0], gt[0] + gt[1] * mask_ds.RasterXSize,
gt[3] + gt[5] * mask_ds.RasterYSize, gt[3]]
# coordinate transformation: from mask to source
crs_tr = osr.CoordinateTransformation(mask_sr, src_sr)
# mask dataset spatial reference
mask_sr = osr.SpatialReference()
mask_sr.ImportFromWkt(mask_ds.GetProjection())
# source dataset spatial reference
src_ds = gdal.Open(str(src_path))
src_sr = osr.SpatialReference()
src_sr.ImportFromWkt(src_ds.GetProjection())
# transform extent of mask to source coordinate system
# coordinate transformation: from mask to source
crs_tr = osr.CoordinateTransformation(mask_sr, src_sr)
# TransfromPoint expects input:
# - gdal >= 3.0: x, y, z = TransformPoint(y, x)
# - gdal < 3.0 : x, y, z = TransformPoint(x, y)
x_tl, y_tl, _ = crs_tr.TransformPoint(extent[0], extent[-1])
x_br, y_br, _ = crs_tr.TransformPoint(extent[1], extent[2])
# transform extent of mask to source coordinate system
# extent of the mask in the source reference coordinate system:
# (x_min, y_min, x_max, y_max)
extent = [x_tl, y_br, x_br, y_tl]
# TransfromPoint expects input:
# - gdal >= 3.0: x, y, z = TransformPoint(y, x)
# - gdal < 3.0 : x, y, z = TransformPoint(x, y)
x_tl, y_tl, _ = crs_tr.TransformPoint(extent[0], extent[-1])
x_br, y_br, _ = crs_tr.TransformPoint(extent[1], extent[2])
# extract values by mask
gdal.Warp(str(trg_path), str(src_path),
outputBounds=extent,
xRes=src_ds.GetGeoTransform()[1],
yRes=src_ds.GetGeoTransform()[5],
srcNodata=src_no_data,
dstNodata=trg_no_data)
# extent of the mask in the source reference coordinate system:
# (x_min, y_min, x_max, y_max)
extent = [x_tl, y_br, x_br, y_tl]
# create a temporary file
tmp_path = _tmp_path(trg_path)
# clip raster extent
LOGGER.info('Clipping: {}, Extent: (x_tl={:.2f}, y_br={:.2f}, x_br={:.2f},'
' y_tl={:.2f})'.format(src_path.name, *extent))
gdal.Warp(str(tmp_path), str(src_path),
outputBounds=extent,
xRes=src_ds.GetGeoTransform()[1],
yRes=src_ds.GetGeoTransform()[5],
srcNodata=src_no_data,
dstNodata=trg_no_data)
# compress raster dataset
compress_raster(tmp_path, trg_path)
# clear source and mask dataset
del src_ds, mask_ds
......@@ -2466,5 +2583,55 @@ def merge_tifs(trg_ds, tifs):
List of paths to the GeoTiffs to mosaic.
"""
LOGGER.info('Creating mosaic: {}'.format(trg_ds))
gdal.Warp(str(trg_ds), [str(tif) for tif in tifs])
# create mosaic
tmp_path = _tmp_path(trg_ds)
LOGGER.info('Create mosaic: {}'.format(trg_ds))
gdal.Warp(str(tmp_path), [str(tif) for tif in tifs])
# compress raster
compress_raster(tmp_path, trg_ds)
def compress_raster(src_ds, trg_ds):
"""Compress a raster dataset using :py:func:`gdal.Translate`.
Parameters
----------
src_ds : `str` or :py:class:`pathlib.Path`
Path to the raster dataset to compress.
trg_ds : `str` or :py:class:`pathlib.Path`
Path to save the compressed raster dataset.
"""
# check if the raster dataset exists
src_ds = pathlib.Path(src_ds)
if not src_ds.exists():
LOGGER.info('{} does not exist'.format(src_ds))
return
# compress raster dataset
LOGGER.info('Compressing: {}'.format(trg_ds))
gdal.Translate(str(trg_ds), str(src_ds), creationOptions=[
'COMPRESS=DEFLATE', 'PREDICTOR=1', 'TILED=YES'])
# remove uncompressed raster dataset
src_ds.unlink()
def _tmp_path(path):
"""Create a temporary filename from a given path.
Parameters
----------
path : `str` or :py:class:`pathlib.Path`
Path to a file for which a temporary path is required.
Returns
-------
tmp_path : `str` or :py:class:`pathlib.Path`
Path to the temporary file.
"""
path = pathlib.Path(path)
return pathlib.Path(str(path).replace(path.suffix,
'_tmp{}'.format(path.suffix)))
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment