From 959125677f82ecf71f8d76556dcdf350be1b86b0 Mon Sep 17 00:00:00 2001 From: "Daniel.Frisinghelli" <daniel.frisinghelli@eurac.edu> Date: Mon, 10 Aug 2020 16:56:22 +0200 Subject: [PATCH] Added representation to dataset; improved kwargs handling --- pysegcnn/core/dataset.py | 145 +++++++++++++++++++++++---------------- 1 file changed, 87 insertions(+), 58 deletions(-) diff --git a/pysegcnn/core/dataset.py b/pysegcnn/core/dataset.py index 9b2704c..a3371c0 100644 --- a/pysegcnn/core/dataset.py +++ b/pysegcnn/core/dataset.py @@ -19,14 +19,13 @@ import glob import enum import itertools - # externals import numpy as np import torch from torch.utils.data import Dataset # locals -from pysegcnn.core.constants import (Landsat8, Sentinel2, SparcsLabels, +from pysegcnn.core.constants import (Landsat8, Sentinel2, Label, SparcsLabels, Cloud95Labels, ProSnowLabels) from pysegcnn.core.utils import (img2np, is_divisible, tile_topleft_corner, parse_landsat_scene, parse_sentinel2_scene) @@ -35,6 +34,35 @@ from pysegcnn.core.utils import (img2np, is_divisible, tile_topleft_corner, # generic image dataset class class ImageDataset(Dataset): + # allowed keyword arguments and default values + default_kwargs = { + + # which bands to use, if use_bands=[], use all available bands + 'use_bands': [], + + # each scene is divided into (tile_size x tile_size) blocks + # each of these blocks is treated as a single sample + 'tile_size': None, + + # a pattern to match the ground truth file naming convention + 'gt_pattern': '*gt.tif', + + # whether to chronologically sort the samples + 'sort': False, + + # the transformations to apply to the original image + # artificially increases the training data size + 'transforms': [], + + # whether to pad the image to be evenly divisible in square tiles + # of size (tile_size x tile_size) + 'pad': False, + + # the value to pad the samples + 'cval': 0, + + } + def __init__(self, root_dir, **kwargs): super().__init__() @@ -64,48 +92,15 @@ class ImageDataset(Dataset): def _init_kwargs(self, **kwargs): - # define allowed keyword arguments - self.default_kwargs = { - # which bands to use, if use_bands=[], use all available bands - 'use_bands': [], - - # each scene is divided into (tile_size x tile_size) blocks - # each of these blocks is treated as a single sample - 'tile_size': None, - - # a pattern to match the ground truth file naming convention - 'gt_pattern': '*gt.tif', + # check if the keyword arguments are correctly specified + if not set(self.default_kwargs.keys()).issubset(kwargs.keys()): + raise TypeError('Valid keyword arguments are: \n' + + '\n'.join('- {}'.format(k) for k in + self.default_kwargs.keys())) - # whether to chronologically sort the samples - 'sort': False, - - # the transformations to apply to the original image - # artificially increases the training data size - 'transforms': [], - - # whether to pad the image to be evenly divisible in square tiles - # of size (tile_size x tile_size) - 'pad': False, - - # the value to pad the samples - 'cval': 0, - - } - - # set default kwargs + # update default arguments with specified keyword argument values + self.default_kwargs.update(kwargs) for k, v in self.default_kwargs.items(): - # store default keyword arguments as instance attributes - setattr(self, k, v) - - # check whether the keyword arguments are correctly specified - for k, v in kwargs.items(): - if k not in self.default_kwargs.keys(): - raise TypeError('"{}" is not a valid keyword argument. ' - 'Valid keyword arguments are: \n'.format(k) + - '\n'.join('- {}'.format(k) for k in - self.default_kwargs.keys())) - - # store keyword argument as instance attribute setattr(self, k, v) # check which bands to use @@ -152,8 +147,8 @@ class ImageDataset(Dataset): self.labels[self.cval] = {'label': 'No data', 'color': 'black'} def _build_labels(self): - return {band.value[0]: {'label': band.name.replace('_', ' '), - 'color': band.value[1]} + return {band.id: {'label': band.name.replace('_', ' '), + 'color': band.color} for band in self._label_class} def _assert_compose_scenes(self): @@ -186,17 +181,18 @@ class ImageDataset(Dataset): 'enum.Enum, containing an enumeration of the ' 'spectral bands of the sensor the dataset is ' 'derived from. Examples can be found in ' - 'pytorch.constants.py.' + 'pysegcnn.core.constants.py.' .format(self.__class__.__name__)) def _assert_get_labels(self): - if not isinstance(self._label_class, enum.EnumMeta): + if not issubclass(self._label_class, Label): raise TypeError('{}.get_labels() should return an instance of ' - 'enum.Enum, containing an enumeration of the ' + 'pysegcnn.core.constants.Label, ' + 'containing an enumeration of the ' 'class labels, together with the corresponing id ' 'in the ground truth mask and a color for ' 'visualization. Examples can be found in ' - 'pytorch.constants.py.' + 'pysegcnn.core.constants.py.' .format(self.__class__.__name__)) # the __len__() method returns the number of samples in the dataset @@ -212,12 +208,12 @@ class ImageDataset(Dataset): # select a scene scene = self.read_scene(idx) - # get samples: (tiles x channels x height x width) + # get samples + # data: (tiles, bands, height, width) + # gt: (height, width) data, gt = self.build_samples(scene) - # preprocess input and return torch tensors of shape: - # x : (bands, height, width) - # y : (height, width) + # preprocess samples x, y = self.preprocess(data, gt) # optional transformation @@ -253,14 +249,13 @@ class ImageDataset(Dataset): raise NotImplementedError('Inherit the ImageDataset class and ' 'implement the method.') - # the get_bands() method has to be implemented by the class inheriting + # the get_sensor() method has to be implemented by the class inheriting # the ImageDataset class - # get_bands() should return a dictionary with the following - # (key: int, value: str) pairs: - # - (1, band_1_name) - # - (2, band_2_name) + # get_sensor() should return an enum.Enum with the following + # (name: str, value: int) tuples: + # - (red, 2) + # - (green, 3) # - ... - # - (n, band_n_name) def get_sensor(self, *args, **kwargs): raise NotImplementedError('Inherit the ImageDataset class and ' 'implement the method.') @@ -335,6 +330,40 @@ class ImageDataset(Dataset): return (torch.tensor(x.copy(), dtype=torch.float32), torch.tensor(y.copy(), dtype=torch.uint8)) + def __repr__(self): + + # representation string to print + fs = self.__class__.__name__ + '(\n' + + # sensor + fs += ' (sensor):\n - ' + self.sensor.__name__ + + # bands used for the segmentation + fs += '\n (bands):\n ' + fs += '\n '.join('- Band {}: {}'.format(i, b) for i, b in + enumerate(self.use_bands)) + + # scenes + fs += '\n (scene):\n ' + fs += '- size (h, w): {}\n '.format((self.height, self.width)) + fs += '- number of scenes: {}\n '.format( + len(np.unique([f['id'] for f in self.scenes]))) + fs += '- padding (bottom, left, top, right): {}'.format(self.padding) + + # tiles + fs += '\n (tiles):\n ' + fs += '- number of tiles per scene: {}\n '.format(self.tiles) + fs += '- tile size: {}\n '.format((self.tile_size, + self.tile_size)) + fs += '- number of tiles: {}'.format(len(self.scenes)) + + # classes of interest + fs += '\n (classes):\n ' + fs += '\n '.join('- Class {}: {}'.format(k, v['label']) for + k, v in self.labels.items()) + fs += '\n)' + return fs + class StandardEoDataset(ImageDataset): -- GitLab