utils.py 99.5 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
"""Utility functions mainly for image IO and reshaping.

License
-------

    Copyright (c) 2020 Daniel Frisinghelli

    This source code is licensed under the GNU General Public License v3.

    See the LICENSE file in the repository's root directory.

"""
Frisinghelli Daniel's avatar
Frisinghelli Daniel committed
13
14

# !/usr/bin/env python
15
16
17
# -*- coding: utf-8 -*-

# builtins
18
import os
19
import re
20
import shutil
21
import logging
22
23
24
import pathlib
import tarfile
import zipfile
25
import datetime
26
import warnings
27
import platform
28
import subprocess
29
import xml.etree.ElementTree as ET
30

31
# externals
32
import torch
33
import numpy as np
34
import pandas as pd
35
import xarray as xr
36
import rasterio
37
from osgeo import gdal, ogr, osr
38

39
40
41
# locals
from pysegcnn.core.constants import Gdal2Numpy

42
43
44
# module level logger
LOGGER = logging.getLogger(__name__)

45
46
47
# suffixes for radiometrically calibrated scenes
SUFFIXES = ['toa_ref', 'toa_rad', 'toa_brt']

48
49
50
# maximum number of filename characters on Windows
MAX_FILENAME_CHARS_WINDOWS = 260

51
52
53
54
# file suffixes for hierarchical data format
HIERARCHICAL_DATA_FORMAT = ['.h4', '.hdf', '.hdf4', '.hdf5', '.he2', '.h5',
                            '.he5', 'nc']

55

Frisinghelli Daniel's avatar
Frisinghelli Daniel committed
56
def img2np(path, tile_size=None, tile=None, pad=False, cval=0):
57
    r"""Read an image to a :py:class:`numpy.ndarray`.
Frisinghelli Daniel's avatar
Frisinghelli Daniel committed
58
59

    If ``tile_size`` is not `None`, the input image is divided into square
60
61
62
    tiles of size ``(tile_size, tile_size)``. If the image is not evenly
    divisible and ``pad=False``, a ``ValueError`` is raised. However, if
    ``pad=True``, center padding with constant value ``cval`` is applied.
Frisinghelli Daniel's avatar
Frisinghelli Daniel committed
63
64
65

    The tiling works as follows:

66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
        +-----------+-----------+-----------+-----------+
        |           |           |           |           |
        |  tile_00  |  tile_01  |    ...    |  tile_0n  |
        |           |           |           |           |
        +-----------+-----------+-----------+-----------+
        |           |           |           |           |
        |  tile_10  |  tile_11  |    ...    |  tile_1n  |
        |           |           |           |           |
        +-----------+-----------+-----------+-----------+
        |           |           |           |           |
        |    ...    |    ...    |    ...    |    ...    |
        |           |           |           |           |
        +-----------+-----------+-----------+-----------+
        |           |           |           |           |
        |  tile_m0  |  tile_m1  |    ...    |  tile_mn  |
        |           |           |           |           |
        +-----------+-----------+-----------+-----------+

    where :math:`m = n`. Each tile has its id, which starts at `0` in the
    topleft corner of the input image, i.e. `tile_00` has :math:`id=0`, and
    increases along the width axis, i.e. `tile_0n` has :math:`id=n`, `tile_10`
    has :math:`id=n+1`, ..., `tile_mn` has :math:`id=(m \\cdot n) - 1`.

    If ``tile`` is an integer, only the tile with ``id=tile`` is returned.
Frisinghelli Daniel's avatar
Frisinghelli Daniel committed
90
91
92

    Parameters
    ----------
93
    path : `str` or :py:class:`pathlib.Path` or :py:class:`numpy.ndarray`
Frisinghelli Daniel's avatar
Frisinghelli Daniel committed
94
95
        The image to read.
    tile_size : `None` or `int`, optional
96
        The size of a tile. The default is `None`.
Frisinghelli Daniel's avatar
Frisinghelli Daniel committed
97
    tile : `int`, optional
98
        The tile id. The default is `None`.
Frisinghelli Daniel's avatar
Frisinghelli Daniel committed
99
    pad : `bool`, optional
100
        Whether to center pad the input image. The default is `False`.
Frisinghelli Daniel's avatar
Frisinghelli Daniel committed
101
    cval : `float`, optional
102
        The constant padding value. The default is `0`.
Frisinghelli Daniel's avatar
Frisinghelli Daniel committed
103
104
105
106
107
108

    Raises
    ------
    FileNotFoundError
        Raised if ``path`` is a path that does not exist.
    TypeError
109
        Raised if ``path`` is not `str` or `None` or :py:class:`numpy.ndarray`.
Frisinghelli Daniel's avatar
Frisinghelli Daniel committed
110
111
112

    Returns
    -------
113
    image : :py:class:`numpy.ndarray`
Frisinghelli Daniel's avatar
Frisinghelli Daniel committed
114
        The image array. The output shape is:
115
116
117
            - `(tiles, bands, tile_size, tile_size)` if ``tile_size`` is not
            `None`. If the image does only have one band,
            `(tiles, tile_size, tile_size)`
Frisinghelli Daniel's avatar
Frisinghelli Daniel committed
118

119
120
            - `(bands, height, width)` if ``tile_size=None``. If the image does
            only have one band, `(height, width)`.
Frisinghelli Daniel's avatar
Frisinghelli Daniel committed
121
122

    """
123
    # check the type of path
124
125
    if isinstance(path, str) or isinstance(path, pathlib.Path):

126
127
128
129
130
131
132
        # check if the path is a url
        if str(path).startswith('http'):
            # gdal virtual file system for url paths
            img = gdal.Open('/vsicurl/{}'.format(str(path)))
        else:
            # image is stored in a file system
            img = gdal.Open(str(path))
133
134
135
136
137
138
139
140

        # number of bands
        bands = img.RasterCount

        # spatial size
        height = img.RasterYSize
        width = img.RasterXSize

141
142
143
144
145
        # data type
        dtype = getattr(Gdal2Numpy,
                        gdal.GetDataTypeName(img.GetRasterBand(1).DataType))
        dtype = dtype.value

146
    elif path is None:
147
        LOGGER.warning('Path is of NoneType, returning.')
148
        return
149
150
151

    # accept numpy arrays as input
    elif isinstance(path, np.ndarray):
Frisinghelli Daniel's avatar
Frisinghelli Daniel committed
152
        # input array
153
        img = path
Frisinghelli Daniel's avatar
Frisinghelli Daniel committed
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168

        # check the dimensions of the input array
        if img.ndim > 2:
            bands = img.shape[0]
            height = img.shape[1]
            width = img.shape[2]
        else:
            bands = 1
            height = img.shape[0]
            width = img.shape[1]

            # expand input array to fit band dimension
            img = np.expand_dims(img, axis=0)

        # input array data type
169
        dtype = img.dtype
170
171

    else:
172
        raise TypeError('Input of type {} not supported'.format(type(path)))
173
174
175
176

    # check whether to read the image in tiles
    if tile_size is None:

177
178
179
        # number of tiles
        ntiles = 1

180
        # create empty numpy array to store whole image
181
        image = np.empty(shape=(ntiles, bands, height, width), dtype=dtype)
182
183

        # iterate over the bands of the image
184
        for b in range(bands):
185
186

            # read the data of band b
187
188
189
190
191
            if isinstance(img, np.ndarray):
                data = img[b, ...]
            else:
                band = img.GetRasterBand(b+1)
                data = band.ReadAsArray()
192
193

            # append band b to numpy image array
194
            image[0, b, :, :] = data
195
196
197
198
199

    else:

        # check whether the image is evenly divisible in square tiles
        # of size (tile_size x tile_size)
200
        ntiles, padding = is_divisible((height, width), tile_size, pad)
201
202

        # image size after padding
203
204
        y_size = height + padding[0] + padding[2]
        x_size = width + padding[1] + padding[3]
205

206
207
208
209
210
211
        # print progress
        LOGGER.debug('Image size: {}'.format((height, width)))
        LOGGER.debug('Dividing image into {} tiles of size {} ...'
                     .format(ntiles, (tile_size, tile_size)))
        LOGGER.debug('Padding image (b, l, t, r): {}'.format(tuple(padding)))
        LOGGER.debug('Padded image size: {}'.format((y_size, x_size)))
212

213
        # get the indices of the top left corner for each tile
214
        topleft = tile_topleft_corner((y_size, x_size), tile_size)
215

216
        # whether to read all tiles or a single tile
217
        if tile is not None:
218
            ntiles = 1
219
220

        # create empty numpy array to store the tiles
221
222
        image = np.ones((ntiles, bands, tile_size, tile_size),
                        dtype=dtype) * cval
223
224
225
226
227
228
229
230

        # iterate over the topleft corners of the tiles
        for k, corner in topleft.items():

            # in case a single tile is required, skip the rest of the tiles
            if tile is not None:
                if k != tile:
                    continue
231

232
233
                # set the key to 0 for correct array indexing when reading
                # a single tile from the image
234
                LOGGER.debug('Processing tile {} ...'.format(k))
235
                k = 0
236
237
238
            else:
                LOGGER.debug('Creating tile {} with top-left corner {} ...'
                             .format(k, corner))
239

240
241
242
243
244
            # calculate shift between padded and original image
            row = corner[0] - padding[2] if corner[0] > 0 else corner[0]
            col = corner[1] - padding[1] if corner[1] > 0 else corner[1]
            y_tl = row + padding[2] if row == 0 else 0
            x_tl = col + padding[1] if col == 0 else 0
245
246

            # iterate over the bands of the image
247
            for b in range(bands):
248
249
250

                # check if the current tile extend exists in the image
                nrows, ncols = check_tile_extend(
251
                    (height, width), (row, col), tile_size)
252

253
254
255
256
257
258
                # read the current tile from band b
                if isinstance(img, np.ndarray):
                    data = img[b, row:row+nrows, col:col+ncols]
                else:
                    band = img.GetRasterBand(b+1)
                    data = band.ReadAsArray(col, row, ncols, nrows)
259
260

                # append band b to numpy image array
261
262
                image[k, b, y_tl:nrows, x_tl:ncols] = data[0:(nrows - y_tl),
                                                           0:(ncols - x_tl)]
263
264

    # check if there are more than 1 band
265
    if not bands > 1:
266
267
268
269
270
        image = image.squeeze(axis=1)

    # check if there are more than 1 tile
    if not ntiles > 1:
        image = image.squeeze(axis=0)
271
272
273
274
275
276
277
278

    # close tif file
    del img

    # return the image
    return image


279
def np2tif(array, filename, no_data=None, names=None, src_ds=None, epsg=None,
280
           geotransform=None, overwrite=False, compress=False):
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
    """Save a :py:class`numpy.ndarray` as a GeoTIFF.

    The spatial coordinate reference system can be specified in two ways:
        - by providing a source dataset (``src_ds``) from which the spatial
        reference is inherited
        - by providing the `EPSG`_ code (``epsg``) of the target coordinate
        reference system together with a tuple (``geotransform``) describing
        the spatial extent of the ``array``.

    Parameters
    ----------
    array : :py:class:`numpy.ndarray`
        The array to save as GeoTIFF. Two-dimensional arrays with shape
        (height, width) and three-dimensional arrays with shape
        (bands, height, width) are supported.
    filename : `str` or :py:class:`pathlib.Path`
        The filename of the GeoTIFF.
298
299
300
    no_data : `None` or `int` or `float`
        The NoData value for each band in the output raster. The default is
        `None`, which means no NoData value is specified.
301
302
303
304
    names : `list` [`str`], optional
        The names of the bands in ``array`` in order. The default is `None`.
        If `None`, no band description is added.
    src_ds : :py:class:`osgeo.gdal.Dataset`, optional
305
306
307
308
309
        The source dataset from which the spatial reference is inherited. The
        default is `None`.
    epsg : `int`, optional
        The EPSG code of the target coordinate reference system. The default is
        `None`.
310
    geotransform : `tuple`, optional
311
312
313
        A tuple with six elements of the form,
        (x_top_left, x_res, x_shift, y_top_left, -y_res, y_shift), describing
        the spatial reference.
314
315
    overwrite : `bool`, optional
        Whether to overwrite ``filename`` if it exists. The default is `False`.
316
317
    compress : `bool`, optional
        Whether to compress the GeoTIFF. The default is `False`.
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351

    .. _EPSG:
        https://epsg.io/

    Raises
    ------
    ValueError
        Raised if ``filename`` is does not end with a file suffix, e.g. ".tif".

        Raised if not both ``epsg`` and ``geotransform`` are specified when
        ``src_ds=None``.

    """
    # create the GeoTIFF driver
    driver = gdal.GetDriverByName('GTiff')

    # shape of the input array
    if array.ndim > 2:
        # three-dimensional array
        bands, height, width = array.shape
    else:
        # two-dimensional array: expand to three-dimensions
        bands, height, width = (1,) + array.shape
        array = np.expand_dims(array, 0)

    # data type
    dtype = gdal.GetDataTypeByName(array.dtype.name)

    # check output filename
    filename = pathlib.Path(filename)
    if not filename.suffix:
        raise ValueError('{} is not a file.'.format(filename))
    filename = pathlib.Path(str(filename).replace(filename.suffix, '.tif'))

352
353
354
355
356
357
358
359
    # check if file exists
    if filename.exists() and not overwrite:
        LOGGER.info('{} already exists.'.format(filename))
        return

    # create temporary file
    tmp_path = _tmp_path(filename)

360
    # create output GeoTIFF
361
    tmp_ds = driver.Create(str(tmp_path), width, height, bands, dtype)
362
363
364

    # iterate over the number of bands and write to output file
    for b in range(bands):
365
        trg_band = tmp_ds.GetRasterBand(b + 1)
366
        trg_band.WriteArray(array[b, ...])
367
368
369
370

        # set the band description, if specified
        if names is not None:
            trg_band.SetDescription(names[b])
371
372
373
374

        # set the NoData value, if specified
        if no_data is not None:
            trg_band.SetNoDataValue(no_data)
375
376
377
378
379
        trg_band.FlushCache()

    # set spatial reference
    if src_ds is not None:
        # inherit spatial reference from source dataset
380
381
        tmp_ds.SetProjection(src_ds.GetProjection())
        tmp_ds.SetGeoTransform(src_ds.GetGeoTransform())
382
383
384
385
386
    else:
        # check whether both the epsg code and the geotransform tuple are
        # specified
        if epsg is not None and geotransform is not None:
            # create the spatial reference from the epsg code
387
388
            sr = osr.SpatialReference()
            sr.ImportFromEPSG(epsg)
389
390

            # set spatial reference from epsg
Frisinghelli Daniel's avatar
Frisinghelli Daniel committed
391
            tmp_ds.SetProjection(sr.ExportToWkt())
392
            tmp_ds.SetGeoTransform(geotransform)
393
394
395
396
397
        else:
            raise ValueError('Both "epsg" and "geotransform" required to set '
                             'spatial reference if "src_ds" is None.')

    # clear dataset
398
    del trg_band, tmp_ds
399

400
    # compress raster
401
    compress_raster(tmp_path, filename, compress=compress)
402
403


404
405
def read_hdf4(path):
    """Read a file in Hierarchical Data Format 4 (HDF4).
406
407
408
409

    Parameters
    ----------
    path : `str` or py:class:`pathlib.Path`
410
        The path to the hdf4 file to read.
411
412
413
414
415
416
417
418

    Raises
    ------
    ValueError
        Raised if ``path`` is not an hdf file.

    Returns
    -------
419
420
    hdf_ds : :py:class:`xarray.Dataset`
        The HDF4 file as :py:class:`xarray.Dataset`.
421
422
423

    """
    # check if the path points to an hdf file
424
425
    path = pathlib.Path(path)
    if path.suffix not in HIERARCHICAL_DATA_FORMAT:
426
427
        raise ValueError('{} is not an hdf file.'.format(path))

428
429
    # read the hdf dataset: get the different subdatasets
    sub_datasets = gdal.Open(str(path)).GetSubDatasets()
430

431
432
433
434
435
    # iterate over the different subsets
    subsets = []
    for filename in sub_datasets:
        ds_name = filename[0].split(':')[-1]
        subsets.append(xr.open_rasterio(filename[0]).to_dataset(name=ds_name))
436

437
438
    # merge subsets to single xarray dataset
    hdf_ds = xr.merge(subsets)
439

440
    return hdf_ds.squeeze()
441
442


443
def hdf2tifs(path, outpath, overwrite=False, create_stack=True, **kwargs):
444
445
446
447
448
449
450
451
    """Convert a file in Hierarchical Data Format (HDF) to GeoTIFFs.

    The GeoTIFFs share the same filename as ``path``, appended by the name of
    the respective subdatasets.

    The default (``outpath=None``) is to save the GeoTIFFs in a directory named
    after the filename of ``path``, within the parent directory of ``path``.

452
453
    The output GeoTIFFs are compressed by default.

454
455
456
457
    Parameters
    ----------
    path : `str` or py:class:`pathlib.Path`
        The path to the hdf file to convert.
458
459
    outpath : `str` or py:class:`pathlib.Path`
        Path to save the GeoTIFF files.
460
461
462
463
464
465
466
467
468
469
470
471
    overwrite : `bool`, optional
        Whether to overwrite existing GeoTIFF files in ``outpath``. The default
        is `False`.
    create_stack : `bool`, optional
        Whether to create a GeoTIFF stack of all the subdatasets in ``path``.
        The default is `True`.
    **kwargs :
        Additional keyword arguments passed to :py:func:`gdal.Translate`.

    """
    # check if the path points to an hdf file
    path = pathlib.Path(path)
472
    if path.suffix not in HIERARCHICAL_DATA_FORMAT:
473
474
        raise ValueError('{} is not an hdf file.'.format(path))

475
476
    # create the output directory for the GeoTiffs
    outpath = pathlib.Path(outpath)
477
478
479
480
    outpath = outpath.joinpath(path.stem.replace('.', '_'))

    # check whether the output path exists
    if not outpath.exists():
481
        LOGGER.info('mkdir {}'.format(outpath))
482
483
484
485
486
487
488
489
490
        outpath.mkdir(parents=True, exist_ok=True)

    # check whether the output path contains GeoTIFF files
    tifs = [f for f in outpath.iterdir() if f.suffix in ['.tif', '.TIF']]

    # check whether to overwrite existing files
    if tifs:
        LOGGER.info('The following files already exist in {}'
                    .format(str(outpath)))
491
492
        LOGGER.info(('\n ' + (len(__name__) + 1) * ' ').join(
            ['{}'.format(str(tif.name)) for tif in tifs]))
493
494
495
496
497
498
499
500
501
502
        if not overwrite:
            # return if not overwriting
            LOGGER.info('Aborting...')
            return

        # remove existing files and prepare to overwrite
        LOGGER.info('Overwrite {}'.format(str(outpath)))
        for tif in tifs:
            tif.unlink()

503
504
505
506
507
508
509
510
511
512
513
514
515
516
    # header file
    hdr = pathlib.Path('.'.join([str(path), 'hdr']))

    # check if header file exists and contains projection
    wkt = None
    if hdr.exists():
        LOGGER.info('Found header file: {}'.format(hdr))
        with open(hdr, 'r') as file:
            # search for WKT-projection string
            content = file.read()
            wkt = re.search('PROJCS[^}]*', content)
        if wkt is not None:
            wkt = wkt[0]

517
518
519
    # read the hdf dataset
    hdf = gdal.Open(str(path)).GetSubDatasets()

520
521
    # check if the dataset is not empty
    if hdf:
522

523
524
525
526
        # iterate over the different subdatasets in the hdf
        for ds in hdf:

            # name of the current subdataset
527
            hdf_ds = gdal.Open(ds[0])
528
            name = ds[0].split(':')[-1].lower()
529

530
531
532
            # filename of the GeoTIFF
            tif_name = outpath.joinpath(
                path.name.replace(path.suffix, '_{}.tif'.format(name)))
533

534
            # convert hdf subdataset to GeoTIFF
535
            LOGGER.info('Converting: {}'.format(tif_name.name))
536
537
538
539
            gdal.Translate(str(tif_name), hdf_ds, outputSRS=wkt,
                           creationOptions=[
                               'COMPRESS=DEFLATE', 'PREDICTOR=1', 'TILED=YES'],
                           **kwargs)
540

541
542
543
            # set metadata field
            tif_ds = gdal.Open(str(tif_name))
            tif_ds.SetMetadata(hdf_ds.GetMetadata())
544

545
546
            del tif_ds

547
548
549
550
551
        # check whether to create a GeoTIFF stack
        if create_stack:
            # filename for the GeoTIFF stack
            stk = tif_name.parent.joinpath(
                path.name.replace(path.suffix, '.tif'))
552
            LOGGER.info('Creating stack: {}'.format(stk))
553

554
            # generated GeoTIFF files
555
556
            tifs = sorted([str(f) for f in outpath.iterdir() if f.suffix in
                           ['.tif', '.TIF']])
557

558
559
            # create stacked GeoTIFF
            stack_tifs(str(stk), tifs)
560
561
562
563
564
565
566

    return


def stack_tifs(filename, tifs, **kwargs):
    """Create a stacked GeoTIFF from a list of single-band GeoTIFFs.

567
568
    The output GeoTIFF stack is compressed by default.

569
570
571
572
573
574
575
576
577
578
579
580
581
582
    Parameters
    ----------
    filename : `str` or py:class:`pathlib.Path`
        The filename of the stacked GeoTIFF ending with `.tif`.
    tifs : `list` [`str`]
        The list of the paths to the GeoTIFF files to stack.
    **kwargs :
        Additional keyword arguments passed to :py:func:`gdal.Translate`.

    """
    # build virtual raster dataset
    vrt = str(filename).replace('.tif', '.vrt')
    vrt_ds = gdal.BuildVRT(str(vrt), tifs, separate=True)

583
584
585
586
587
588
589
590
591
592
    # set band descriptions
    for i, tif in enumerate(tifs):

        # get the description of the band in the tif file
        tif_ds = gdal.Open(str(tif))
        band_name = tif_ds.GetRasterBand(1).GetDescription()

        # set the description of the band in the vrt file
        vrt_ds.GetRasterBand(i + 1).SetDescription(band_name)

593
    # create GeoTIFF stack
594
    gdal.PushErrorHandler('CPLQuietErrorHandler')
595
596
    gdal.Translate(str(filename), vrt_ds, creationOptions=[
        'COMPRESS=DEFLATE', 'PREDICTOR=1', 'TILED=YES'], **kwargs)
597

598
599
    del vrt_ds

600
601
    return

602

603
def is_divisible(img_size, tile_size, pad=False):
Frisinghelli Daniel's avatar
Frisinghelli Daniel committed
604
605
606
607
608
609
610
611
612
    """Check whether an image is evenly divisible into square tiles.

    Parameters
    ----------
    img_size : `tuple`
        The image size (height, width).
    tile_size : `int`
        The size of the tile.
    pad : `bool`, optional
613
        Whether to center pad the input image. The default is `False`.
Frisinghelli Daniel's avatar
Frisinghelli Daniel committed
614
615
616
617

    Raises
    ------
    ValueError
618
        Raised if the image is not evenly divisible and ``pad=False``.
Frisinghelli Daniel's avatar
Frisinghelli Daniel committed
619
620
621
622
623
624
625
626
627

    Returns
    -------
    ntiles : `int`
        The number of tiles fitting ``img_size``.
    padding : `tuple`
        The amount of padding (bottom, left, top, right).

    """
628
629
630
631
    # calculate number of pixels per tile
    pixels_per_tile = tile_size ** 2

    # check whether the image is evenly divisible in square tiles of size
632
    # (tile_size, tile_size)
633
634
    ntiles = ((img_size[0] * img_size[1]) / pixels_per_tile)

635
636
    # if it is evenly divisible, no padding is required
    if ntiles.is_integer():
Frisinghelli Daniel's avatar
Frisinghelli Daniel committed
637
        padding = 4 * (0,)
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658

    if not ntiles.is_integer() and not pad:
        raise ValueError('Image of size {} not evenly divisible in ({}, {}) '
                         'tiles.'.format(img_size, tile_size, tile_size))

    if not ntiles.is_integer() and pad:

        # calculate the desired image size, i.e. the smallest size that is
        # evenly divisible into square tiles of size (tile_size, tile_size)
        h_new = int(np.ceil(img_size[0] / tile_size) * tile_size)
        w_new = int(np.ceil(img_size[1] / tile_size) * tile_size)

        # calculate center offset
        dh = h_new - img_size[0]
        dw = w_new - img_size[1]

        # check whether the center offsets are even or odd

        # in case both offsets are even, the padding is symmetric on both the
        # bottom/top and left/right
        if not dh % 2 and not dw % 2:
Frisinghelli Daniel's avatar
Frisinghelli Daniel committed
659
            padding = (dh // 2, dw // 2, dh // 2, dw // 2)
660
661
662
663

        # in case only one offset is even, the padding is symmetric along the
        # even offset and asymmetric along the odd offset
        if not dh % 2 and dw % 2:
Frisinghelli Daniel's avatar
Frisinghelli Daniel committed
664
            padding = (dh // 2, dw // 2, dh // 2, dw // 2 + 1)
665
        if dh % 2 and not dw % 2:
Frisinghelli Daniel's avatar
Frisinghelli Daniel committed
666
            padding = (dh // 2, dw // 2, dh // 2 + 1, dw // 2)
667
668
669
670

        # in case of offsets are odd, the padding is asymmetric on both the
        # bottom/top and left/right
        if dh % 2 and dw % 2:
Frisinghelli Daniel's avatar
Frisinghelli Daniel committed
671
            padding = (dh // 2, dw // 2, dh // 2 + 1, dw // 2 + 1)
672
673
674
675

        # calculate number of tiles on padded image
        ntiles = (h_new * w_new) / (tile_size ** 2)

Frisinghelli Daniel's avatar
Frisinghelli Daniel committed
676
    return int(ntiles), padding
677
678
679


def check_tile_extend(img_size, topleft, tile_size):
Frisinghelli Daniel's avatar
Frisinghelli Daniel committed
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
    """Check if a tile exceeds the image size.

    Parameters
    ----------
    img_size : `tuple`
        The image size (height, width).
    topleft : `tuple`
        The topleft corner of the tile (y, x).
    tile_size : `int`
        The size of the tile.

    Returns
    -------
    nrows : `int`
        Number of rows of the tile within the image.
695
    ncols : `int`
Frisinghelli Daniel's avatar
Frisinghelli Daniel committed
696
        Number of columns of the tile within the image.
697

Frisinghelli Daniel's avatar
Frisinghelli Daniel committed
698
    """
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
    # check if the tile is within both the rows and the columns of the image
    if (topleft[0] + tile_size < img_size[0] and
            topleft[1] + tile_size < img_size[1]):

        # both rows and columns can be equal to the tile size
        nrows, ncols = tile_size, tile_size

    # check if the tile exceeds one of rows or columns of the image
    if (topleft[0] + tile_size < img_size[0] and not
            topleft[1] + tile_size < img_size[1]):

        # adjust columns to remaining columns in the original image
        nrows, ncols = tile_size, img_size[1] - topleft[1]

    if (topleft[1] + tile_size < img_size[1] and not
            topleft[0] + tile_size < img_size[0]):

        # adjust rows to remaining rows in the original image
        nrows, ncols = img_size[0] - topleft[0], tile_size

    # check if the tile exceeds both the rows and the columns of the image
    if (not topleft[0] + tile_size < img_size[0] and not
            topleft[1] + tile_size < img_size[1]):

        # adjust both rows and columns to the remaining ones
        nrows, ncols = img_size[0] - topleft[0], img_size[1] - topleft[1]
725

726
    return nrows, ncols
727

Frisinghelli Daniel's avatar
Frisinghelli Daniel committed
728

729
def tile_topleft_corner(img_size, tile_size):
Frisinghelli Daniel's avatar
Frisinghelli Daniel committed
730
731
732
733
734
735
736
737
738
739
740
741
742
    """Return the topleft corners of the tiles in the image.

    Parameters
    ----------
    img_size : `tuple`
        The image size (height, width).
    tile_size : `int`
        The size of the tile.

    Returns
    -------
    indices : `dict`
        The keys of ``indices`` are the tile ids (`int`) and the values are the
743
        topleft corners (y, x) of the tiles.
744

Frisinghelli Daniel's avatar
Frisinghelli Daniel committed
745
    """
746
747
748
    # check if the image is divisible into square tiles of size
    # (tile_size, tile_size)
    _, _ = is_divisible(img_size, tile_size, pad=False)
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765

    # number of tiles along the width (columns) of the image
    ntiles_columns = int(img_size[1] / tile_size)

    # number of tiles along the height (rows) of the image
    ntiles_rows = int(img_size[0] / tile_size)

    # get the indices of the top left corner for each tile
    indices = {}
    k = 0
    for i in range(ntiles_rows):
        for j in range(ntiles_columns):
            indices[k] = (i * tile_size, j * tile_size)
            k += 1

    return indices

766

767
def reconstruct_scene(tiles):
Frisinghelli Daniel's avatar
Frisinghelli Daniel committed
768
769
770
771
    """Reconstruct a tiled image.

    Parameters
    ----------
772
    tiles : :py:class:`numpy.ndarray`
773
774
        The tiled image, shape: `(tiles, bands, tile_size, tile_size)` or
        `(tiles, tile_size, tile_size)`.
Frisinghelli Daniel's avatar
Frisinghelli Daniel committed
775
776
777

    Returns
    -------
778
779
    image : :py:class:`numpy.ndarray`
        The reconstructed image, shape: `(bands, height, width)`.
780

Frisinghelli Daniel's avatar
Frisinghelli Daniel committed
781
    """
Frisinghelli Daniel's avatar
Frisinghelli Daniel committed
782
    # check the dimensions of the input array
783
784
785
    if tiles.ndim > 3:
        nbands = tiles.shape[1]
        tile_size = tiles.shape[2]
786
787
    else:
        nbands = 1
788
        tile_size = tiles.shape[1]
789
790

    # calculate image size
791
    img_size = 2 * (int(np.sqrt(tiles.shape[0]) * tile_size),)
792
793
794
795
796

    # calculate the topleft corners of the tiles
    topleft = tile_topleft_corner(img_size, tile_size)

    # iterate over the tiles
797
    scene = np.zeros(shape=(nbands,) + img_size, dtype=tiles.dtype)
798
    for t in range(tiles.shape[0]):
799
800
        scene[...,
              topleft[t][0]: topleft[t][0] + tile_size,
801
              topleft[t][1]: topleft[t][1] + tile_size] = tiles[t, ...]
802
803
804
805

    return scene.squeeze()


806
def accuracy_function(outputs, labels):
Frisinghelli Daniel's avatar
Frisinghelli Daniel committed
807
808
809
810
    """Calculate prediction accuracy.

    Parameters
    ----------
811
    outputs : :py:class:`torch.Tensor` or `array_like`
Frisinghelli Daniel's avatar
Frisinghelli Daniel committed
812
        The model prediction.
813
    labels : :py:class:`torch.Tensor` or `array_like`
Frisinghelli Daniel's avatar
Frisinghelli Daniel committed
814
815
816
817
818
819
820
821
        The ground truth.

    Returns
    -------
    accuracy : `float`
        Mean prediction accuracy.

    """
822
823
824
825
826
827
    if isinstance(outputs, torch.Tensor):
        return (outputs == labels).float().mean().item()
    else:
        return (np.asarray(outputs) == np.asarray(labels)).mean().item()


828
def parse_landsat_scene(scene_id):
Frisinghelli Daniel's avatar
Frisinghelli Daniel committed
829
830
831
832
833
834
835
836
837
838
839
840
    """Parse a Landsat scene identifier.

    Parameters
    ----------
    scene_id : `str`
        A Landsat scene identifier.

    Returns
    -------
    scene : `dict` or `None`
        A dictionary containing scene metadata. If `None`, ``scene_id`` is not
        a valid Landsat scene identifier.
841

Frisinghelli Daniel's avatar
Frisinghelli Daniel committed
842
    """
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
    # Landsat Collection 1 naming convention in regular expression
    sensor = 'L[COTEM]0[1-8]_'
    level = 'L[0-9][A-Z][A-Z]_'
    swath = '[0-2][0-9]{2}[0-1][0-9]{2}_'
    date = '[0-9]{4}[0-1][0-9][0-3][0-9]_'
    doy = '[0-9]{4}[0-3][0-9]{2}'
    collection = '0[0-9]_'
    category = '[A-Z]([A-Z]|[0-9])'

    # Landsat Collection 1 naming
    C1 = (sensor + level + swath + date + date + collection + category)
    Landsat_C1 = re.compile(C1)

    # Landsat naming convention before Collections
    C0 = (sensor.replace('_', '').replace('0', '') + swath.replace('_', '') +
          doy + '[A-Z]{3}' + '[0-9]{2}')
    Landsat_C0 = re.compile(C0)

861
862
863
864
865
866
867
868
869
870
    # mapping from sensor id to sensors
    lsensors = {'E': 'Enhanced Thematic Mapper Plus',
                'T': 'Thematic Mapper',
                'M': 'Multispectral Scanner'}
    l8sensors = {'C': 'Operational Land Imager (OLI) & Thermal Infrared Sensor'
                      ' (TIRS)',
                 'O': 'Operational Land Imager (OLI)',
                 'T': 'Thermal Infrared Sensor (TIRS)',
                 }

871
    # whether a scene identifier matches the Landsat naming convention
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
    scene = {}
    if Landsat_C0.search(scene_id):

        # the match of the regular expression
        match = Landsat_C0.search(scene_id)[0]

        # the metadata of the scene identifier
        scene['id'] = match
        scene['satellite'] = 'Landsat {}'.format(match[2])
        if int(match[2]) > 7:
            scene['sensor'] = l8sensors[match[1]]
        else:
            scene['sensor'] = lsensors[match[1]]
        scene['path'] = match[3:6]
        scene['row'] = match[6:9]
        scene['date'] = doy2date(match[9:13], match[13:16])
        scene['gsi'] = match[16:19]
        scene['version'] = match[19:]

    elif Landsat_C1.search(scene_id):

        # the match of the regular expression
        match = Landsat_C1.search(scene_id)[0]

        # split scene into respective parts
        parts = match.split('_')

        # the metadata of the scene identifier
        scene['id'] = match
        scene['satellite'] = 'Landsat {}'.format(parts[0][2:])
        if int(parts[0][3]) > 7:
            scene['sensor'] = l8sensors[parts[0][1]]
        else:
            scene['sensor'] = lsensors[parts[0][1]]
        scene['path'] = parts[2][0:3]
        scene['row'] = parts[2][3:]
        scene['date'] = datetime.datetime.strptime(parts[3], '%Y%m%d')
        scene['collection'] = int(parts[5])
        scene['version'] = parts[6]
911

912
913
    else:
        scene = None
914

915
    return scene
916

917
918

def parse_sentinel2_scene(scene_id):
Frisinghelli Daniel's avatar
Frisinghelli Daniel committed
919
920
921
922
923
924
    """Parse a Sentinel-2 scene identifier.

    Parameters
    ----------
    scene_id : `str`
        A Sentinel-2 scene identifier.
925

Frisinghelli Daniel's avatar
Frisinghelli Daniel committed
926
927
928
929
930
931
932
    Returns
    -------
    scene : `dict` or `None`
        A dictionary containing scene metadata. If `None`, ``scene_id`` is not
        a valid Sentinel-2 scene identifier.

    """
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
    # Sentinel 2 Level-1C products naming convention after 6th December 2016
    mission = 'S2[A-B]_'
    level = 'MSIL1C_'
    date = '[0-9]{4}[0-1][0-9][0-3][0-9]'
    time = 'T[0-2][0-9][0-5][0-9][0-5][0-9]_'
    processing = 'N[0-9]{4}_'
    orbit = 'R[0-1][0-9]{2}_'
    tile = 'T[0-9]{2}[A-Z]{3}_'
    level_1C = (mission + level + date + time + processing +
                orbit + tile + date + time.replace('_', ''))
    S2_L1C_New = re.compile(level_1C)

    # Sentinel 2 Level-1C products naming convention before 6th December 2016
    file_class = '[A-Z]{4}_'
    file_category = '[A-Z]{3}_'
948
    file_semantic = '(MSI(L[0-1]([ABCP]|_)|CTI)|USER2A)_'
949
950
    site = '[A-Z_]{4}_'
    S2_L1C_Old = re.compile(mission + file_class + file_category +
951
952
953
954
955
956
957
958
959
960
961
                            file_semantic + site + date + time + orbit + 'V' +
                            date + time + date + time.replace('_', ''))

    # Sentinel 2 granule naming convention before 6th December 2016
    granule_semantic = 'L[0-1][ABC_]_(GR|DS|TL|TC|CO)_'
    det_or_tile = '(D[0-1][1-2]|T[0-9]{2}[A-Z]{3})(_)?'
    aorbit = '(A[0-9]{6}_' + '{}'.format('|S' + date + time + ')')
    baseline = '(N[0-9]{2})?'
    S2_L1C_Granule = re.compile(mission + file_class + file_category +
                                granule_semantic + site + 'V' + date + time +
                                aorbit + det_or_tile + baseline)
962

963
964
965
966
967
968
    # Sentitel 2 granule naming convention
    S2_L1C_Granule_Only = re.compile('L[0-1][ABC]_' + tile + 'A[0-9]{6}_' +
                                     date + time.replace('_', ''))

    # Sentinel 2 tile naming convetion
    S2_L1C_Tile = re.compile(tile + date + time.replace('_', ''))
969

970
971
972
    # whether a scene identifier matches the Sentinel naming convention
    scene = {}
    if S2_L1C_Old.search(scene_id):
973

974
975
976
        # the match of the regular expression
        match = S2_L1C_Old.search(scene_id)[0]

977
        # split scene into respective part
978
979
980
981
        parts = match.split('_')

        # the metadata of the scene identifier
        scene['id'] = match
982
        scene['satellite'] = parts[0]
983
984
        scene['file class'] = parts[1]
        scene['file category'] = parts[2]
985
986
987
        scene['file semantic'] = parts[3]
        scene['site'] = parts[4]
        scene['orbit'] = parts[6]
988
989
        scene['date'] = datetime.datetime.strptime(
           parts[7].split('T')[0].replace('V', ''), '%Y%m%d')
990
        scene['tile'] = None
991
992
993
994
995
996
997
998
999
1000

    elif S2_L1C_New.search(scene_id):

        # the match of the regular expression
        match = S2_L1C_New.search(scene_id)[0]

        # split scene into respective parts
        parts = match.split('_')

        # the metadata of the scene identifier
For faster browsing, not all history is shown. View entire blame