From 39438a4d4bd5a5e2e11f2b1da261ba94752a9982 Mon Sep 17 00:00:00 2001
From: "Daniel.Frisinghelli" <daniel.frisinghelli@eurac.edu>
Date: Tue, 14 Jul 2020 17:23:39 +0200
Subject: [PATCH] Added an option to chronologically sort a dataset

---
 pytorch/dataset.py | 80 ++++++++++++++++++++++++++++++++++++----------
 1 file changed, 63 insertions(+), 17 deletions(-)

diff --git a/pytorch/dataset.py b/pytorch/dataset.py
index 3218360..36a6456 100644
--- a/pytorch/dataset.py
+++ b/pytorch/dataset.py
@@ -33,6 +33,7 @@ sys.path.append('..')
 from pytorch.constants import (Landsat8, Sentinel2, SparcsLabels,
                                Cloud95Labels, ProSnowLabels)
 from pytorch.graphics import plot_sample
+from pytorch.utils import parse_landsat8_date, parse_sentinel2_date
 
 # generic image dataset class
 class ImageDataset(Dataset):
@@ -309,10 +310,18 @@ class ImageDataset(Dataset):
 
 class StandardEoDataset(ImageDataset):
 
-    def __init__(self, root_dir, use_bands, tile_size):
+    def __init__(self, root_dir, use_bands, tile_size, sort):
         # initialize super class ImageDataset
         super().__init__(root_dir, use_bands, tile_size)
 
+        # function that parses the date from a Landsat 8 scene id
+        self.date_parser = None
+
+        # whether to sort the list of samples:
+        # for time series data, set sort=True to obtain the scenes in
+        # chronological order
+        self.sort = sort
+
     # returns the band number of a Landsat8 or Sentinel2 tif file
     # x: path to a tif file
     def get_band_number(self, path):
@@ -356,18 +365,22 @@ class StandardEoDataset(ImageDataset):
     # to the tif files of each scene
     # if the scenes are divided into tiles, each tile has its own entry
     # with corresponding tile id
-    def compose_scenes(self, pattern='*mask.png'):
+    def compose_scenes(self):
 
         # list of all samples in the dataset
         scenes = []
         for scene in os.listdir(self.root):
 
+            # get the date of the current scene
+            date = self.date_parser(scene)
+
             # list the spectral bands of the scene
             bands = glob.glob(os.path.join(self.root, scene, '*B*.tif'))
 
             # get the ground truth mask
             try:
-                gt = glob.glob(os.path.join(self.root, scene, pattern)).pop()
+                gt = glob.glob(
+                    os.path.join(self.root, scene, '*mask.png')).pop()
             except IndexError:
                 gt = None
 
@@ -380,18 +393,28 @@ class StandardEoDataset(ImageDataset):
                 # store tile number
                 data['tile'] = tile
 
+                # store date
+                data['date'] = date
+
                 # append to list
                 scenes.append(data)
 
+        # sort list of scenes in chronological order
+        if self.sort:
+            scenes.sort(key=lambda k: k['date'])
+
         return scenes
 
 # SparcsDataset class: inherits from the generic ImageDataset class
 class SparcsDataset(StandardEoDataset):
 
     def __init__(self, root_dir, use_bands=['red', 'green', 'blue'],
-                 tile_size=None):
+                 tile_size=None, sort=False):
         # initialize super class ImageDataset
-        super().__init__(root_dir, use_bands, tile_size)
+        super().__init__(root_dir, use_bands, tile_size, sort)
+
+        # function that parses the date from a Landsat 8 scene id
+        self.date_parser = parse_landsat8_date
 
         # list of all scenes in the root directory
         # each scene is divided into tiles blocks
@@ -424,8 +447,11 @@ class SparcsDataset(StandardEoDataset):
 
 class ProSnowDataset(StandardEoDataset):
 
-    def __init__(self, root_dir, use_bands, tile_size):
-        super().__init__(root_dir, use_bands, tile_size)
+    def __init__(self, root_dir, use_bands, tile_size, sort):
+        super().__init__(root_dir, use_bands, tile_size, sort)
+
+        # function that parses the date from a Sentinel 2 scene id
+        self.date_parser = parse_sentinel2_date
 
         # list of samples in the dataset
         self.scenes = self.compose_scenes()
@@ -452,13 +478,22 @@ class ProSnowDataset(StandardEoDataset):
 
 class ProSnowGarmisch(ProSnowDataset):
 
-    def __init__(self, root_dir, use_bands=[], tile_size=None):
-        super().__init__(root_dir, use_bands, tile_size)
+    def __init__(self, root_dir, use_bands=[], tile_size=None, sort=True):
+        super().__init__(root_dir, use_bands, tile_size, sort)
 
     def get_size(self):
         return (615, 543)
 
 
+class ProSnowObergurgl(ProSnowDataset):
+
+    def __init__(self, root_dir, use_bands=[], tile_size=None, sort=True):
+        super().__init__(root_dir, use_bands, tile_size, sort)
+
+    def get_size(self):
+        return (310, 270)
+
+
 class Cloud95Dataset(ImageDataset):
 
     def __init__(self, root_dir, use_bands=[], tile_size=None, exclude=None):
@@ -570,17 +605,28 @@ if __name__ == '__main__':
     cloud_path = os.path.join(wd, '_Datasets/Cloud95/Training')
 
     # path to the ProSnow dataset
-    prosnow_path = os.path.join(wd, '_Datasets/ProSnow/Garmisch')
+    prosnow_path = os.path.join(wd, '_Datasets/ProSnow/')
 
     # the csv file containing the names of the informative patches
-    patches = 'training_patches_95-cloud_nonempty.csv'
+    # patches = 'training_patches_95-cloud_nonempty.csv'
 
     # instanciate the Cloud-95 dataset
-    cloud_dataset = Cloud95Dataset(cloud_path, tile_size=192, exclude=patches)
+    # cloud_dataset = Cloud95Dataset(cloud_path,
+    #                                tile_size=192,
+    #                                exclude=patches)
 
     # instanciate the SparcsDataset class
-    sparcs_dataset = SparcsDataset(sparcs_path, tile_size=None,
-                                   use_bands=['nir', 'red', 'green'])
-
-    # instanciate the ProSnow class
-    prosnow_dataset = ProSnowGarmisch(prosnow_path)
+    sparcs_dataset = SparcsDataset(sparcs_path,
+                                   tile_size=None,
+                                   use_bands=['nir', 'red', 'green'],
+                                   sort=False)
+
+    # instanciate the ProSnow datasets
+    garmisch = ProSnowGarmisch(os.path.join(prosnow_path, 'Garmisch'),
+                               tile_size=None,
+                               use_bands=['nir', 'red', 'green'],
+                               sort=True)
+    obergurgl = ProSnowObergurgl(os.path.join(prosnow_path, 'Obergurgl'),
+                                 tile_size=None,
+                                 use_bands=['nir', 'red', 'green'],
+                                 sort=True)
-- 
GitLab