From 82df58504ecf3d87c7a7310060eb7d93666ab935 Mon Sep 17 00:00:00 2001
From: "Daniel.Frisinghelli" <daniel.frisinghelli@eurac.edu>
Date: Wed, 23 Dec 2020 16:39:26 +0100
Subject: [PATCH] Added the option to set a description for each raster band.

---
 pysegcnn/core/utils.py | 14 +++++++++++---
 1 file changed, 11 insertions(+), 3 deletions(-)

diff --git a/pysegcnn/core/utils.py b/pysegcnn/core/utils.py
index 96bce4f..b49161d 100644
--- a/pysegcnn/core/utils.py
+++ b/pysegcnn/core/utils.py
@@ -252,7 +252,8 @@ def img2np(path, tile_size=None, tile=None, pad=False, cval=0):
     return image
 
 
-def np2tif(array, filename, src_ds=None, epsg=None, geotransform=None):
+def np2tif(array, filename, names=None, src_ds=None, epsg=None,
+           geotransform=None):
     """Save a :py:class`numpy.ndarray` as a GeoTIFF.
 
     The spatial coordinate reference system can be specified in two ways:
@@ -270,13 +271,16 @@ def np2tif(array, filename, src_ds=None, epsg=None, geotransform=None):
         (bands, height, width) are supported.
     filename : `str` or :py:class:`pathlib.Path`
         The filename of the GeoTIFF.
-    src_ds : :py:class:`osgeo.gdal.Dataset`
+    names : `list` [`str`], optional
+        The names of the bands in ``array`` in order. The default is `None`.
+        If `None`, no band description is added.
+    src_ds : :py:class:`osgeo.gdal.Dataset`, optional
         The source dataset from which the spatial reference is inherited. The
         default is `None`.
     epsg : `int`, optional
         The EPSG code of the target coordinate reference system. The default is
         `None`.
-    geotransform : `tuple`
+    geotransform : `tuple`, optional
         A tuple with six elements of the form,
         (x_top_left, x_res, x_shift, y_top_left, -y_res, y_shift), describing
         the spatial reference.
@@ -321,6 +325,10 @@ def np2tif(array, filename, src_ds=None, epsg=None, geotransform=None):
     for b in range(bands):
         trg_band = trg_ds.GetRasterBand(b + 1)
         trg_band.WriteArray(array[b, ...])
+
+        # set the band description, if specified
+        if names is not None:
+            trg_band.SetDescription(names[b])
         trg_band.FlushCache()
 
     # set spatial reference
-- 
GitLab