From 12364f5472ee7ba78198d614c4c2627dfded6bd7 Mon Sep 17 00:00:00 2001
From: "Daniel.Frisinghelli" <daniel.frisinghelli@eurac.edu>
Date: Tue, 25 Aug 2020 15:43:51 +0200
Subject: [PATCH] Added Landsat radiometric calibration function

---
 pysegcnn/core/utils.py | 379 +++++++++++++++++++++++++++++++++++++----
 1 file changed, 346 insertions(+), 33 deletions(-)

diff --git a/pysegcnn/core/utils.py b/pysegcnn/core/utils.py
index c9e6cce..84d2b3c 100644
--- a/pysegcnn/core/utils.py
+++ b/pysegcnn/core/utils.py
@@ -23,6 +23,7 @@ import pathlib
 import tarfile
 import zipfile
 import datetime
+import warnings
 
 # externals
 import gdal
@@ -32,6 +33,9 @@ import numpy as np
 # module level logger
 LOGGER = logging.getLogger(__name__)
 
+# suffixes for radiometrically calibrated scenes
+SUFFIXES = ['toa_ref', 'toa_rad', 'toa_brt']
+
 
 def img2np(path, tile_size=None, tile=None, pad=False, cval=0):
     """Read an image to a `numpy.ndarray`.
@@ -734,6 +738,11 @@ def destack_tiff(image, outpath=None, overwrite=False, remove=False,
         filenames for each band in ``image`` are, "filename(``image``) +
         + _B(i)_ + ``suffix``.tif". The default is ''.
 
+    Raises
+    ------
+    FileNotFoundError
+        Raised if ``image`` does not exist.
+
     Returns
     -------
     None.
@@ -762,59 +771,60 @@ def destack_tiff(image, outpath=None, overwrite=False, remove=False,
         if not outpath.exists():
             outpath.mkdir(parents=True, exist_ok=True)
 
-    # open the raster
+    # open the TIFF
     img = gdal.Open(str(image))
 
     # check whether the file was already processed
     processed = list(outpath.glob(imgname + '_B*'))
-    if len(processed) == img.RasterCount and not overwrite:
+    if len(processed) >= img.RasterCount and not overwrite:
         LOGGER.info('{} already processed.'.format(imgname))
-        del img
-        return
 
-    # image driver
-    driver = gdal.GetDriverByName('GTiff')
-    driver.Register()
+    # destack the TIFF
+    else:
+        # image driver
+        driver = gdal.GetDriverByName('GTiff')
+        driver.Register()
 
-    # image size and tiles
-    cols = img.RasterXSize
-    rows = img.RasterYSize
-    bands = img.RasterCount
+        # image size and tiles
+        cols = img.RasterXSize
+        rows = img.RasterYSize
+        bands = img.RasterCount
 
-    # print progress
-    LOGGER.info('Processing: {}'.format(imgname))
+        # print progress
+        LOGGER.info('Processing: {}'.format(imgname))
 
-    # iterate the bands of the raster
-    for b in range(1, bands + 1):
+        # iterate the bands of the raster
+        for b in range(1, bands + 1):
 
-        # read the data of band b
-        band = img.GetRasterBand(b)
-        data = band.ReadAsArray()
+            # read the data of band b
+            band = img.GetRasterBand(b)
+            data = band.ReadAsArray()
 
-        # output file: replace for band name
-        fname = str(outpath.joinpath(imgname + '_B{:d}.tif'.format(b)))
-        if suffix:
-            fname = fname.replace('.tif', '_{}.tif'.format(suffix))
-        outDs = driver.Create(fname, cols, rows, 1, band.DataType)
+            # output file: replace for band name
+            fname = str(outpath.joinpath(imgname + '_B{:d}.tif'.format(b)))
+            if suffix:
+                fname = fname.replace('.tif', '_{}.tif'.format(suffix))
+            outDs = driver.Create(fname, cols, rows, 1, band.DataType)
 
-        # define output band
-        outband = outDs.GetRasterBand(1)
+            # define output band
+            outband = outDs.GetRasterBand(1)
 
-        # write array to output band
-        outband.WriteArray(data)
-        outband.FlushCache()
+            # write array to output band
+            outband.WriteArray(data)
+            outband.FlushCache()
 
-        # Set the geographic information
-        outDs.SetProjection(img.GetProjection())
-        outDs.SetGeoTransform(img.GetGeoTransform())
+            # Set the geographic information
+            outDs.SetProjection(img.GetProjection())
+            outDs.SetGeoTransform(img.GetGeoTransform())
 
-        # clear memory
-        del outband, band, data, outDs
+            # clear memory
+            del outband, band, data, outDs
 
     # remove old stacked GeoTIFF
     del img
     if remove:
         os.unlink(image)
+    return
 
 
 def standard_eo_structure(source_path, target_path, overwrite=False,
@@ -967,3 +977,306 @@ def extract_archive(inpath, outpath, overwrite=False):
     tar.extractall(target)
 
     return target
+
+
+def read_landsat_metadata(file):
+    """Parse the Landsat metadata *_MTL.txt file.
+
+    Parameters
+    ----------
+    file : `str` or `pathlib.Path`
+        Path to a Landsat *_MTL.txt file.
+
+    Raises
+    ------
+    FileNotFoundError
+        Raised if ``file`` does not exist.
+
+    Returns
+    -------
+    metadata : `dict`
+        The metadata text file as dictionary, where each line is a (key, value)
+        pair.
+    """
+    file = pathlib.Path(file)
+    # check if the metadata file exists
+    if not file.exists():
+        raise FileNotFoundError('{} does not exist'.format(file))
+
+    # read metadata file
+    metadata = {}
+    LOGGER.info('Parsing metadata file: {}'.format(file.name))
+    with open(file, 'r') as metafile:
+        # iterate over the lines of the metadata file
+        for line in metafile:
+            try:
+                # the current line as (key, value pair)
+                (key, value) = line.split('=')
+
+                # store current line in dictionary: skip lines defining the
+                # parameter groups
+                if 'GROUP' not in key:
+                    metadata[key.strip()] = value.strip()
+
+            # catch value error of line.split('='), i.e. if there is no '='
+            # sign in the current line
+            except ValueError:
+                continue
+
+    return metadata
+
+
+def get_radiometric_constants(metadata):
+    """Retrieve the radiometric calibration constants.
+
+    Parameters
+    ----------
+    metadata : `dict`
+        The dictionary returned by ``read_landsat_metadata``.
+
+    Returns
+    -------
+    oli : `dict`
+        Radiometric rescaling factors of the OLI sensor.
+    tir : `dict`
+        Thermal conversion constants of the TIRS sensor.
+    """
+    # regular expression patterns matching the radiometric rescaling factors
+    oli_pattern = re.compile('(RADIANCE|REFLECTANCE)_(MULT|ADD)_BAND_\\d{1,2}')
+    tir_pattern = re.compile('K(1|2)_CONSTANT_BAND_\\d{2}')
+
+    # rescaling factors to calculate top of atmosphere radiance and reflectance
+    oli = {key: float(value) for key, value in metadata.items() if
+           oli_pattern.search(key) is not None}
+
+    # rescaling factors to calculate at-satellite temperatures in Kelvin
+    tir = {key: float(value) for key, value in metadata.items() if
+           tir_pattern.search(key) is not None}
+
+    return oli, tir
+
+
+def landsat_radiometric_calibration(scene, outpath=None, exclude=[],
+                                    radiance=False, overwrite=False,
+                                    remove_raw=True):
+    """Radiometric calibration of Landsat Collection Level 1 scenes.
+
+    Convert the Landsat OLI bands to top of atmosphere radiance or reflectance
+    and the TIRS bands to top of atmosphere brightness temperature.
+
+    Conversion is performed following the `equations`_ provided by the USGS.
+
+    The filename of each band is extended by one of the following suffixes,
+    depending on the type of the radiometric calibration:
+
+        'toa_ref': top of atmosphere reflectance
+        'toa_rad': top of atmopshere radiance
+        'toa_brt': top of atmosphere brightness temperature
+
+    Parameters
+    ----------
+    scene : `str` or `pathlib.Path`
+        Path to a Landsat scene in digital number format.
+    outpath : `str` or `pathlib.Path`, optional
+        Path to save the calibrated images. The default is None, which means
+        saving to ``scene``.
+    exclude : `list` [`str`], optional
+        Bands to exclude from the radiometric calibration. The default is [].
+    radiance : `bool`, optional
+        Whether to calculate top of atmosphere radiance. The default is False,
+        which means calculating top of atmopshere reflectance.
+    overwrite : `bool`, optional
+        Whether to overwrite the calibrated images. The default is False.
+    remove_raw : `bool`, optional
+        Whether to remove the raw digitial number images. The default is True.
+
+    Raises
+    ------
+    FileNotFoundError
+        Raised if ``scene`` does not exist.
+
+    Returns
+    -------
+    None.
+
+    .. _equations:
+        https://www.usgs.gov/land-resources/nli/landsat/using-usgs-landsat-level-1-data-product
+
+    """
+    scene = pathlib.Path(scene)
+    # check if the input scene exists
+    if not scene.exists():
+        raise FileNotFoundError('{} does not exist.'.format(scene))
+
+    # default: output directory is equal to the input directory
+    if outpath is None:
+        outpath = scene
+    else:
+        outpath = pathlib.Path(outpath)
+        # check if output directory exists
+        if not outpath.exists():
+            outpath.mkdir(parents=True, exist_ok=True)
+
+    # the scene metadata file
+    try:
+        mpattern = re.compile('[mMtTlL].txt')
+        metafile = [f for f in scene.iterdir() if mpattern.search(str(f))]
+        metadata = read_landsat_metadata(metafile.pop())
+    except IndexError:
+        LOGGER.error('Can not calibrate scene {}: {} does not exist.'
+                     .format(scene.name, scene.name + '_MTL.txt'))
+        return
+
+    # radiometric calibration constants
+    oli, tir = get_radiometric_constants(metadata)
+
+    # log current Landsat scene ID
+    LOGGER.info('Landsat scene id: {}'.format(metadata['LANDSAT_SCENE_ID']))
+
+    # images to process
+    ipattern = re.compile('B\\d{1,2}(.*)\\.[tT][iI][fF]')
+    images = [file for file in scene.iterdir() if ipattern.search(str(file))]
+
+    # pattern to match calibrated images
+    cal_pattern = re.compile('({}|{}|{}).[tT][iI][fF]'.format(*SUFFIXES))
+
+    # check if any images were already processe
+    processed = [file for file in images if cal_pattern.search(str(file))]
+    if any(processed):
+        LOGGER.info('The following images have already been processed:')
+
+        # images that were already processed
+        LOGGER.info(('\n ' + (len(__name__) + 1) * ' ').join(
+            [str(file.name) for file in processed]))
+
+        # overwrite: remove processed images and redo calibration
+        if overwrite:
+            LOGGER.info('Preparing to overwrite ...')
+
+            # remove processed images
+            for toremove in processed:
+                # remove from disk
+                os.unlink(toremove)
+                LOGGER.info('rm {}'.format(toremove))
+
+                # remove from list to process
+                images.remove(toremove)
+
+        # not overwriting, terminate calibration
+        else:
+            return
+
+    # exclude defined bands from the calibration procedure
+    for i in images:
+        current_band = re.search('B\\d{1,2}', str(i))[0]
+        if current_band in exclude:
+            images.remove(i)
+
+    # image driver
+    driver = gdal.GetDriverByName('GTiff')
+    driver.Register()
+
+    # iterate over the different bands
+    for image in images:
+        LOGGER.info('Processing: {}'.format(image.name))
+
+        # read the image
+        img = gdal.Open(str(image))
+        band = img.GetRasterBand(1)
+
+        # read data as array
+        data = band.ReadAsArray()
+
+        # mask of erroneous values, i.e. mask of values < 0
+        mask = data < 0
+
+        # output filename
+        fname = outpath.joinpath(image.stem)
+
+        # get the current band
+        band = re.search('B\\d{1,2}', str(image))[0].replace('B', 'BAND_')
+
+        # check if the current band is a thermal band
+        if band in ['BAND_10', 'BAND_11']:
+
+            # output file name for TIRS bands
+            fname = pathlib.Path(str(fname) + '_toa_brt.tif')
+
+            # calculate top of atmosphere brightness temperature
+            with warnings.catch_warnings():
+                warnings.filterwarnings("ignore", category=RuntimeWarning)
+
+                # top of atmosphere radiance
+                rad = (oli['RADIANCE_MULT_{}'.format(band)] * data +
+                       oli['RADIANCE_ADD_{}'.format(band)])
+
+                # top of atmosphere brightness temperature
+                den = np.log(tir['K1_CONSTANT_{}'.format(band)] / rad + 1)
+                toa = tir['K2_CONSTANT_{}'.format(band)] / den
+
+                # clear memory
+                del den, data, rad
+        else:
+
+            # whether to calculate top of atmosphere radiance or reflectance
+            if radiance:
+
+                # output file name for OLI bands: radiance
+                fname = pathlib.Path(str(fname) + '_toa_rad.tif')
+
+                # calculate top of atmosphere radiance
+                toa = (oli['RADIANCE_MULT_{}'.format(band)] * data +
+                       oli['RADIANCE_ADD_{}'.format(band)])
+
+                # clear memory
+                del data
+
+            else:
+
+                # output file name for OLI bands: reflectance
+                fname = pathlib.Path(str(fname) + '_toa_ref.tif')
+
+                # solar zenith angle in radians
+                zenith = np.radians(90 - float(metadata['SUN_ELEVATION']))
+
+                # calculate top of the atmosphere reflectance
+                ref = (oli['REFLECTANCE_MULT_{}'.format(band)] * data +
+                       oli['REFLECTANCE_ADD_{}'.format(band)])
+                toa = ref / np.cos(zenith)
+
+                # clear memory
+                del ref, data
+
+        # mask erroneous values
+        toa[mask] = np.nan
+
+        # output file
+        outDs = driver.Create(str(fname), img.RasterXSize, img.RasterYSize,
+                              img.RasterCount, gdal.GDT_Float32)
+        outband = outDs.GetRasterBand(1)
+
+        # write array
+        outband.WriteArray(toa)
+        outband.FlushCache()
+
+        # Set the geographic information
+        outDs.SetProjection(img.GetProjection())
+        outDs.SetGeoTransform(img.GetGeoTransform())
+
+        # clear memory
+        del outband, band, img, outDs, toa, mask
+
+    # check if raw images should be removed
+    if remove_raw:
+
+        # raw digital number images
+        _dn = [i for i in scene.iterdir() if ipattern.search(str(i)) and
+               not cal_pattern.search(str(i))]
+
+        # remove raw digital number images
+        for toremove in _dn:
+            # remove from disk
+            os.unlink(toremove)
+            LOGGER.info('rm {}'.format(toremove))
+
+    return
-- 
GitLab