From a991fd992fb89eae8db712301b78639ef1e9de58 Mon Sep 17 00:00:00 2001
From: "Daniel.Frisinghelli" <daniel.frisinghelli@eurac.edu>
Date: Mon, 20 Jul 2020 17:12:02 +0200
Subject: [PATCH] Added data augmentation methods to processing pipeline

---
 main/config.py     |  47 ++++--
 pytorch/dataset.py | 345 +++++++++++++++++----------------------------
 2 files changed, 165 insertions(+), 227 deletions(-)

diff --git a/main/config.py b/main/config.py
index 1aaa85d..5a41701 100755
--- a/main/config.py
+++ b/main/config.py
@@ -15,26 +15,23 @@ import torch.nn as nn
 import torch.optim as optim
 
 from pytorch.models import UNet
+from pytorch.transforms import Augment, FlipLr, FlipUd, Noise
 
 # ------------------------- Dataset configuration -----------------------------
 # -----------------------------------------------------------------------------
 
 # define path to working directory
 # wd = '//projectdata.eurac.edu/projects/cci_snow/dfrisinghelli/'
-# wd = 'C:/Eurac/2020/'
-wd = '/mnt/CEPH_PROJECTS/cci_snow/dfrisinghelli/'
+wd = 'C:/Eurac/2020/'
+# wd = '/mnt/CEPH_PROJECTS/cci_snow/dfrisinghelli/'
 
 # define which dataset to train on
-# dataset_name = 'Sparcs'
-dataset_name = 'Cloud95'
+dataset_name = 'Sparcs'
+# dataset_name = 'Cloud95'
 
 # path to the dataset
-# dataset_path = os.path.join(wd, '_Datasets/Sparcs')
-dataset_path = os.path.join(wd, '_Datasets/Cloud95/Training')
-
-# the csv file containing the names of the informative patches of the
-# Cloud95 dataset
-patches = 'training_patches_95-cloud_nonempty.csv'
+dataset_path = os.path.join(wd, '_Datasets/Sparcs')
+# dataset_path = os.path.join(wd, '_Datasets/Cloud95/Training')
 
 # define the bands to use to train the segmentation network:
 # either a list of bands, e.g. ['red', 'green', 'nir', 'swir2', ...]
@@ -43,7 +40,33 @@ bands = ['red', 'green', 'blue', 'nir']
 
 # define the size of the network input
 # if None, the size will default to the size of a scene
-tile_size = 192
+tile_size = 125
+
+# whether to sort the dataset in chronological order, useful for time series
+# data
+sort = False
+
+# -----------------------------------------------------------------------------
+# -----------------------------------------------------------------------------
+
+# ------------------------- Dataset augmentation ------------------------------
+# -----------------------------------------------------------------------------
+
+# whether to artificially increase the training data size using data
+# augmentation methods
+transforms = [None]
+augmentations = [
+    Augment([
+        FlipLr(),
+        FlipUd(),
+        Noise(mode='speckle', mean=0.1, var=0.05)
+        ])
+    ]
+
+# each transformation in this list is treated as a new sample in case
+transforms.extend(augmentations)
+
+# if no augmentation is required, comment line 67!
 
 # -----------------------------------------------------------------------------
 # -----------------------------------------------------------------------------
@@ -80,7 +103,7 @@ kwargs = {'kernel_size': 3,  # the size of the convolving kernel
 state_path = os.path.join(wd, 'git/deep-learning/main/_models/')
 
 # whether to use a pretrained model
-pretrained = True
+pretrained = False
 
 # name of the pretrained model
 pretrained_model = 'UNet_SparcsDataset_t125_b128_rgbn.pt'
diff --git a/pytorch/dataset.py b/pytorch/dataset.py
index 0e9f0bd..eac30cd 100644
--- a/pytorch/dataset.py
+++ b/pytorch/dataset.py
@@ -29,13 +29,15 @@ from torch.utils.data import Dataset
 # locals
 from pytorch.constants import (Landsat8, Sentinel2, SparcsLabels,
                                Cloud95Labels, ProSnowLabels)
-from pytorch.utils import parse_landsat8_date, parse_sentinel2_date
+from pytorch.utils import (img2np, is_divisible, tile_offsets,
+                           parse_landsat8_date, parse_sentinel2_date)
 
 
 # generic image dataset class
 class ImageDataset(Dataset):
 
-    def __init__(self, root_dir, use_bands, tile_size):
+    def __init__(self, root_dir, use_bands, tile_size, sort=False,
+                 transforms=[None]):
         super().__init__()
 
         # the root directory: path to the image dataset
@@ -62,10 +64,30 @@ class ImageDataset(Dataset):
         if self.tile_size is None:
             self.tiles = 1
         else:
-            self.tiles = self.is_divisible(self.size, self.tile_size)
+            self.tiles = is_divisible(self.size, self.tile_size)
+
+        # whether to sort the list of samples:
+        # for time series data, set sort=True to obtain the scenes in
+        # chronological order
+        self.sort = sort
+
+        # whether to artificially increase the training data size using
+        # transformations to apply to the original image
+        self.transforms = transforms
+        if self.transforms is None:
+            self.transforms = [self.transforms]
 
         # the samples of the dataset
-        self.scenes = []
+        self.scenes = self.compose_scenes()
+
+        # check whether the compose_scenes() method is correctly implemented
+        for scene in self.scenes:
+            assert isinstance(scene, dict), \
+                'method compose_scenes() should return a list of dict.'
+            assert [band in scene for band in self.use_bands], \
+                'dict expected to have keys {}'.format(self.use_bands)
+            assert 'date' in scene and 'tile' in scene and 'gt' in scene, \
+                'dict expected to have keys {}'.format(['date', 'tile', 'gt'])
 
     # the __len__() method returns the number of samples in the dataset
     def __len__(self):
@@ -88,6 +110,13 @@ class ImageDataset(Dataset):
         # y : (height, width)
         x, y = self.preprocess(data, gt)
 
+        # optional transformation
+        if scene['transform'] is not None:
+            x, y = scene['transform'](x, y)
+
+        # convert to torch tensors
+        x, y = self.to_tensor(x, y)
+
         return x, y
 
     # the compose_scenes() method has to be implemented by the class inheriting
@@ -149,6 +178,15 @@ class ImageDataset(Dataset):
         raise NotImplementedError('Inherit the ImageDataset class and '
                                   'implement the method.')
 
+    # the date_parser() method has to be implemented by the class inheriting
+    # the ImageDataset class
+    # the input to the date_parser() method is a string describing a scene id,
+    # e.g. an id of a Landsat or a Sentinel scene
+    # date_parser() should return an instance of datetime.datetime
+    def date_parser(self, scene):
+        raise NotImplementedError('Inherit the ImageDataset class and '
+                                  'implement the method.')
+
     # _read_scene() reads all the bands and the ground truth mask in a
     # scene/tile to a numpy array and returns a dictionary with
     # (key, value) = ('band_name', np.ndarray(band_data))
@@ -158,10 +196,9 @@ class ImageDataset(Dataset):
         scene = self.scenes[idx]
 
         # read each band of the scene into a numpy array
-        scene_data = {key: (self.img2np(value, tile_size=self.tile_size,
-                                        tile=scene['tile'])
-                            if key != 'tile' and key != 'date' else value)
-                      for key, value in scene.items()}
+        scene_data = {key: img2np(value, self.tile_size, scene['tile'])
+                      if isinstance(value, str) else value for key, value
+                      in scene.items()}
 
         return scene_data
 
@@ -175,149 +212,19 @@ class ImageDataset(Dataset):
 
         return stack, gt
 
-    # the following functions are utility functions for common image
-    # manipulation operations
-
-    # this function reads an image to a numpy array
-    def img2np(self, path, tile_size=None, tile=None):
-
-        # open the tif file
-        if path is None:
-            print('Path is of NoneType, returning.')
-            return
-        img = gdal.Open(path)
-
-        # check whether to read the image in tiles
-        if tile_size is None:
-
-            # create empty numpy array to store whole image
-            image = np.empty(shape=(img.RasterCount, img.RasterYSize,
-                                    img.RasterXSize))
-
-            # iterate over the bands of the image
-            for b in range(img.RasterCount):
-
-                # read the data of band b
-                band = img.GetRasterBand(b+1)
-                data = band.ReadAsArray()
-
-                # append band b to numpy image array
-                image[b, :, :] = data
-
-        else:
-
-            # check whether the image is evenly divisible in square tiles
-            # of size (tile_size x tile_size)
-            ntiles = self.is_divisible((img.RasterXSize, img.RasterYSize),
-                                       tile_size)
-
-            # get the indices of the top left corner for each tile
-            topleft = self.tile_offsets((img.RasterYSize, img.RasterXSize),
-                                        tile_size)
-
-            # check whether to read all tiles or a single tile
-            if tile is None:
-
-                # create empty numpy array to store all tiles
-                image = np.empty(shape=(ntiles, img.RasterCount,
-                                        tile_size, tile_size))
-
-                # iterate over the tiles
-                for k, v in topleft.items():
-
-                    # iterate over the bands of the image
-                    for b in range(img.RasterCount):
-
-                        # read the data of band b
-                        band = img.GetRasterBand(b+1)
-                        data = band.ReadAsArray(v[1], v[0],
-                                                tile_size, tile_size)
-
-                        # append band b to numpy image array
-                        image[k, b, :, :] = data
-
-            else:
-
-                # create empty numpy array to store a single tile
-                image = np.empty(shape=(img.RasterCount, tile_size, tile_size))
-
-                # the tile of interest
-                tile = topleft[tile]
-
-                # iterate over the bands of the image
-                for b in range(img.RasterCount):
-
-                    # read the data of band b
-                    band = img.GetRasterBand(b+1)
-                    data = band.ReadAsArray(tile[1], tile[0],
-                                            tile_size, tile_size)
-
-                    # append band b to numpy image array
-                    image[b, :, :] = data
-
-        # check if there are more than 1 band
-        if not img.RasterCount > 1:
-            image = image.squeeze()
-
-        # close tif file
-        del img
-
-        # return the image
-        return image
-
-    # this function checks whether an image is evenly divisible
-    # in square tiles of defined size tile_size
-    def is_divisible(self, img_size, tile_size):
-        # calculate number of pixels per tile
-        pixels_per_tile = tile_size ** 2
-
-        # check whether the image is evenly divisible in square tiles of size
-        # (tile_size x tile_size)
-        ntiles = ((img_size[0] * img_size[1]) / pixels_per_tile)
-        assert ntiles.is_integer(), ('Image not evenly divisible in '
-                                     ' {} x {} tiles.').format(tile_size,
-                                                               tile_size)
-
-        return int(ntiles)
-
-    # this function returns the top-left corners for each tile
-    # if the image is evenly divisible in square tiles of
-    # defined size tile_size
-    def tile_offsets(self, img_size, tile_size):
-
-        # check if divisible
-        _ = self.is_divisible(img_size, tile_size)
-
-        # number of tiles along the width (columns) of the image
-        ntiles_columns = int(img_size[1] / tile_size)
-
-        # number of tiles along the height (rows) of the image
-        ntiles_rows = int(img_size[0] / tile_size)
-
-        # get the indices of the top left corner for each tile
-        indices = {}
-        k = 0
-        for i in range(ntiles_rows):
-            for j in range(ntiles_columns):
-                indices[k] = (i * tile_size, j * tile_size)
-                k += 1
-
-        return indices
+    def to_tensor(self, x, y):
+        return (torch.tensor(x.copy(), dtype=torch.float32),
+                torch.tensor(y.copy(), dtype=torch.uint8) if y is not None
+                else y)
 
 
 class StandardEoDataset(ImageDataset):
 
-    def __init__(self, root_dir, use_bands, tile_size, sort):
-        # initialize super class ImageDataset
-        super().__init__(root_dir, use_bands, tile_size)
-
-        # function that parses the date from a Landsat 8 scene id
-        self.date_parser = None
+    def __init__(self, root_dir, use_bands, tile_size, sort=False,
+                 transforms=[None]):
 
-        # whether to sort the list of samples:
-        # for time series data, set sort=True to obtain the scenes in
-        # chronological order
-        self.sort = sort
+        # initialize super class ImageDataset
+        super().__init__(root_dir, use_bands, tile_size, sort, transforms)
 
     # returns the band number of a Landsat8 or Sentinel2 tif file
     # x: path to a tif file
@@ -342,7 +249,7 @@ class StandardEoDataset(ImageDataset):
 
         return band
 
-    # _store_bands() writes the paths to the data of each scene to a dictionary
+    # store_bands() writes the paths to the data of each scene to a dictionary
     # only the bands of interest are stored
     def store_bands(self, bands, gt):
 
@@ -384,17 +291,23 @@ class StandardEoDataset(ImageDataset):
             # iterate over the tiles
             for tile in range(self.tiles):
 
-                # store the bands and the ground truth mask of the tile
-                data = self.store_bands(bands, gt)
+                # iterate over the transformations to apply
+                for transf in self.transforms:
+
+                    # store the bands and the ground truth mask of the tile
+                    data = self.store_bands(bands, gt)
+
+                    # store tile number
+                    data['tile'] = tile
 
-                # store tile number
-                data['tile'] = tile
+                    # store date
+                    data['date'] = date
 
-                # store date
-                data['date'] = date
+                    # store optional transformation
+                    data['transform'] = transf
 
-                # append to list
-                scenes.append(data)
+                    # append to list
+                    scenes.append(data)
 
         # sort list of scenes in chronological order
         if self.sort:
@@ -407,16 +320,10 @@ class StandardEoDataset(ImageDataset):
 class SparcsDataset(StandardEoDataset):
 
     def __init__(self, root_dir, use_bands=['red', 'green', 'blue'],
-                 tile_size=None, sort=False):
-        # initialize super class ImageDataset
-        super().__init__(root_dir, use_bands, tile_size, sort)
-
-        # function that parses the date from a Landsat 8 scene id
-        self.date_parser = parse_landsat8_date
+                 tile_size=None, sort=False, transforms=[None]):
 
-        # list of all scenes in the root directory
-        # each scene is divided into tiles blocks
-        self.scenes = self.compose_scenes()
+        # initialize super class StandardEoDataset
+        super().__init__(root_dir, use_bands, tile_size, sort, transforms)
 
     # image size of the Sparcs dataset: (height, width)
     def get_size(self):
@@ -436,23 +343,21 @@ class SparcsDataset(StandardEoDataset):
     def preprocess(self, data, gt):
 
         # if the preprocessing is not done externally, implement it here
+        return data, gt
 
-        # convert to torch tensors
-        x = torch.tensor(data, dtype=torch.float32)
-        y = torch.tensor(gt, dtype=torch.uint8) if gt is not None else gt
-        return x, y
+    # function that parses the date from a Landsat 8 scene id
+    def date_parser(self, scene):
+        return parse_landsat8_date(scene)
 
 
-class ProSnowDataset(StandardEoDataset):
 
-    def __init__(self, root_dir, use_bands, tile_size, sort):
-        super().__init__(root_dir, use_bands, tile_size, sort)
+class ProSnowDataset(StandardEoDataset):
 
-        # function that parses the date from a Sentinel 2 scene id
-        self.date_parser = parse_sentinel2_date
+    def __init__(self, root_dir, use_bands, tile_size, sort=True,
+                 transforms=[None]):
 
-        # list of samples in the dataset
-        self.scenes = self.compose_scenes()
+        # initialize super class StandardEoDataset
+        super().__init__(root_dir, use_bands, tile_size, sort, transforms)
 
     # Sentinel 2 bands
     def get_bands(self):
@@ -467,17 +372,18 @@ class ProSnowDataset(StandardEoDataset):
     def preprocess(self, data, gt):
 
         # if the preprocessing is not done externally, implement it here
+        return data, gt
 
-        # convert to torch tensors
-        x = torch.tensor(data, dtype=torch.float32)
-        y = torch.tensor(gt, dtype=torch.uint8) if gt is not None else gt
-        return x, y
+    # function that parses the date from a Sentinel 2 scene id
+    def date_parser(self, scene):
+        return parse_sentinel2_date(scene)
 
 
 class ProSnowGarmisch(ProSnowDataset):
 
-    def __init__(self, root_dir, use_bands=[], tile_size=None, sort=True):
-        super().__init__(root_dir, use_bands, tile_size, sort)
+    def __init__(self, root_dir, use_bands=[], tile_size=None, sort=True,
+                 transforms=[None]):
+        super().__init__(root_dir, use_bands, tile_size, sort, transforms)
 
     def get_size(self):
         return (615, 543)
@@ -485,8 +391,9 @@ class ProSnowGarmisch(ProSnowDataset):
 
 class ProSnowObergurgl(ProSnowDataset):
 
-    def __init__(self, root_dir, use_bands=[], tile_size=None, sort=True):
-        super().__init__(root_dir, use_bands, tile_size, sort)
+    def __init__(self, root_dir, use_bands=[], tile_size=None, sort=True,
+                 transforms=[None]):
+        super().__init__(root_dir, use_bands, tile_size, sort, transforms)
 
     def get_size(self):
         return (310, 270)
@@ -494,21 +401,16 @@ class ProSnowObergurgl(ProSnowDataset):
 
 class Cloud95Dataset(ImageDataset):
 
-    def __init__(self, root_dir, use_bands=[], tile_size=None, exclude=None):
+    def __init__(self, root_dir, use_bands=[], tile_size=None, sort=False,
+                 transforms=[None]):
 
-        # initialize super class ImageDataset
-        super().__init__(root_dir, use_bands, tile_size)
-
-        # whether to exclude patches with more than 80% black pixels, i.e.
-        # patches resulting from the black margins around a Landsat 8 scene
-        self.exclude = exclude
-
-        # function that parses the date from a Landsat 8 scene id
-        self.date_parser = parse_landsat8_date
+        # the csv file containing the names of the informative patches
+        # patches with more than 80% black pixels, i.e. patches resulting from
+        # the black margins around a Landsat 8 scene are excluded
+        self.exclude = 'training_patches_95-cloud_nonempty.csv'
 
-        # list of all scenes in the root directory
-        # each scene is divided into tiles blocks
-        self.scenes = self.compose_scenes()
+        # initialize super class ImageDataset
+        super().__init__(root_dir, use_bands, tile_size, sort, transforms)
 
     # image size of the Cloud-95 dataset: (height, width)
     def get_size(self):
@@ -529,10 +431,14 @@ class Cloud95Dataset(ImageDataset):
         # normalize the data
         # here, we use the normalization of the authors of Cloud-95, i.e.
         # Mohajerani and Saeedi (2019, 2020)
-        x = torch.tensor(data / 65535, dtype=torch.float32)
-        y = torch.tensor(gt / 255, dtype=torch.uint8)
+        data /= 65535
+        gt /= 255
 
-        return x, y
+        return data, gt
+
+    # function that parses the date from a Landsat 8 scene id
+    def date_parser(self, scene):
+        return parse_landsat8_date(scene)
 
     def compose_scenes(self):
 
@@ -577,23 +483,29 @@ class Cloud95Dataset(ImageDataset):
             # iterate over the tiles
             for tile in range(self.tiles):
 
-                # initialize dictionary to store bands of current patch
-                scene = {}
+                # iterate over the transformations to apply
+                for transf in self.transforms:
+
+                    # initialize dictionary to store bands of current patch
+                    scene = {}
+
+                    # iterate over the bands of interest
+                    for band in band_dirs.keys():
+                        # save path to current band tif file to dictionary
+                        scene[band] = os.path.join(band_dirs[band],
+                                                   file.replace(biter, band))
 
-                # iterate over the bands of interest
-                for band in band_dirs.keys():
-                    # save path to current band tif file to dictionary
-                    scene[band] = os.path.join(band_dirs[band],
-                                               file.replace(biter, band))
+                    # store tile number
+                    scene['tile'] = tile
 
-                # store tile number
-                scene['tile'] = tile
+                    # store date
+                    scene['date'] = date
 
-                # store date
-                scene['date'] = date
+                    # store optional transformation
+                    scene['transform'] = transf
 
-                # append patch to list of all patches
-                scenes.append(scene)
+                    # append patch to list of all patches
+                    scenes.append(scene)
 
         # sort list of scenes in chronological order
         if self.sort:
@@ -624,19 +536,22 @@ if __name__ == '__main__':
     # instanciate the Cloud-95 dataset
     # cloud_dataset = Cloud95Dataset(cloud_path,
     #                                tile_size=192,
-    #                                exclude=patches)
+    #                                use_bands=[],
+    #                                sort=False)
 
     # instanciate the SparcsDataset class
     sparcs_dataset = SparcsDataset(sparcs_path,
                                    tile_size=None,
                                    use_bands=['nir', 'red', 'green'],
-                                   sort=False)
+                                   sort=False,
+                                   transforms=[None])
 
     # instanciate the ProSnow datasets
     garmisch = ProSnowGarmisch(os.path.join(prosnow_path, 'Garmisch'),
                                tile_size=None,
                                use_bands=['nir', 'red', 'green'],
-                               sort=True)
+                               sort=True,
+                               transforms=[None])
     # obergurgl = ProSnowObergurgl(os.path.join(prosnow_path, 'Obergurgl'),
     #                              tile_size=None,
     #                              use_bands=['nir', 'red', 'green'],
-- 
GitLab