From a991fd992fb89eae8db712301b78639ef1e9de58 Mon Sep 17 00:00:00 2001 From: "Daniel.Frisinghelli" <daniel.frisinghelli@eurac.edu> Date: Mon, 20 Jul 2020 17:12:02 +0200 Subject: [PATCH] Added data augmentation methods to processing pipeline --- main/config.py | 47 ++++-- pytorch/dataset.py | 345 +++++++++++++++++---------------------------- 2 files changed, 165 insertions(+), 227 deletions(-) diff --git a/main/config.py b/main/config.py index 1aaa85d..5a41701 100755 --- a/main/config.py +++ b/main/config.py @@ -15,26 +15,23 @@ import torch.nn as nn import torch.optim as optim from pytorch.models import UNet +from pytorch.transforms import Augment, FlipLr, FlipUd, Noise # ------------------------- Dataset configuration ----------------------------- # ----------------------------------------------------------------------------- # define path to working directory # wd = '//projectdata.eurac.edu/projects/cci_snow/dfrisinghelli/' -# wd = 'C:/Eurac/2020/' -wd = '/mnt/CEPH_PROJECTS/cci_snow/dfrisinghelli/' +wd = 'C:/Eurac/2020/' +# wd = '/mnt/CEPH_PROJECTS/cci_snow/dfrisinghelli/' # define which dataset to train on -# dataset_name = 'Sparcs' -dataset_name = 'Cloud95' +dataset_name = 'Sparcs' +# dataset_name = 'Cloud95' # path to the dataset -# dataset_path = os.path.join(wd, '_Datasets/Sparcs') -dataset_path = os.path.join(wd, '_Datasets/Cloud95/Training') - -# the csv file containing the names of the informative patches of the -# Cloud95 dataset -patches = 'training_patches_95-cloud_nonempty.csv' +dataset_path = os.path.join(wd, '_Datasets/Sparcs') +# dataset_path = os.path.join(wd, '_Datasets/Cloud95/Training') # define the bands to use to train the segmentation network: # either a list of bands, e.g. ['red', 'green', 'nir', 'swir2', ...] @@ -43,7 +40,33 @@ bands = ['red', 'green', 'blue', 'nir'] # define the size of the network input # if None, the size will default to the size of a scene -tile_size = 192 +tile_size = 125 + +# whether to sort the dataset in chronological order, useful for time series +# data +sort = False + +# ----------------------------------------------------------------------------- +# ----------------------------------------------------------------------------- + +# ------------------------- Dataset augmentation ------------------------------ +# ----------------------------------------------------------------------------- + +# whether to artificially increase the training data size using data +# augmentation methods +transforms = [None] +augmentations = [ + Augment([ + FlipLr(), + FlipUd(), + Noise(mode='speckle', mean=0.1, var=0.05) + ]) + ] + +# each transformation in this list is treated as a new sample in case +transforms.extend(augmentations) + +# if no augmentation is required, comment line 67! # ----------------------------------------------------------------------------- # ----------------------------------------------------------------------------- @@ -80,7 +103,7 @@ kwargs = {'kernel_size': 3, # the size of the convolving kernel state_path = os.path.join(wd, 'git/deep-learning/main/_models/') # whether to use a pretrained model -pretrained = True +pretrained = False # name of the pretrained model pretrained_model = 'UNet_SparcsDataset_t125_b128_rgbn.pt' diff --git a/pytorch/dataset.py b/pytorch/dataset.py index 0e9f0bd..eac30cd 100644 --- a/pytorch/dataset.py +++ b/pytorch/dataset.py @@ -29,13 +29,15 @@ from torch.utils.data import Dataset # locals from pytorch.constants import (Landsat8, Sentinel2, SparcsLabels, Cloud95Labels, ProSnowLabels) -from pytorch.utils import parse_landsat8_date, parse_sentinel2_date +from pytorch.utils import (img2np, is_divisible, tile_offsets, + parse_landsat8_date, parse_sentinel2_date) # generic image dataset class class ImageDataset(Dataset): - def __init__(self, root_dir, use_bands, tile_size): + def __init__(self, root_dir, use_bands, tile_size, sort=False, + transforms=[None]): super().__init__() # the root directory: path to the image dataset @@ -62,10 +64,30 @@ class ImageDataset(Dataset): if self.tile_size is None: self.tiles = 1 else: - self.tiles = self.is_divisible(self.size, self.tile_size) + self.tiles = is_divisible(self.size, self.tile_size) + + # whether to sort the list of samples: + # for time series data, set sort=True to obtain the scenes in + # chronological order + self.sort = sort + + # whether to artificially increase the training data size using + # transformations to apply to the original image + self.transforms = transforms + if self.transforms is None: + self.transforms = [self.transforms] # the samples of the dataset - self.scenes = [] + self.scenes = self.compose_scenes() + + # check whether the compose_scenes() method is correctly implemented + for scene in self.scenes: + assert isinstance(scene, dict), \ + 'method compose_scenes() should return a list of dict.' + assert [band in scene for band in self.use_bands], \ + 'dict expected to have keys {}'.format(self.use_bands) + assert 'date' in scene and 'tile' in scene and 'gt' in scene, \ + 'dict expected to have keys {}'.format(['date', 'tile', 'gt']) # the __len__() method returns the number of samples in the dataset def __len__(self): @@ -88,6 +110,13 @@ class ImageDataset(Dataset): # y : (height, width) x, y = self.preprocess(data, gt) + # optional transformation + if scene['transform'] is not None: + x, y = scene['transform'](x, y) + + # convert to torch tensors + x, y = self.to_tensor(x, y) + return x, y # the compose_scenes() method has to be implemented by the class inheriting @@ -149,6 +178,15 @@ class ImageDataset(Dataset): raise NotImplementedError('Inherit the ImageDataset class and ' 'implement the method.') + # the date_parser() method has to be implemented by the class inheriting + # the ImageDataset class + # the input to the date_parser() method is a string describing a scene id, + # e.g. an id of a Landsat or a Sentinel scene + # date_parser() should return an instance of datetime.datetime + def date_parser(self, scene): + raise NotImplementedError('Inherit the ImageDataset class and ' + 'implement the method.') + # _read_scene() reads all the bands and the ground truth mask in a # scene/tile to a numpy array and returns a dictionary with # (key, value) = ('band_name', np.ndarray(band_data)) @@ -158,10 +196,9 @@ class ImageDataset(Dataset): scene = self.scenes[idx] # read each band of the scene into a numpy array - scene_data = {key: (self.img2np(value, tile_size=self.tile_size, - tile=scene['tile']) - if key != 'tile' and key != 'date' else value) - for key, value in scene.items()} + scene_data = {key: img2np(value, self.tile_size, scene['tile']) + if isinstance(value, str) else value for key, value + in scene.items()} return scene_data @@ -175,149 +212,19 @@ class ImageDataset(Dataset): return stack, gt - # the following functions are utility functions for common image - # manipulation operations - - # this function reads an image to a numpy array - def img2np(self, path, tile_size=None, tile=None): - - # open the tif file - if path is None: - print('Path is of NoneType, returning.') - return - img = gdal.Open(path) - - # check whether to read the image in tiles - if tile_size is None: - - # create empty numpy array to store whole image - image = np.empty(shape=(img.RasterCount, img.RasterYSize, - img.RasterXSize)) - - # iterate over the bands of the image - for b in range(img.RasterCount): - - # read the data of band b - band = img.GetRasterBand(b+1) - data = band.ReadAsArray() - - # append band b to numpy image array - image[b, :, :] = data - - else: - - # check whether the image is evenly divisible in square tiles - # of size (tile_size x tile_size) - ntiles = self.is_divisible((img.RasterXSize, img.RasterYSize), - tile_size) - - # get the indices of the top left corner for each tile - topleft = self.tile_offsets((img.RasterYSize, img.RasterXSize), - tile_size) - - # check whether to read all tiles or a single tile - if tile is None: - - # create empty numpy array to store all tiles - image = np.empty(shape=(ntiles, img.RasterCount, - tile_size, tile_size)) - - # iterate over the tiles - for k, v in topleft.items(): - - # iterate over the bands of the image - for b in range(img.RasterCount): - - # read the data of band b - band = img.GetRasterBand(b+1) - data = band.ReadAsArray(v[1], v[0], - tile_size, tile_size) - - # append band b to numpy image array - image[k, b, :, :] = data - - else: - - # create empty numpy array to store a single tile - image = np.empty(shape=(img.RasterCount, tile_size, tile_size)) - - # the tile of interest - tile = topleft[tile] - - # iterate over the bands of the image - for b in range(img.RasterCount): - - # read the data of band b - band = img.GetRasterBand(b+1) - data = band.ReadAsArray(tile[1], tile[0], - tile_size, tile_size) - - # append band b to numpy image array - image[b, :, :] = data - - # check if there are more than 1 band - if not img.RasterCount > 1: - image = image.squeeze() - - # close tif file - del img - - # return the image - return image - - # this function checks whether an image is evenly divisible - # in square tiles of defined size tile_size - def is_divisible(self, img_size, tile_size): - # calculate number of pixels per tile - pixels_per_tile = tile_size ** 2 - - # check whether the image is evenly divisible in square tiles of size - # (tile_size x tile_size) - ntiles = ((img_size[0] * img_size[1]) / pixels_per_tile) - assert ntiles.is_integer(), ('Image not evenly divisible in ' - ' {} x {} tiles.').format(tile_size, - tile_size) - - return int(ntiles) - - # this function returns the top-left corners for each tile - # if the image is evenly divisible in square tiles of - # defined size tile_size - def tile_offsets(self, img_size, tile_size): - - # check if divisible - _ = self.is_divisible(img_size, tile_size) - - # number of tiles along the width (columns) of the image - ntiles_columns = int(img_size[1] / tile_size) - - # number of tiles along the height (rows) of the image - ntiles_rows = int(img_size[0] / tile_size) - - # get the indices of the top left corner for each tile - indices = {} - k = 0 - for i in range(ntiles_rows): - for j in range(ntiles_columns): - indices[k] = (i * tile_size, j * tile_size) - k += 1 - - return indices + def to_tensor(self, x, y): + return (torch.tensor(x.copy(), dtype=torch.float32), + torch.tensor(y.copy(), dtype=torch.uint8) if y is not None + else y) class StandardEoDataset(ImageDataset): - def __init__(self, root_dir, use_bands, tile_size, sort): - # initialize super class ImageDataset - super().__init__(root_dir, use_bands, tile_size) - - # function that parses the date from a Landsat 8 scene id - self.date_parser = None + def __init__(self, root_dir, use_bands, tile_size, sort=False, + transforms=[None]): - # whether to sort the list of samples: - # for time series data, set sort=True to obtain the scenes in - # chronological order - self.sort = sort + # initialize super class ImageDataset + super().__init__(root_dir, use_bands, tile_size, sort, transforms) # returns the band number of a Landsat8 or Sentinel2 tif file # x: path to a tif file @@ -342,7 +249,7 @@ class StandardEoDataset(ImageDataset): return band - # _store_bands() writes the paths to the data of each scene to a dictionary + # store_bands() writes the paths to the data of each scene to a dictionary # only the bands of interest are stored def store_bands(self, bands, gt): @@ -384,17 +291,23 @@ class StandardEoDataset(ImageDataset): # iterate over the tiles for tile in range(self.tiles): - # store the bands and the ground truth mask of the tile - data = self.store_bands(bands, gt) + # iterate over the transformations to apply + for transf in self.transforms: + + # store the bands and the ground truth mask of the tile + data = self.store_bands(bands, gt) + + # store tile number + data['tile'] = tile - # store tile number - data['tile'] = tile + # store date + data['date'] = date - # store date - data['date'] = date + # store optional transformation + data['transform'] = transf - # append to list - scenes.append(data) + # append to list + scenes.append(data) # sort list of scenes in chronological order if self.sort: @@ -407,16 +320,10 @@ class StandardEoDataset(ImageDataset): class SparcsDataset(StandardEoDataset): def __init__(self, root_dir, use_bands=['red', 'green', 'blue'], - tile_size=None, sort=False): - # initialize super class ImageDataset - super().__init__(root_dir, use_bands, tile_size, sort) - - # function that parses the date from a Landsat 8 scene id - self.date_parser = parse_landsat8_date + tile_size=None, sort=False, transforms=[None]): - # list of all scenes in the root directory - # each scene is divided into tiles blocks - self.scenes = self.compose_scenes() + # initialize super class StandardEoDataset + super().__init__(root_dir, use_bands, tile_size, sort, transforms) # image size of the Sparcs dataset: (height, width) def get_size(self): @@ -436,23 +343,21 @@ class SparcsDataset(StandardEoDataset): def preprocess(self, data, gt): # if the preprocessing is not done externally, implement it here + return data, gt - # convert to torch tensors - x = torch.tensor(data, dtype=torch.float32) - y = torch.tensor(gt, dtype=torch.uint8) if gt is not None else gt - return x, y + # function that parses the date from a Landsat 8 scene id + def date_parser(self, scene): + return parse_landsat8_date(scene) -class ProSnowDataset(StandardEoDataset): - def __init__(self, root_dir, use_bands, tile_size, sort): - super().__init__(root_dir, use_bands, tile_size, sort) +class ProSnowDataset(StandardEoDataset): - # function that parses the date from a Sentinel 2 scene id - self.date_parser = parse_sentinel2_date + def __init__(self, root_dir, use_bands, tile_size, sort=True, + transforms=[None]): - # list of samples in the dataset - self.scenes = self.compose_scenes() + # initialize super class StandardEoDataset + super().__init__(root_dir, use_bands, tile_size, sort, transforms) # Sentinel 2 bands def get_bands(self): @@ -467,17 +372,18 @@ class ProSnowDataset(StandardEoDataset): def preprocess(self, data, gt): # if the preprocessing is not done externally, implement it here + return data, gt - # convert to torch tensors - x = torch.tensor(data, dtype=torch.float32) - y = torch.tensor(gt, dtype=torch.uint8) if gt is not None else gt - return x, y + # function that parses the date from a Sentinel 2 scene id + def date_parser(self, scene): + return parse_sentinel2_date(scene) class ProSnowGarmisch(ProSnowDataset): - def __init__(self, root_dir, use_bands=[], tile_size=None, sort=True): - super().__init__(root_dir, use_bands, tile_size, sort) + def __init__(self, root_dir, use_bands=[], tile_size=None, sort=True, + transforms=[None]): + super().__init__(root_dir, use_bands, tile_size, sort, transforms) def get_size(self): return (615, 543) @@ -485,8 +391,9 @@ class ProSnowGarmisch(ProSnowDataset): class ProSnowObergurgl(ProSnowDataset): - def __init__(self, root_dir, use_bands=[], tile_size=None, sort=True): - super().__init__(root_dir, use_bands, tile_size, sort) + def __init__(self, root_dir, use_bands=[], tile_size=None, sort=True, + transforms=[None]): + super().__init__(root_dir, use_bands, tile_size, sort, transforms) def get_size(self): return (310, 270) @@ -494,21 +401,16 @@ class ProSnowObergurgl(ProSnowDataset): class Cloud95Dataset(ImageDataset): - def __init__(self, root_dir, use_bands=[], tile_size=None, exclude=None): + def __init__(self, root_dir, use_bands=[], tile_size=None, sort=False, + transforms=[None]): - # initialize super class ImageDataset - super().__init__(root_dir, use_bands, tile_size) - - # whether to exclude patches with more than 80% black pixels, i.e. - # patches resulting from the black margins around a Landsat 8 scene - self.exclude = exclude - - # function that parses the date from a Landsat 8 scene id - self.date_parser = parse_landsat8_date + # the csv file containing the names of the informative patches + # patches with more than 80% black pixels, i.e. patches resulting from + # the black margins around a Landsat 8 scene are excluded + self.exclude = 'training_patches_95-cloud_nonempty.csv' - # list of all scenes in the root directory - # each scene is divided into tiles blocks - self.scenes = self.compose_scenes() + # initialize super class ImageDataset + super().__init__(root_dir, use_bands, tile_size, sort, transforms) # image size of the Cloud-95 dataset: (height, width) def get_size(self): @@ -529,10 +431,14 @@ class Cloud95Dataset(ImageDataset): # normalize the data # here, we use the normalization of the authors of Cloud-95, i.e. # Mohajerani and Saeedi (2019, 2020) - x = torch.tensor(data / 65535, dtype=torch.float32) - y = torch.tensor(gt / 255, dtype=torch.uint8) + data /= 65535 + gt /= 255 - return x, y + return data, gt + + # function that parses the date from a Landsat 8 scene id + def date_parser(self, scene): + return parse_landsat8_date(scene) def compose_scenes(self): @@ -577,23 +483,29 @@ class Cloud95Dataset(ImageDataset): # iterate over the tiles for tile in range(self.tiles): - # initialize dictionary to store bands of current patch - scene = {} + # iterate over the transformations to apply + for transf in self.transforms: + + # initialize dictionary to store bands of current patch + scene = {} + + # iterate over the bands of interest + for band in band_dirs.keys(): + # save path to current band tif file to dictionary + scene[band] = os.path.join(band_dirs[band], + file.replace(biter, band)) - # iterate over the bands of interest - for band in band_dirs.keys(): - # save path to current band tif file to dictionary - scene[band] = os.path.join(band_dirs[band], - file.replace(biter, band)) + # store tile number + scene['tile'] = tile - # store tile number - scene['tile'] = tile + # store date + scene['date'] = date - # store date - scene['date'] = date + # store optional transformation + scene['transform'] = transf - # append patch to list of all patches - scenes.append(scene) + # append patch to list of all patches + scenes.append(scene) # sort list of scenes in chronological order if self.sort: @@ -624,19 +536,22 @@ if __name__ == '__main__': # instanciate the Cloud-95 dataset # cloud_dataset = Cloud95Dataset(cloud_path, # tile_size=192, - # exclude=patches) + # use_bands=[], + # sort=False) # instanciate the SparcsDataset class sparcs_dataset = SparcsDataset(sparcs_path, tile_size=None, use_bands=['nir', 'red', 'green'], - sort=False) + sort=False, + transforms=[None]) # instanciate the ProSnow datasets garmisch = ProSnowGarmisch(os.path.join(prosnow_path, 'Garmisch'), tile_size=None, use_bands=['nir', 'red', 'green'], - sort=True) + sort=True, + transforms=[None]) # obergurgl = ProSnowObergurgl(os.path.join(prosnow_path, 'Obergurgl'), # tile_size=None, # use_bands=['nir', 'red', 'green'], -- GitLab