pysegcnn.core.dataset.ImageDataset

class pysegcnn.core.dataset.ImageDataset(root_dir, use_bands=[], tile_size=None, pad=False, gt_pattern='(.*)gt.tif', sort=False, seed=0, transforms=[])[source]

Base class for multispectral image data.

Inheriting from torch.utils.data.Dataset to be compliant to the PyTorch standard. This enables the use of the handy torch.utils.data.DataLoader class during model training.

__init__(root_dir, use_bands=[], tile_size=None, pad=False, gt_pattern='(.*)gt.tif', sort=False, seed=0, transforms=[])[source]

Initialize.

Parameters
root_dirstr

The root directory, path to the dataset.

use_bandslist [str], optional

A list of the spectral bands to use. The default is [].

tile_sizeint or None, optional

The size of the tiles. If not None, each scene is divided into square tiles of shape (tile_size, tile_size). The default is None.

padbool, optional

Whether to center pad the input image. Set pad=True, if the images are not evenly divisible by the tile_size. The image data is padded with a constant padding value of zero. For each image, the corresponding ground truth image is padded with a “no data” label. The default is False.

gt_patternstr, optional

A regural expression to match the ground truth naming convention. All directories and subdirectories in root_dir are searched for files matching gt_pattern. The default is (.*)gt\.tif.

sortbool, optional

Whether to chronologically sort the samples. Useful for time series data. The default is False.

seedint, optional

The random seed. Used to split the dataset into training, validation and test set. Useful for reproducibility. The default is 0.

transformslist, optional

List of pysegcnn.core.transforms.Augment instances. Each item in transforms generates a distinct transformed version of the dataset. The total dataset is composed of the original untransformed dataset together with each transformed version of it. If transforms=[], only the original dataset is used. The default is [].

Methods

__init__(root_dir[, use_bands, tile_size, …])

Initialize.

build_samples(scene)

Stack the bands of a sample in a single array.

compose_scenes()

Build the list of samples of the dataset.

get_labels()

Return an enumeration of the class labels of the dataset.

get_sensor()

Return an enumeration of the bands of the sensor of the dataset.

get_size()

Return the size of the images in the dataset.

parse_scene_id(scene_id)

Parse the scene identifier.

preprocess(data, gt)

Preprocess a sample before feeding it to a model.

read_scene(idx)

Read the data of the sample with index idx.

to_tensor(x, dtype)

Convert x to torch.Tensor.