From 39438a4d4bd5a5e2e11f2b1da261ba94752a9982 Mon Sep 17 00:00:00 2001 From: "Daniel.Frisinghelli" <daniel.frisinghelli@eurac.edu> Date: Tue, 14 Jul 2020 17:23:39 +0200 Subject: [PATCH] Added an option to chronologically sort a dataset --- pytorch/dataset.py | 80 ++++++++++++++++++++++++++++++++++++---------- 1 file changed, 63 insertions(+), 17 deletions(-) diff --git a/pytorch/dataset.py b/pytorch/dataset.py index 3218360..36a6456 100644 --- a/pytorch/dataset.py +++ b/pytorch/dataset.py @@ -33,6 +33,7 @@ sys.path.append('..') from pytorch.constants import (Landsat8, Sentinel2, SparcsLabels, Cloud95Labels, ProSnowLabels) from pytorch.graphics import plot_sample +from pytorch.utils import parse_landsat8_date, parse_sentinel2_date # generic image dataset class class ImageDataset(Dataset): @@ -309,10 +310,18 @@ class ImageDataset(Dataset): class StandardEoDataset(ImageDataset): - def __init__(self, root_dir, use_bands, tile_size): + def __init__(self, root_dir, use_bands, tile_size, sort): # initialize super class ImageDataset super().__init__(root_dir, use_bands, tile_size) + # function that parses the date from a Landsat 8 scene id + self.date_parser = None + + # whether to sort the list of samples: + # for time series data, set sort=True to obtain the scenes in + # chronological order + self.sort = sort + # returns the band number of a Landsat8 or Sentinel2 tif file # x: path to a tif file def get_band_number(self, path): @@ -356,18 +365,22 @@ class StandardEoDataset(ImageDataset): # to the tif files of each scene # if the scenes are divided into tiles, each tile has its own entry # with corresponding tile id - def compose_scenes(self, pattern='*mask.png'): + def compose_scenes(self): # list of all samples in the dataset scenes = [] for scene in os.listdir(self.root): + # get the date of the current scene + date = self.date_parser(scene) + # list the spectral bands of the scene bands = glob.glob(os.path.join(self.root, scene, '*B*.tif')) # get the ground truth mask try: - gt = glob.glob(os.path.join(self.root, scene, pattern)).pop() + gt = glob.glob( + os.path.join(self.root, scene, '*mask.png')).pop() except IndexError: gt = None @@ -380,18 +393,28 @@ class StandardEoDataset(ImageDataset): # store tile number data['tile'] = tile + # store date + data['date'] = date + # append to list scenes.append(data) + # sort list of scenes in chronological order + if self.sort: + scenes.sort(key=lambda k: k['date']) + return scenes # SparcsDataset class: inherits from the generic ImageDataset class class SparcsDataset(StandardEoDataset): def __init__(self, root_dir, use_bands=['red', 'green', 'blue'], - tile_size=None): + tile_size=None, sort=False): # initialize super class ImageDataset - super().__init__(root_dir, use_bands, tile_size) + super().__init__(root_dir, use_bands, tile_size, sort) + + # function that parses the date from a Landsat 8 scene id + self.date_parser = parse_landsat8_date # list of all scenes in the root directory # each scene is divided into tiles blocks @@ -424,8 +447,11 @@ class SparcsDataset(StandardEoDataset): class ProSnowDataset(StandardEoDataset): - def __init__(self, root_dir, use_bands, tile_size): - super().__init__(root_dir, use_bands, tile_size) + def __init__(self, root_dir, use_bands, tile_size, sort): + super().__init__(root_dir, use_bands, tile_size, sort) + + # function that parses the date from a Sentinel 2 scene id + self.date_parser = parse_sentinel2_date # list of samples in the dataset self.scenes = self.compose_scenes() @@ -452,13 +478,22 @@ class ProSnowDataset(StandardEoDataset): class ProSnowGarmisch(ProSnowDataset): - def __init__(self, root_dir, use_bands=[], tile_size=None): - super().__init__(root_dir, use_bands, tile_size) + def __init__(self, root_dir, use_bands=[], tile_size=None, sort=True): + super().__init__(root_dir, use_bands, tile_size, sort) def get_size(self): return (615, 543) +class ProSnowObergurgl(ProSnowDataset): + + def __init__(self, root_dir, use_bands=[], tile_size=None, sort=True): + super().__init__(root_dir, use_bands, tile_size, sort) + + def get_size(self): + return (310, 270) + + class Cloud95Dataset(ImageDataset): def __init__(self, root_dir, use_bands=[], tile_size=None, exclude=None): @@ -570,17 +605,28 @@ if __name__ == '__main__': cloud_path = os.path.join(wd, '_Datasets/Cloud95/Training') # path to the ProSnow dataset - prosnow_path = os.path.join(wd, '_Datasets/ProSnow/Garmisch') + prosnow_path = os.path.join(wd, '_Datasets/ProSnow/') # the csv file containing the names of the informative patches - patches = 'training_patches_95-cloud_nonempty.csv' + # patches = 'training_patches_95-cloud_nonempty.csv' # instanciate the Cloud-95 dataset - cloud_dataset = Cloud95Dataset(cloud_path, tile_size=192, exclude=patches) + # cloud_dataset = Cloud95Dataset(cloud_path, + # tile_size=192, + # exclude=patches) # instanciate the SparcsDataset class - sparcs_dataset = SparcsDataset(sparcs_path, tile_size=None, - use_bands=['nir', 'red', 'green']) - - # instanciate the ProSnow class - prosnow_dataset = ProSnowGarmisch(prosnow_path) + sparcs_dataset = SparcsDataset(sparcs_path, + tile_size=None, + use_bands=['nir', 'red', 'green'], + sort=False) + + # instanciate the ProSnow datasets + garmisch = ProSnowGarmisch(os.path.join(prosnow_path, 'Garmisch'), + tile_size=None, + use_bands=['nir', 'red', 'green'], + sort=True) + obergurgl = ProSnowObergurgl(os.path.join(prosnow_path, 'Obergurgl'), + tile_size=None, + use_bands=['nir', 'red', 'green'], + sort=True) -- GitLab