Skip to content
Snippets Groups Projects
Commit 95912567 authored by Frisinghelli Daniel's avatar Frisinghelli Daniel
Browse files

Added representation to dataset; improved kwargs handling

parent 9374ee3e
No related branches found
No related tags found
No related merge requests found
......@@ -19,14 +19,13 @@ import glob
import enum
import itertools
# externals
import numpy as np
import torch
from torch.utils.data import Dataset
# locals
from pysegcnn.core.constants import (Landsat8, Sentinel2, SparcsLabels,
from pysegcnn.core.constants import (Landsat8, Sentinel2, Label, SparcsLabels,
Cloud95Labels, ProSnowLabels)
from pysegcnn.core.utils import (img2np, is_divisible, tile_topleft_corner,
parse_landsat_scene, parse_sentinel2_scene)
......@@ -35,6 +34,35 @@ from pysegcnn.core.utils import (img2np, is_divisible, tile_topleft_corner,
# generic image dataset class
class ImageDataset(Dataset):
# allowed keyword arguments and default values
default_kwargs = {
# which bands to use, if use_bands=[], use all available bands
'use_bands': [],
# each scene is divided into (tile_size x tile_size) blocks
# each of these blocks is treated as a single sample
'tile_size': None,
# a pattern to match the ground truth file naming convention
'gt_pattern': '*gt.tif',
# whether to chronologically sort the samples
'sort': False,
# the transformations to apply to the original image
# artificially increases the training data size
'transforms': [],
# whether to pad the image to be evenly divisible in square tiles
# of size (tile_size x tile_size)
'pad': False,
# the value to pad the samples
'cval': 0,
}
def __init__(self, root_dir, **kwargs):
super().__init__()
......@@ -64,48 +92,15 @@ class ImageDataset(Dataset):
def _init_kwargs(self, **kwargs):
# define allowed keyword arguments
self.default_kwargs = {
# which bands to use, if use_bands=[], use all available bands
'use_bands': [],
# each scene is divided into (tile_size x tile_size) blocks
# each of these blocks is treated as a single sample
'tile_size': None,
# a pattern to match the ground truth file naming convention
'gt_pattern': '*gt.tif',
# check if the keyword arguments are correctly specified
if not set(self.default_kwargs.keys()).issubset(kwargs.keys()):
raise TypeError('Valid keyword arguments are: \n' +
'\n'.join('- {}'.format(k) for k in
self.default_kwargs.keys()))
# whether to chronologically sort the samples
'sort': False,
# the transformations to apply to the original image
# artificially increases the training data size
'transforms': [],
# whether to pad the image to be evenly divisible in square tiles
# of size (tile_size x tile_size)
'pad': False,
# the value to pad the samples
'cval': 0,
}
# set default kwargs
# update default arguments with specified keyword argument values
self.default_kwargs.update(kwargs)
for k, v in self.default_kwargs.items():
# store default keyword arguments as instance attributes
setattr(self, k, v)
# check whether the keyword arguments are correctly specified
for k, v in kwargs.items():
if k not in self.default_kwargs.keys():
raise TypeError('"{}" is not a valid keyword argument. '
'Valid keyword arguments are: \n'.format(k) +
'\n'.join('- {}'.format(k) for k in
self.default_kwargs.keys()))
# store keyword argument as instance attribute
setattr(self, k, v)
# check which bands to use
......@@ -152,8 +147,8 @@ class ImageDataset(Dataset):
self.labels[self.cval] = {'label': 'No data', 'color': 'black'}
def _build_labels(self):
return {band.value[0]: {'label': band.name.replace('_', ' '),
'color': band.value[1]}
return {band.id: {'label': band.name.replace('_', ' '),
'color': band.color}
for band in self._label_class}
def _assert_compose_scenes(self):
......@@ -186,17 +181,18 @@ class ImageDataset(Dataset):
'enum.Enum, containing an enumeration of the '
'spectral bands of the sensor the dataset is '
'derived from. Examples can be found in '
'pytorch.constants.py.'
'pysegcnn.core.constants.py.'
.format(self.__class__.__name__))
def _assert_get_labels(self):
if not isinstance(self._label_class, enum.EnumMeta):
if not issubclass(self._label_class, Label):
raise TypeError('{}.get_labels() should return an instance of '
'enum.Enum, containing an enumeration of the '
'pysegcnn.core.constants.Label, '
'containing an enumeration of the '
'class labels, together with the corresponing id '
'in the ground truth mask and a color for '
'visualization. Examples can be found in '
'pytorch.constants.py.'
'pysegcnn.core.constants.py.'
.format(self.__class__.__name__))
# the __len__() method returns the number of samples in the dataset
......@@ -212,12 +208,12 @@ class ImageDataset(Dataset):
# select a scene
scene = self.read_scene(idx)
# get samples: (tiles x channels x height x width)
# get samples
# data: (tiles, bands, height, width)
# gt: (height, width)
data, gt = self.build_samples(scene)
# preprocess input and return torch tensors of shape:
# x : (bands, height, width)
# y : (height, width)
# preprocess samples
x, y = self.preprocess(data, gt)
# optional transformation
......@@ -253,14 +249,13 @@ class ImageDataset(Dataset):
raise NotImplementedError('Inherit the ImageDataset class and '
'implement the method.')
# the get_bands() method has to be implemented by the class inheriting
# the get_sensor() method has to be implemented by the class inheriting
# the ImageDataset class
# get_bands() should return a dictionary with the following
# (key: int, value: str) pairs:
# - (1, band_1_name)
# - (2, band_2_name)
# get_sensor() should return an enum.Enum with the following
# (name: str, value: int) tuples:
# - (red, 2)
# - (green, 3)
# - ...
# - (n, band_n_name)
def get_sensor(self, *args, **kwargs):
raise NotImplementedError('Inherit the ImageDataset class and '
'implement the method.')
......@@ -335,6 +330,40 @@ class ImageDataset(Dataset):
return (torch.tensor(x.copy(), dtype=torch.float32),
torch.tensor(y.copy(), dtype=torch.uint8))
def __repr__(self):
# representation string to print
fs = self.__class__.__name__ + '(\n'
# sensor
fs += ' (sensor):\n - ' + self.sensor.__name__
# bands used for the segmentation
fs += '\n (bands):\n '
fs += '\n '.join('- Band {}: {}'.format(i, b) for i, b in
enumerate(self.use_bands))
# scenes
fs += '\n (scene):\n '
fs += '- size (h, w): {}\n '.format((self.height, self.width))
fs += '- number of scenes: {}\n '.format(
len(np.unique([f['id'] for f in self.scenes])))
fs += '- padding (bottom, left, top, right): {}'.format(self.padding)
# tiles
fs += '\n (tiles):\n '
fs += '- number of tiles per scene: {}\n '.format(self.tiles)
fs += '- tile size: {}\n '.format((self.tile_size,
self.tile_size))
fs += '- number of tiles: {}'.format(len(self.scenes))
# classes of interest
fs += '\n (classes):\n '
fs += '\n '.join('- Class {}: {}'.format(k, v['label']) for
k, v in self.labels.items())
fs += '\n)'
return fs
class StandardEoDataset(ImageDataset):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment