From 725b420e4cd7cd29430b3acaabc82b7e2b9d1b1b Mon Sep 17 00:00:00 2001 From: "Daniel.Frisinghelli" <daniel.frisinghelli@eurac.edu> Date: Wed, 19 Aug 2020 17:27:52 +0200 Subject: [PATCH] Adding docstrings: part 3 --- pysegcnn/core/dataset.py | 1140 +++++++++++++++++++++++++++++-------- pysegcnn/core/graphics.py | 12 +- pysegcnn/core/layers.py | 32 +- pysegcnn/core/models.py | 2 +- pysegcnn/core/split.py | 6 +- pysegcnn/core/trainer.py | 396 ++++++++++++- 6 files changed, 1320 insertions(+), 268 deletions(-) diff --git a/pysegcnn/core/dataset.py b/pysegcnn/core/dataset.py index 98a2428..d09b713 100644 --- a/pysegcnn/core/dataset.py +++ b/pysegcnn/core/dataset.py @@ -35,45 +35,76 @@ from pysegcnn.core.utils import (img2np, is_divisible, tile_topleft_corner, LOGGER = logging.getLogger(__name__) -# generic image dataset class class ImageDataset(Dataset): - - # allowed keyword arguments and default values - default_kwargs = { - - # which bands to use, if use_bands=[], use all available bands - 'use_bands': [], - - # each scene is divided into (tile_size x tile_size) blocks - # each of these blocks is treated as a single sample - 'tile_size': None, - - # a pattern to match the ground truth file naming convention - 'gt_pattern': '*gt.tif', - - # whether to chronologically sort the samples - 'sort': False, - - # the random seed used to split dataset into training, validation and - # test data - 'seed': 0, - - # the transformations to apply to the original image - # artificially increases the training data size - 'transforms': [], - - # whether to pad the image to be evenly divisible in square tiles - # of size (tile_size x tile_size) - 'pad': False - - } - - def __init__(self, root_dir, **kwargs): + """Base class for multispectral image data. + + Inheriting from `torch.utils.data.Dataset` to be compliant to the PyTorch + standard. Furthermore, using instances of `torch.utils.data.Dataset` + enables the use of the handy `torch.utils.data.DataLoader` class during + model training. + + Parameters + ---------- + root_dir : `str` + The root directory, path to the dataset. + use_bands : `list` [`str`], optional + A list of the spectral bands to use. The default is []. + tile_size : `int` or `None`, optional + The size of the tiles. If not `None`, each scene is divided into square + tiles of shape (tile_size, tile_size). The default is None. + pad : `bool`, optional + Whether to center pad the input image. Set ``pad`` = True, if the + images are not evenly divisible by the ``tile_size``. The image data is + padded with a constant padding value of zero. For each image, the + corresponding ground truth image is padded with a "no data" label. + The default is False. + gt_pattern : `str`, optional + A pattern to match the ground truth naming convention. All directories + and subdirectories in ``root_dir`` are searched for files matching + ``gt_pattern``. The default is '*gt.tif'. + sort : `bool`, optional + Whether to chronologically sort the samples. Useful for time series + data. The default is False. + seed : `int`, optional + The random seed. Used to split the dataset into training, validation + and test set. Useful for reproducibility. The default is 0. + transforms : `list` [`pysegcnn.core.split.Augment`], optional + List of `pysegcnn.core.split.Augment` instances. Each item in + ``transforms`` generates a distinct transformed version of the dataset. + The total dataset is composed of the original untransformed dataset + together with each transformed version of it. + If ``transforms`` = [], only the original dataset is used. + The default is []. + + Returns + ------- + None. + + """ + + def __init__(self, root_dir, use_bands=[], tile_size=None, pad=False, + gt_pattern='*gt.tif', sort=False, seed=0, transforms=[]): super().__init__() - # the root directory: path to the image dataset + # dataset configuration self.root = root_dir + self.use_bands = use_bands + self.tile_size = tile_size + self.pad = pad + self.gt_pattern = gt_pattern + self.sort = sort + self.seed = seed + self.transforms = transforms + + # initialize instance attributes + self._init_attributes() + # the samples of the dataset + self.scenes = self.compose_scenes() + self._assert_compose_scenes() + + def _init_attributes(self): + """Initialize the class instance attributes.""" # the size of a scene/patch in the dataset self.size = self.get_size() self._assert_get_size() @@ -88,26 +119,6 @@ class ImageDataset(Dataset): self._assert_get_labels() self.labels = self._build_labels() - # initialize keyword arguments - self._init_kwargs(**kwargs) - - # the samples of the dataset - self.scenes = self.compose_scenes() - self._assert_compose_scenes() - - def _init_kwargs(self, **kwargs): - - # check if the keyword arguments are correctly specified - if not set(kwargs.keys()).issubset(self.default_kwargs.keys()): - raise TypeError('Valid keyword arguments are: \n' + - '\n'.join('- {}'.format(k) for k in - self.default_kwargs.keys())) - - # update default arguments with specified keyword argument values - self.default_kwargs.update(kwargs) - for k, v in self.default_kwargs.items(): - setattr(self, k, v) - # check which bands to use self.use_bands = (self.use_bands if self.use_bands else [*self.bands.values()]) @@ -140,12 +151,25 @@ class ImageDataset(Dataset): .format(self.cval)) def _build_labels(self): + """Build the label dictionary. + + Returns + ------- + labels : `dict` [`int`, `dict`] + The label dictionary. The keys are the values of the class labels + in the ground truth ``y``. Each nested `dict` should have keys: + ``'color'`` + A named color (`str`). + ``'label'`` + The name of the class label (`str`). + + """ return {band.id: {'label': band.name.replace('_', ' '), 'color': band.color} for band in self._label_class} def _assert_compose_scenes(self): - + """Check whether compose_scenes() is correctly implemented.""" # list of required keys self.keys = self.use_bands + ['gt', 'date', 'tile', 'transform', 'id'] @@ -163,12 +187,14 @@ class ImageDataset(Dataset): .format(self.keys)) def _assert_get_size(self): + """Check whether get_size() is correctly implemented.""" if not isinstance(self.size, tuple) and len(self.size) == 2: raise TypeError('{}.get_size() should return the spatial size of ' 'an image sample as tuple, i.e. (height, width).' .format(self.__class__.__name__)) def _assert_get_sensor(self): + """Check whether get_sensor() is correctly implemented.""" if not isinstance(self.sensor, enum.EnumMeta): raise TypeError('{}.get_sensor() should return an instance of ' 'enum.Enum, containing an enumeration of the ' @@ -178,6 +204,7 @@ class ImageDataset(Dataset): .format(self.__class__.__name__)) def _assert_get_labels(self): + """Check whether get_labels() is correctly implemented.""" if not issubclass(self._label_class, Label): raise TypeError('{}.get_labels() should return an instance of ' 'pysegcnn.core.constants.Label, ' @@ -188,16 +215,32 @@ class ImageDataset(Dataset): 'pysegcnn.core.constants.py.' .format(self.__class__.__name__)) - # the __len__() method returns the number of samples in the dataset def __len__(self): - # number of (tiles x channels x height x width) patches after each - # scene is decomposed to tiles blocks + """Return the number of samples in the dataset. + + Returns + ------- + nsamples : `int` + The number of samples in the dataset. + """ return len(self.scenes) - # the __getitem__() method returns a single sample of the dataset given an - # index, i.e. an array/tensor of shape (channels x height x width) def __getitem__(self, idx): + """Return the data of a sample of the dataset given an index ``idx``. + + Parameters + ---------- + idx : `int` + The index of the sample. + + Returns + ------- + x : `torch.Tensor` + The sample input data. + y : `torch.Tensor` + The sample ground truth. + """ # select a scene scene = self.read_scene(idx) @@ -209,88 +252,191 @@ class ImageDataset(Dataset): # preprocess samples x, y = self.preprocess(data, gt) - # optional transformation + # apply transformation if scene['transform'] is not None: x, y = scene['transform'](x, y) # convert to torch tensors - x, y = self.to_tensor(x, y) + x = self.to_tensor(x, dtype=torch.float32) + y = self.to_tensor(y, dtype=torch.uint8) return x, y - # the compose_scenes() method has to be implemented by the class inheriting - # the ImageDataset class - # compose_scenes() should return a list of dictionaries, where each - # dictionary represent one sample of the dataset, a scene or a tile - # of a scene, etc. - # the dictionaries should have the following (key, value) pairs: - # - (band_1, path_to_band_1.tif) - # - (band_2, path_to_band_2.tif) - # - ... - # - (band_n, path_to_band_n.tif) - # - (gt, path_to_ground_truth.tif) - # - (tile, None or int) - # - ... - def compose_scenes(self, *args, **kwargs): + def compose_scenes(self): + """Build the list of samples of the dataset. + + Each sample is represented by a dictionary. + + Raises + ------ + NotImplementedError + Raised if the `pysegcnn.core.dataset.ImageDataset` class is not + inherited. + + Returns + ------- + samples : `list` [`dict`] + Each dictionary representing a sample should have keys: + ``'band_name_1'`` + Path to the file of band_1. + ``'band_name_2'`` + Path to the file of band_2. + ``'band_name_n'`` + Path to the file of band_n. + ``'gt'`` + Path to the ground truth file. + ``'date'`` + The date of the sample. + ``'tile'`` + The tile id of the sample. + ``'transform'`` + The transformation to apply. + ``'id'`` + The scene identifier. + + """ raise NotImplementedError('Inherit the ImageDataset class and ' 'implement the method.') - # the get_size() method has to be implemented by the class inheriting - # the ImageDataset class - # get_size() method should return the image size as tuple, (height, width) - def get_size(self, *args, **kwargs): + def get_size(self): + """Return the size of the images in the dataset. + + Raises + ------ + NotImplementedError + Raised if the `pysegcnn.core.dataset.ImageDataset` class is not + inherited. + + Returns + ------- + size : `tuple` + The image size (height, width). + + """ raise NotImplementedError('Inherit the ImageDataset class and ' 'implement the method.') - # the get_sensor() method has to be implemented by the class inheriting - # the ImageDataset class - # get_sensor() should return an enum.Enum with the following - # (name: str, value: int) tuples: - # - (red, 2) - # - (green, 3) - # - ... - def get_sensor(self, *args, **kwargs): + def get_sensor(self): + """Return an enumeration of the bands of the sensor of the dataset. + + Examples can be found in `pysegcnn.core.constants`. + + Raises + ------ + NotImplementedError + Raised if the `pysegcnn.core.dataset.ImageDataset` class is not + inherited. + + Returns + ------- + sensor : `enum.Enum` + An enumeration of the bands of the sensor. + + """ raise NotImplementedError('Inherit the ImageDataset class and ' 'implement the method.') - # the get_labels() method has to be implemented by the class inheriting - # the ImageDataset class - # get_labels() should return a dictionary with the following - # (key: int, value: str) pairs: - # - (0, label_1_name) - # - (1, label_2_name) - # - ... - # - (n, label_n_name) - # where the keys should be the values representing the values of the - # corresponding label in the ground truth mask - # the labels in the dictionary determine the classes to be segmented - def get_labels(self, *args, **kwargs): + def get_labels(self): + """Return an enumeration of the class labels of the dataset. + + Examples can be found in `pysegcnn.core.constants`. + + Raises + ------ + NotImplementedError + Raised if the `pysegcnn.core.dataset.ImageDataset` class is not + inherited. + + Returns + ------- + labels : `enum.Enum` + The class labels. + + """ raise NotImplementedError('Inherit the ImageDataset class and ' 'implement the method.') - # the preprocess() method has to be implemented by the class inheriting - # the ImageDataset class - # preprocess() should return two torch.tensors: - # - input data: tensor of shape (bands, height, width) - # - ground truth: tensor of shape (height, width) def preprocess(self, data, gt): + """Preprocess a sample before feeding it to a model. + + Parameters + ---------- + data : `numpy.ndarray` + The sample input data. + gt : `numpy.ndarray` + The sample ground truth. + + Raises + ------ + NotImplementedError + Raised if the `pysegcnn.core.dataset.ImageDataset` class is not + inherited. + + Returns + ------- + data : `numpy.ndarray` + The preprocessed input data. + gt : `numpy.ndarray` + The preprocessed ground truth data. + + """ raise NotImplementedError('Inherit the ImageDataset class and ' 'implement the method.') - # the parse_scene_id() method has to be implemented by the class inheriting - # the ImageDataset class - # the input to the parse_scene_id() method is a string describing a scene - # id, e.g. an id of a Landsat or a Sentinel scene - # parse_scene_id() should return a dictionary containing the scene metadata - def parse_scene_id(self, scene): + def parse_scene_id(self, scene_id): + """Parse the scene identifier. + + Parameters + ---------- + scene_id : `str` + A scene identifier. + + Raises + ------ + NotImplementedError + Raised if the `pysegcnn.core.dataset.ImageDataset` class is not + inherited. + + Returns + ------- + scene : `dict` or `None` + A dictionary containing scene metadata. If `None`, ``scene_id`` is + not a valid scene identifier. + + """ raise NotImplementedError('Inherit the ImageDataset class and ' 'implement the method.') - # _read_scene() reads all the bands and the ground truth mask in a - # scene/tile to a numpy array and returns a dictionary with - # (key, value) = ('band_name', np.ndarray(band_data)) def read_scene(self, idx): - + """Read the data of the sample with index ``idx``. + + Parameters + ---------- + idx : `int` + The index of the sample. + + Returns + ------- + scene_data : `dict` + The sample data dictionary with keys: + ``'band_name_1'`` + data of band_1 (`numpy.ndarray`). + ``'band_name_2'`` + data of band_2 (`numpy.ndarray`). + ``'band_name_n'`` + data of band_n (`numpy.ndarray`). + ``'gt'`` + data of the ground truth (`numpy.ndarray`). + ``'date'`` + The date of the sample. + ``'tile'`` + The tile id of the sample. + ``'transform'`` + The transformation to apply. + ``'id'`` + The scene identifier. + + """ # select a scene from the root directory scene = self.scenes[idx] @@ -309,22 +455,71 @@ class ImageDataset(Dataset): return scene_data - # _build_samples() stacks all bands of a scene/tile into a - # numpy array of shape (bands x height x width) def build_samples(self, scene): - + """Stack the bands of a sample in a single array. + + Parameters + ---------- + scene : `dict` + The sample data dictionary with keys: + ``'band_name_1'`` + data of band_1 (`numpy.ndarray`). + ``'band_name_2'`` + data of band_2 (`numpy.ndarray`). + ``'band_name_n'`` + data of band_n (`numpy.ndarray`). + ``'gt'`` + data of the ground truth (`numpy.ndarray`). + ``'date'`` + The date of the sample. + ``'tile'`` + The tile id of the sample. + ``'transform'`` + The transformation to apply. + ``'id'`` + The scene identifier. + + Returns + ------- + stack : `numpy.ndarray` + The input data of the sample. + gt : TYPE + The ground truth of the sample. + + """ # iterate over the channels to stack stack = np.stack([scene[band] for band in self.use_bands], axis=0) gt = scene['gt'] return stack, gt - def to_tensor(self, x, y): - return (torch.tensor(x.copy(), dtype=torch.float32), - torch.tensor(y.copy(), dtype=torch.uint8)) + def to_tensor(self, x, dtype): + """Convert ``x`` to `torch.Tensor`. + + Parameters + ---------- + x : array_like + The input data. + dtype : `torch.dtype` + The data type used to convert ``x``. + + Returns + ------- + x : `torch.Tensor` + The input data tensor. + + """ + return torch.tensor(np.asarray(x).copy(), dtype=dtype) def __repr__(self): + """Dataset representation. + + Returns + ------- + fs : `str` + Representation string. + """ # representation string to print fs = self.__class__.__name__ + '(\n' @@ -358,17 +553,106 @@ class ImageDataset(Dataset): return fs - class StandardEoDataset(ImageDataset): - - def __init__(self, root_dir, **kwargs): + """Base class for standard Earth Observation style datasets. + + `pysegcnn.core.dataset.StandardEoDataset` implements the + `~pysegcnn.core.dataset.StandardEoDataset.compose_scenes` method for + datasets with the following directory structure: + + root_dir/ + scene_id_1/ + scene_id_1_B1.tif + scene_id_1_B2.tif + . + . + . + scene_id_1_BN.tif + scene_id_2/ + scene_id_2_B1.tif + scene_id_2_B2.tif + . + . + . + scene_id_2_BN.tif + . + . + . + scene_id_N/ + . + . + . + + If your dataset shares this directory structure, you can directly inherit + `pysegcnn.core.dataset.StandardEoDataset` and implement the remaining + methods. + + See `pysegcnn.core.dataset.SparcsDataset` for an example. + + Parameters + ---------- + root_dir : `str` + The root directory, path to the dataset. + use_bands : `list` [`str`], optional + A list of the spectral bands to use. The default is []. + tile_size : `int` or `None`, optional + The size of the tiles. If not `None`, each scene is divided into square + tiles of shape (tile_size, tile_size). The default is None. + pad : `bool`, optional + Whether to center pad the input image. Set ``pad`` = True, if the + images are not evenly divisible by the ``tile_size``. The image data is + padded with a constant padding value of zero. For each image, the + corresponding ground truth image is padded with a "no data" label. + The default is False. + gt_pattern : `str`, optional + A pattern to match the ground truth naming convention. All directories + and subdirectories in ``root_dir`` are searched for files matching + ``gt_pattern``. The default is '*gt.tif'. + sort : `bool`, optional + Whether to chronologically sort the samples. Useful for time series + data. The default is False. + seed : `int`, optional + The random seed. Used to split the dataset into training, validation + and test set. Useful for reproducibility. The default is 0. + transforms : `list` [`pysegcnn.core.split.Augment`], optional + List of `pysegcnn.core.split.Augment` instances. Each item in + ``transforms`` generates a distinct transformed version of the dataset. + The total dataset is composed of the original untransformed dataset + together with each transformed version of it. + If ``transforms`` = [], only the original dataset is used. + The default is []. + + Returns + ------- + None. + + """ + + def __init__(self, root_dir, use_bands=[], tile_size=None, pad=False, + gt_pattern='*gt.tif', sort=False, seed=0, transforms=[]): # initialize super class ImageDataset - super().__init__(root_dir, **kwargs) + super().__init__(root_dir, use_bands, tile_size, pad, gt_pattern, + sort, seed, transforms) + + def _get_band_number(self, path): + """Return the band number of a scene .tif file. + + Parameters + ---------- + path : `str` + The path to the .tif file. + + Raises + ------ + ValueError + Raised if ``path`` is not a .tif file. - # returns the band number of a Landsat8 or Sentinel2 tif file - # path: path to a tif file - def get_band_number(self, path): + Returns + ------- + band : `int` or `str` + The band number. + """ # check whether the path leads to a tif file if not path.endswith(('tif', 'TIF')): raise ValueError('Expected a path to a tif file.') @@ -377,7 +661,7 @@ class StandardEoDataset(ImageDataset): fname = os.path.basename(path) # search for numbers following a "B" in the filename - band = re.search('B\dA|B\d{1,2}', fname)[0].replace('B', '') + band = re.search('B\\dA|B\\d{1,2}', fname)[0].replace('B', '') # try converting to an integer: # raises a ValueError for Sentinel2 8A band @@ -388,14 +672,34 @@ class StandardEoDataset(ImageDataset): return band - # store_bands() writes the paths to the data of each scene to a dictionary - # only the bands of interest are stored - def store_bands(self, bands, gt): - + def _store_bands(self, bands, gt): + """Write the bands of interest to a dictionary. + + Parameters + ---------- + bands : `list` [`str`] + Paths to the .tif files of the bands of the scene. + gt : `str` + Path to the ground truth of the scene. + + Returns + ------- + scene_data : `dict` + The scene data dictionary with keys: + ``'band_name_1'`` + Path to the .tif file of band_1. + ``'band_name_2'`` + Path to the .tif file of band_2. + ``'band_name_n'`` + Path to the .tif file of band_n. + ``'gt'`` + Path to the ground truth file. + + """ # store the bands of interest in a dictionary scene_data = {} for i, b in enumerate(bands): - band = self.bands[self.get_band_number(b)] + band = self.bands[self._get_band_number(b)] if band in self.use_bands: scene_data[band] = b @@ -404,12 +708,33 @@ class StandardEoDataset(ImageDataset): return scene_data - # compose_scenes() creates a list of dictionaries containing the paths - # to the tif files of each scene - # if the scenes are divided into tiles, each tile has its own entry - # with corresponding tile id def compose_scenes(self): - + """Build the list of samples of the dataset. + + Each sample is represented by a dictionary. + + Returns + ------- + scenes : `list` [`dict`] + Each item in ``scenes`` is a `dict` with keys: + ``'band_name_1'`` + Path to the file of band_1. + ``'band_name_2'`` + Path to the file of band_2. + ``'band_name_n'`` + Path to the file of band_n. + ``'gt'`` + Path to the ground truth file. + ``'date'`` + The date of the sample. + ``'tile'`` + The tile id of the sample. + ``'transform'`` + The transformation to apply. + ``'id'`` + The scene identifier. + + """ # search the root directory scenes = [] self.gt = [] @@ -447,7 +772,7 @@ class StandardEoDataset(ImageDataset): for transf in self.transforms: # store the bands and the ground truth mask of the tile - data = self.store_bands(bands, gt) + data = self._store_bands(bands, gt) # the name of the scene data['id'] = scene['id'] @@ -472,122 +797,518 @@ class StandardEoDataset(ImageDataset): return scenes -# SparcsDataset class: inherits from the generic ImageDataset class class SparcsDataset(StandardEoDataset): - - def __init__(self, root_dir, **kwargs): + """Dataset class for the `Sparcs`_ dataset. + + .. _Sparcs: + https://www.usgs.gov/land-resources/nli/landsat/spatial-procedures-automated-removal-cloud-and-shadow-sparcs-validation + + Parameters + ---------- + root_dir : `str` + The root directory, path to the dataset. + use_bands : `list` [`str`], optional + A list of the spectral bands to use. The default is []. + tile_size : `int` or `None`, optional + The size of the tiles. If not `None`, each scene is divided into square + tiles of shape (tile_size, tile_size). The default is None. + pad : `bool`, optional + Whether to center pad the input image. Set ``pad`` = True, if the + images are not evenly divisible by the ``tile_size``. The image data is + padded with a constant padding value of zero. For each image, the + corresponding ground truth image is padded with a "no data" label. + The default is False. + gt_pattern : `str`, optional + A pattern to match the ground truth naming convention. All directories + and subdirectories in ``root_dir`` are searched for files matching + ``gt_pattern``. The default is '*gt.tif'. + sort : `bool`, optional + Whether to chronologically sort the samples. Useful for time series + data. The default is False. + seed : `int`, optional + The random seed. Used to split the dataset into training, validation + and test set. Useful for reproducibility. The default is 0. + transforms : `list` [`pysegcnn.core.split.Augment`], optional + List of `pysegcnn.core.split.Augment` instances. Each item in + ``transforms`` generates a distinct transformed version of the dataset. + The total dataset is composed of the original untransformed dataset + together with each transformed version of it. + If ``transforms`` = [], only the original dataset is used. + The default is []. + + Returns + ------- + None. + + """ + + def __init__(self, root_dir, use_bands=[], tile_size=None, pad=False, + gt_pattern='*gt.tif', sort=False, seed=0, transforms=[]): # initialize super class StandardEoDataset - super().__init__(root_dir, **kwargs) + super().__init__(root_dir, use_bands, tile_size, pad, gt_pattern, + sort, seed, transforms) - # image size of the Sparcs dataset: (height, width) def get_size(self): + """Image size of the Sparcs dataset. + + Returns + ------- + size : `tuple` + The image size (height, width). + + """ return (1000, 1000) - # Landsat 8 bands of the Sparcs dataset def get_sensor(self): + """Landsat 8 bands of the Sparcs dataset. + + Returns + ------- + sensor : `enum.Enum` + An enumeration of the bands of the sensor. + + """ return Landsat8 - # class labels of the Sparcs dataset def get_labels(self): + """Class labels of the Sparcs dataset. + + Returns + ------- + labels : `enum.Enum` + The class labels. + + """ return SparcsLabels - # preprocessing of the Sparcs dataset def preprocess(self, data, gt): - + """Preprocess Sparcs dataset images. + + Parameters + ---------- + data : `numpy.ndarray` + The sample input data. + gt : `numpy.ndarray` + The sample ground truth. + + Returns + ------- + data : `numpy.ndarray` + The preprocessed input data. + gt : `numpy.ndarray` + The preprocessed ground truth data. + + """ # if the preprocessing is not done externally, implement it here return data, gt - # function that parses the date from a Landsat 8 scene id - def parse_scene_id(self, scene): - return parse_landsat_scene(scene) + def parse_scene_id(self, scene_id): + """Parse Sparcs scene identifiers (Landsat 8). + Parameters + ---------- + scene_id : `str` + A scene identifier. -class ProSnowDataset(StandardEoDataset): + Returns + ------- + scene : `dict` or `None` + A dictionary containing scene metadata. If `None`, ``scene_id`` is + not a valid Landsat scene identifier. + + """ + return parse_landsat_scene(scene_id) - def __init__(self, root_dir, **kwargs): + +class ProSnowDataset(StandardEoDataset): + """Dataset class for the ProSnow datasets. + + Parameters + ---------- + root_dir : `str` + The root directory, path to the dataset. + use_bands : `list` [`str`], optional + A list of the spectral bands to use. The default is []. + tile_size : `int` or `None`, optional + The size of the tiles. If not `None`, each scene is divided into square + tiles of shape (tile_size, tile_size). The default is None. + pad : `bool`, optional + Whether to center pad the input image. Set ``pad`` = True, if the + images are not evenly divisible by the ``tile_size``. The image data is + padded with a constant padding value of zero. For each image, the + corresponding ground truth image is padded with a "no data" label. + The default is False. + gt_pattern : `str`, optional + A pattern to match the ground truth naming convention. All directories + and subdirectories in ``root_dir`` are searched for files matching + ``gt_pattern``. The default is '*gt.tif'. + sort : `bool`, optional + Whether to chronologically sort the samples. Useful for time series + data. The default is False. + seed : `int`, optional + The random seed. Used to split the dataset into training, validation + and test set. Useful for reproducibility. The default is 0. + transforms : `list` [`pysegcnn.core.split.Augment`], optional + List of `pysegcnn.core.split.Augment` instances. Each item in + ``transforms`` generates a distinct transformed version of the dataset. + The total dataset is composed of the original untransformed dataset + together with each transformed version of it. + If ``transforms`` = [], only the original dataset is used. + The default is []. + + Returns + ------- + None. + + """ + + def __init__(self, root_dir, use_bands=[], tile_size=None, pad=False, + gt_pattern='*gt.tif', sort=False, seed=0, transforms=[]): # initialize super class StandardEoDataset - super().__init__(root_dir, **kwargs) + super().__init__(root_dir, use_bands, tile_size, pad, gt_pattern, + sort, seed, transforms) - # Sentinel 2 bands def get_sensor(self): + """Sentinel 2 bands of the ProSnow datasets. + + Returns + ------- + sensor : `enum.Enum` + An enumeration of the bands of the sensor. + + """ return Sentinel2 - # class labels of the ProSnow dataset def get_labels(self): + """Class labels of the ProSnow datasets. + + Returns + ------- + labels : `enum.Enum` + The class labels. + + """ return ProSnowLabels - # preprocessing of the ProSnow dataset def preprocess(self, data, gt): - + """Preprocess ProSnow dataset images. + + Parameters + ---------- + data : `numpy.ndarray` + The sample input data. + gt : `numpy.ndarray` + The sample ground truth. + + Returns + ------- + data : `numpy.ndarray` + The preprocessed input data. + gt : `numpy.ndarray` + The preprocessed ground truth data. + + """ # if the preprocessing is not done externally, implement it here return data, gt - # function that parses the date from a Sentinel 2 scene id - def parse_scene_id(self, scene): - return parse_sentinel2_scene(scene) + def parse_scene_id(self, scene_id): + """Parse ProSnow scene identifiers (Sentinel 2). + Parameters + ---------- + scene_id : `str` + A scene identifier. -class ProSnowGarmisch(ProSnowDataset): + Returns + ------- + scene : `dict` or `None` + A dictionary containing scene metadata. If `None`, ``scene_id`` is + not a valid Sentinel-2 scene identifier. - def __init__(self, root_dir, **kwargs): - # initialize super class StandardEoDatasets - super().__init__(root_dir, **kwargs) + """ + return parse_sentinel2_scene(scene_id) + + +class ProSnowGarmisch(ProSnowDataset): + """Dataset class for the ProSnow Garmisch dataset. + + Parameters + ---------- + root_dir : `str` + The root directory, path to the dataset. + use_bands : `list` [`str`], optional + A list of the spectral bands to use. The default is []. + tile_size : `int` or `None`, optional + The size of the tiles. If not `None`, each scene is divided into square + tiles of shape (tile_size, tile_size). The default is None. + pad : `bool`, optional + Whether to center pad the input image. Set ``pad`` = True, if the + images are not evenly divisible by the ``tile_size``. The image data is + padded with a constant padding value of zero. For each image, the + corresponding ground truth image is padded with a "no data" label. + The default is False. + gt_pattern : `str`, optional + A pattern to match the ground truth naming convention. All directories + and subdirectories in ``root_dir`` are searched for files matching + ``gt_pattern``. The default is '*gt.tif'. + sort : `bool`, optional + Whether to chronologically sort the samples. Useful for time series + data. The default is False. + seed : `int`, optional + The random seed. Used to split the dataset into training, validation + and test set. Useful for reproducibility. The default is 0. + transforms : `list` [`pysegcnn.core.split.Augment`], optional + List of `pysegcnn.core.split.Augment` instances. Each item in + ``transforms`` generates a distinct transformed version of the dataset. + The total dataset is composed of the original untransformed dataset + together with each transformed version of it. + If ``transforms`` = [], only the original dataset is used. + The default is []. + + Returns + ------- + None. + + """ + + def __init__(self, root_dir, use_bands=[], tile_size=None, pad=False, + gt_pattern='*gt.tif', sort=False, seed=0, transforms=[]): + # initialize super class StandardEoDataset + super().__init__(root_dir, use_bands, tile_size, pad, gt_pattern, + sort, seed, transforms) def get_size(self): + """Image size of the ProSnow Garmisch dataset. + + Returns + ------- + size : `tuple` + The image size (height, width). + + """ return (615, 543) class ProSnowObergurgl(ProSnowDataset): - - def __init__(self, root_dir, **kwargs): + """Dataset class for the ProSnow Obergurgl dataset. + + Parameters + ---------- + root_dir : `str` + The root directory, path to the dataset. + use_bands : `list` [`str`], optional + A list of the spectral bands to use. The default is []. + tile_size : `int` or `None`, optional + The size of the tiles. If not `None`, each scene is divided into square + tiles of shape (tile_size, tile_size). The default is None. + pad : `bool`, optional + Whether to center pad the input image. Set ``pad`` = True, if the + images are not evenly divisible by the ``tile_size``. The image data is + padded with a constant padding value of zero. For each image, the + corresponding ground truth image is padded with a "no data" label. + The default is False. + gt_pattern : `str`, optional + A pattern to match the ground truth naming convention. All directories + and subdirectories in ``root_dir`` are searched for files matching + ``gt_pattern``. The default is '*gt.tif'. + sort : `bool`, optional + Whether to chronologically sort the samples. Useful for time series + data. The default is False. + seed : `int`, optional + The random seed. Used to split the dataset into training, validation + and test set. Useful for reproducibility. The default is 0. + transforms : `list` [`pysegcnn.core.split.Augment`], optional + List of `pysegcnn.core.split.Augment` instances. Each item in + ``transforms`` generates a distinct transformed version of the dataset. + The total dataset is composed of the original untransformed dataset + together with each transformed version of it. + If ``transforms`` = [], only the original dataset is used. + The default is []. + + Returns + ------- + None. + + """ + + def __init__(self, root_dir, use_bands=[], tile_size=None, pad=False, + gt_pattern='*gt.tif', sort=False, seed=0, transforms=[]): # initialize super class StandardEoDataset - super().__init__(root_dir, **kwargs) + super().__init__(root_dir, use_bands, tile_size, pad, gt_pattern, + sort, seed, transforms) def get_size(self): + """Image size of the ProSnow Obergurgl dataset. + + Returns + ------- + size : `tuple` + The image size (height, width). + + """ return (310, 270) class Cloud95Dataset(ImageDataset): - - def __init__(self, root_dir, **kwargs): + """Dataset class for the `Cloud95`_ dataset by `Mohajerani et al. (2020)`_. + + .. _Cloud95: + https://github.com/SorourMo/95-Cloud-An-Extension-to-38-Cloud-Dataset + .. _`Mohajerani et al. (2020)`: + https://arxiv.org/abs/2001.08768 + + Parameters + ---------- + root_dir : `str` + The root directory, path to the dataset. + use_bands : `list` [`str`], optional + A list of the spectral bands to use. The default is []. + tile_size : `int` or `None`, optional + The size of the tiles. If not `None`, each scene is divided into square + tiles of shape (tile_size, tile_size). The default is None. + pad : `bool`, optional + Whether to center pad the input image. Set ``pad`` = True, if the + images are not evenly divisible by the ``tile_size``. The image data is + padded with a constant padding value of zero. For each image, the + corresponding ground truth image is padded with a "no data" label. + The default is False. + gt_pattern : `str`, optional + A pattern to match the ground truth naming convention. All directories + and subdirectories in ``root_dir`` are searched for files matching + ``gt_pattern``. The default is '*gt.tif'. + sort : `bool`, optional + Whether to chronologically sort the samples. Useful for time series + data. The default is False. + seed : `int`, optional + The random seed. Used to split the dataset into training, validation + and test set. Useful for reproducibility. The default is 0. + transforms : `list` [`pysegcnn.core.split.Augment`], optional + List of `pysegcnn.core.split.Augment` instances. Each item in + ``transforms`` generates a distinct transformed version of the dataset. + The total dataset is composed of the original untransformed dataset + together with each transformed version of it. + If ``transforms`` = [], only the original dataset is used. + The default is []. + + Returns + ------- + None. + + """ + + def __init__(self, root_dir, use_bands=[], tile_size=None, pad=False, + gt_pattern='*gt.tif', sort=False, seed=0, transforms=[]): + # initialize super class StandardEoDataset + super().__init__(root_dir, use_bands, tile_size, pad, gt_pattern, + sort, seed, transforms) # the csv file containing the names of the informative patches # patches with more than 80% black pixels, i.e. patches resulting from # the black margins around a Landsat 8 scene are excluded self.exclude = 'training_patches_95-cloud_nonempty.csv' - # initialize super class ImageDataset - super().__init__(root_dir, **kwargs) - - # image size of the Cloud-95 dataset: (height, width) def get_size(self): + """Image size of the Cloud-95 dataset. + + Returns + ------- + size : `tuple` + The image size (height, width). + + """ return (384, 384) - # Landsat 8 bands in the Cloud-95 dataset def get_sensor(self): + """Landsat 8 bands of the Cloud-95 dataset. + + Returns + ------- + sensor : `enum.Enum` + An enumeration of the bands of the sensor. + + """ return Landsat8 - # class labels of the Cloud-95 dataset def get_labels(self): + """Class labels of the Cloud-95 dataset. + + Returns + ------- + labels : `enum.Enum` + The class labels. + + """ return Cloud95Labels - # preprocess Cloud-95 dataset def preprocess(self, data, gt): - + """Preprocess Cloud-95 dataset images. + + Parameters + ---------- + data : `numpy.ndarray` + The sample input data. + gt : `numpy.ndarray` + The sample ground truth. + + Returns + ------- + data : `numpy.ndarray` + The preprocessed input data. + gt : `numpy.ndarray` + The preprocessed ground truth data. + + """ # normalize the data # here, we use the normalization of the authors of Cloud-95, i.e. # Mohajerani and Saeedi (2019, 2020) data /= 65535 gt[gt != self.cval] /= 255 - return data, gt - # function that parses the date from a Landsat 8 scene id - def parse_scene_id(self, scene): - return parse_landsat_scene(scene) + def parse_scene_id(self, scene_id): + """Parse Sparcs scene identifiers (Landsat 8). - def compose_scenes(self): + Parameters + ---------- + scene_id : `str` + A scene identifier. + + Returns + ------- + scene : `dict` or `None` + A dictionary containing scene metadata. If `None`, ``scene_id`` is + not a valid Landsat scene identifier. + """ + return parse_landsat_scene(scene_id) + + def compose_scenes(self): + """Build the list of samples of the dataset. + + Each sample is represented by a dictionary. + + Returns + ------- + scenes : `list` [`dict`] + Each item in ``scenes`` is a `dict` with keys: + ``'band_name_1'`` + Path to the file of band_1. + ``'band_name_2'`` + Path to the file of band_2. + ``'band_name_n'`` + Path to the file of band_n. + ``'gt'`` + Path to the ground truth file. + ``'date'`` + The date of the sample. + ``'tile'`` + The tile id of the sample. + ``'transform'`` + The transformation to apply. + ``'id'`` + The scene identifier. + + """ # whether to exclude patches with more than 80% black pixels ipatches = [] if self.exclude is not None: @@ -620,7 +1341,7 @@ class Cloud95Dataset(ImageDataset): patchname = file.split('.')[0].replace(biter + '_', '') # get the date of the current scene - date = self.parse_scene_id(patchname)['date'] + scene_meta = self.parse_scene_id(patchname) # check whether the current file is an informative patch if ipatches and patchname not in ipatches: @@ -641,11 +1362,14 @@ class Cloud95Dataset(ImageDataset): scene[band] = os.path.join(band_dirs[band], file.replace(biter, band)) + # the name of the scene the patch was extracted from + scene['id'] = scene_meta['id'] + # store tile number scene['tile'] = tile # store date - scene['date'] = date + scene['date'] = scene_meta['date'] # store optional transformation scene['transform'] = transf @@ -661,49 +1385,9 @@ class Cloud95Dataset(ImageDataset): class SupportedDatasets(enum.Enum): + """Names and corresponding classes of the implemented datasets.""" + Sparcs = SparcsDataset Cloud95 = Cloud95Dataset Garmisch = ProSnowGarmisch Obergurgl = ProSnowObergurgl - - -if __name__ == '__main__': - - # define path to working directory - # wd = '//projectdata.eurac.edu/projects/cci_snow/dfrisinghelli/' - # wd = '/mnt/CEPH_PROJECTS/cci_snow/dfrisinghelli' - wd = 'C:/Eurac/2020/' - - # path to the preprocessed sparcs dataset - sparcs_path = os.path.join(wd, '_Datasets/Sparcs') - - # path to the Cloud-95 dataset - # cloud_path = os.path.join(wd, '_Datasets/Cloud95/Training') - - # path to the ProSnow dataset - # prosnow_path = os.path.join(wd, '_Datasets/ProSnow/') - - # instanciate the Cloud-95 dataset - # cloud_dataset = Cloud95Dataset(cloud_path, - # tile_size=192, - # use_bands=[], - # sort=False) - - # instanciate the SparcsDataset class - sparcs_dataset = SparcsDataset(sparcs_path, - tile_size=None, - use_bands=['red', 'green', 'blue', 'nir'], - sort=False, - transforms=[], - gt_pattern='*mask.png', - pad=True, - cval=99, - ) - - # instanciate the ProSnow datasets - # garmisch = ProSnowGarmisch(os.path.join(prosnow_path, 'Garmisch'), - # tile_size=None, - # use_bands=['nir', 'red', 'green', 'blue'], - # sort=True, - # transforms=[], - # gt_pattern='*class.img') diff --git a/pysegcnn/core/graphics.py b/pysegcnn/core/graphics.py index 5d765cd..e760d9d 100644 --- a/pysegcnn/core/graphics.py +++ b/pysegcnn/core/graphics.py @@ -83,15 +83,15 @@ def plot_sample(x, y, use_bands, labels, y_pred=None, figsize=(10, 10), Parameters ---------- - x : `numpy.ndarray` or `torch.tensor`, (b, h, w) + x : `numpy.ndarray` or `torch.Tensor`, (b, h, w) Array containing the raw data of the tile, shape=(bands, height, width) - y : `numpy.ndarray` or `torch.tensor`, (h, w) + y : `numpy.ndarray` or `torch.Tensor`, (h, w) Array containing the ground truth of tile ``x``, shape=(height, width) use_bands : `list` of `str` List describing the order of the bands in ``x``. labels : `dict` [`int`, `dict`] - The keys are the values of the class labels in the ground truth ``y``. - Each nested `dict` should have keys: + The label dictionary. The keys are the values of the class labels + in the ground truth ``y``. Each nested `dict` should have keys: ``'color'`` A named color (`str`). ``'label'`` @@ -185,8 +185,8 @@ def plot_confusion_matrix(cm, labels, normalize=True, cm : `numpy.ndarray` The confusion matrix. labels : `dict` [`int`, `dict`] - The keys are the values of the class labels in the ground truth ``y``. - Each nested `dict` should have keys: + The label dictionary. The keys are the values of the class labels + in the ground truth ``y``. Each nested `dict` should have keys: ``'color'`` A named color (`str`). ``'label'`` diff --git a/pysegcnn/core/layers.py b/pysegcnn/core/layers.py index 367bd72..ab6c59a 100644 --- a/pysegcnn/core/layers.py +++ b/pysegcnn/core/layers.py @@ -134,16 +134,16 @@ class Conv2dPool(nn.Module): Parameters ---------- - x : `torch.tensor` + x : `torch.Tensor` Output of previous layer. Returns ------- - y : `torch.tensor` + y : `torch.Tensor` Output of this block. - x : `torch.tensor` + x : `torch.Tensor` Output before max pooling. Stored for skip connections. - i : `torch.tensor` + i : `torch.Tensor` Indices of the max pooling operation. Used in unpooling operation. """ @@ -190,18 +190,18 @@ class Conv2dUnpool(nn.Module): Parameters ---------- - x : `torch.tensor` + x : `torch.Tensor` Output of previous layer. - feature : `torch.tensor` + feature : `torch.Tensor` Encoder feature used for the skip connection. - indices : `torch.tensor` + indices : `torch.Tensor` Indices of the max pooling operation. Used in unpooling operation. skip : `bool` Whether to apply skip connetion. Returns ------- - x : `torch.tensor` + x : `torch.Tensor` Output of this block. """ @@ -254,11 +254,11 @@ class Conv2dUpsample(nn.Module): Parameters ---------- - x : `torch.tensor` + x : `torch.Tensor` Output of previous layer. - feature : `torch.tensor` + feature : `torch.Tensor` Encoder feature used for the skip connection. - indices : `torch.tensor` + indices : `torch.Tensor` Indices of the max pooling operation. Used in unpooling operation. Not used here, but passed to preserve generic interface. Useful in `pysegcnn.core.layers.Decoder`. @@ -267,7 +267,7 @@ class Conv2dUpsample(nn.Module): Returns ------- - x : `torch.tensor` + x : `torch.Tensor` Output of this block. """ @@ -331,12 +331,12 @@ class Encoder(nn.Module): Parameters ---------- - x : `torch.tensor` + x : `torch.Tensor` Input image. Returns ------- - x : `torch.tensor` + x : `torch.Tensor` Output of the encoder. """ @@ -411,7 +411,7 @@ class Decoder(nn.Module): Parameters ---------- - x : `torch.tensor` + x : `torch.Tensor` Output of the encoder. enc_cache : `dict` Cache dictionary with keys: @@ -422,7 +422,7 @@ class Decoder(nn.Module): Returns ------- - x : `torch.tensor` + x : `torch.Tensor` Output of the decoder. """ diff --git a/pysegcnn/core/models.py b/pysegcnn/core/models.py index 52400d7..0cd85f4 100644 --- a/pysegcnn/core/models.py +++ b/pysegcnn/core/models.py @@ -275,7 +275,7 @@ class UNet(Network): Parameters ---------- - x : `torch.tensor` + x : `torch.Tensor` The input image, shape=(batch_size, channels, height, width). Returns diff --git a/pysegcnn/core/split.py b/pysegcnn/core/split.py index ea6bef9..2be0aac 100644 --- a/pysegcnn/core/split.py +++ b/pysegcnn/core/split.py @@ -214,9 +214,9 @@ def date_scene_split(ds, date, dateformat='%Y%m%d'): ---------- ds : `pysegcnn.core.dataset.ImageDataset` An instance of `~pysegcnn.core.dataset.ImageDataset`. - date : 'str' + date : `str` A date. - dateformat : 'str', optional + dateformat : `str`, optional The format of ``date``. ``dateformat`` is used by `datetime.datetime.strptime' to parse ``date`` to a `datetime.datetime` object. The default is '%Y%m%d'. @@ -283,7 +283,7 @@ def pairwise_disjoint(sets): class CustomSubset(Subset): - """Custom subset inheriting `torch.utils.data.dataset.Subset`.""" + """Custom subset inheriting `torch.utils.data.Subset`.""" def __repr__(self): """Representation of ``~pysegcnn.core.split.CustomSubset``.""" diff --git a/pysegcnn/core/trainer.py b/pysegcnn/core/trainer.py index 31c615f..1e09161 100644 --- a/pysegcnn/core/trainer.py +++ b/pysegcnn/core/trainer.py @@ -1,9 +1,8 @@ +"""Model configuration and training.""" + +# !/usr/bin/env python # -*- coding: utf-8 -*- -""" -Created on Wed Aug 12 10:24:34 2020 -@author: Daniel -""" # builtins import dataclasses import pathlib @@ -34,8 +33,22 @@ LOGGER = logging.getLogger(__name__) @dataclasses.dataclass class BaseConfig: + """Base `dataclasses.dataclass` for each configuration.""" def __post_init__(self): + """Check the type of each argument. + + Raises + ------ + TypeError + Raised if the conversion to the specified type of the argument + fails. + + Returns + ------- + None. + + """ # check input types for field in dataclasses.fields(self): # the value of the current field @@ -55,6 +68,45 @@ class BaseConfig: @dataclasses.dataclass class DatasetConfig(BaseConfig): + """Dataset configuration class. + + Parameters + ---------- + dataset_name : `str` + The name of the dataset. + root_dir : `pathlib.Path` + The root directory, path to the dataset. + bands : `list` [`str`] + A list of the spectral bands to use. + tile_size : `int` + The size of the tiles. Each scene is divided into square tiles of shape + (tile_size, tile_size). + gt_pattern : `str` + A pattern to match the ground truth naming convention. All directories + and subdirectories in ``root_dir`` are searched for files matching + ``gt_pattern``. + seed : `int` + The random seed. Used to split the dataset into training, validation + and test set. Useful for reproducibility. The default is 0. + sort : `bool`, optional + Whether to chronologically sort the samples. Useful for time series + data. The default is False. + transforms : `list` [`pysegcnn.core.split.Augment`], optional + List of `pysegcnn.core.split.Augment` instances. Each item in + ``transforms`` generates a distinct transformed version of the dataset. + The total dataset is composed of the original untransformed dataset + together with each transformed version of it. + If ``transforms`` = [], only the original dataset is used. + The default is []. + pad : `bool`, optional + Whether to center pad the input image. Set ``pad`` = True, if the + images are not evenly divisible by the ``tile_size``. The image data is + padded with a constant padding value of zero. For each image, the + corresponding ground truth image is padded with a "no data" label. + The default is False. + + """ + dataset_name: str root_dir: pathlib.Path bands: list @@ -66,6 +118,21 @@ class DatasetConfig(BaseConfig): pad: bool = False def __post_init__(self): + """Check the type of each argument. + + Raises + ------ + FileNotFoundError + Raised if ``root_dir`` does not exist. + TypeError + Raised if not each item in ``transforms`` is an instance of + `pysegcnn.core.split.Augment` in case ``transforms`` is not empty. + + Returns + ------- + None. + + """ # check input types super().__post_init__() @@ -84,7 +151,14 @@ class DatasetConfig(BaseConfig): Augment.__name__]))) def init_dataset(self): + """Instanciate the dataset. + + Returns + ------- + dataset : `pysegcnn.core.dataset.ImageDataset` + An instance of `pysegcnn.core.dataset.ImageDataset`. + """ # instanciate the dataset dataset = self.dataset_class( root_dir=str(self.root_dir), @@ -102,6 +176,32 @@ class DatasetConfig(BaseConfig): @dataclasses.dataclass class SplitConfig(BaseConfig): + """Dataset split configuration class. + + Parameters + ---------- + split_mode : `str` + The mode to split the dataset. + ttratio : `float` + The ratio of training and validation data to test data, e.g. + ``ttratio`` = 0.6 means 60% for training and validation, 40% for + testing. + tvratio : `float` + The ratio of training data to validation data, e.g. ``tvratio`` = 0.8 + means 80% training, 20% validation. + date : `str`, optional + A date. Used if ``split_mode`` = 'date'. The default is 'yyyymmdd'. + dateformat : `str`, optional + The format of ``date``. ``dateformat`` is used by + `datetime.datetime.strptime' to parse ``date`` to a `datetime.datetime` + object. The default is '%Y%m%d'. + drop : `float`, optional + Whether to drop samples (during training only) with a fraction of + pixels equal to the constant padding value >= ``drop``. ``drop`` = 0 + means, do not drop any samples. The default is 0. + + """ + split_mode: str ttratio: float tvratio: float @@ -110,17 +210,46 @@ class SplitConfig(BaseConfig): drop: float = 0 def __post_init__(self): + """Check the type of each argument. + + Raises + ------ + ValueError + Raised if ``split_mode`` is not supported. + + Returns + ------- + None. + + """ # check input types super().__post_init__() # check if the split mode is valid self.split_class = item_in_enum(self.split_mode, SupportedSplits) - # function to drop samples with a fraction of pixels equal to the constant - # padding value self.cval >= self.drop @staticmethod def _drop_samples(ds, drop_threshold=1): - + """Drop samples with a fraction of pixels equal to the padding value. + + Parameters + ---------- + ds : `pysegcnn.core.split.RandomSubset` or + `pysegcnn.core.split.SceneSubset`. + An instance of `pysegcnn.core.split.RandomSubset` or + `pysegcnn.core.split.SceneSubset`. + drop_threshold : `float`, optional + The threshold above which samples are dropped. ``drop_threshold`` = + 1 means a sample is dropped, if all pixels are equal to the padding + value. ``drop_threshold`` = 0.8 means, drop a sample if 80% of the + pixels are equal to the padding value, etc. The default is 1. + + Returns + ------- + dropped : `list` [`dict`] + List of the dropped samples. + + """ # iterate over the scenes returned by self.compose_scenes() dropped = [] for pos, i in enumerate(ds.indices): @@ -145,7 +274,32 @@ class SplitConfig(BaseConfig): return dropped def train_val_test_split(self, ds): - + """Split ``ds`` into training, validation and test set. + + Parameters + ---------- + ds : `pysegcnn.core.dataset.ImageDataset` + An instance of `pysegcnn.core.dataset.ImageDataset`. + + Raises + ------ + TypeError + Raised if ``ds`` is not an instance of + `pysegcnn.core.dataset.ImageDataset`. + + Returns + ------- + train_ds : `pysegcnn.core.split.RandomSubset` or + `pysegcnn.core.split.SceneSubset`. + The training set. + valid_ds : `pysegcnn.core.split.RandomSubset` or + `pysegcnn.core.split.SceneSubset`. + The validation set. + test_ds : `pysegcnn.core.split.RandomSubset` or + `pysegcnn.core.split.SceneSubset`. + The test set. + + """ if not isinstance(ds, ImageDataset): raise TypeError('Expected "ds" to be {}.' .format('.'.join([ImageDataset.__module__, @@ -172,6 +326,31 @@ class SplitConfig(BaseConfig): @staticmethod def dataloaders(*args, **kwargs): + """Build `torch.utils.data.DataLoader` instances. + + Parameters + ---------- + *args : `list` [`torch.utils.data.Dataset`] + List of instances of `torch.utils.data.Dataset`. + **kwargs + Additional keyword arguments passed to + `torch.utils.data.DataLoader`. + + Raises + ------ + TypeError + Raised if not each item in ``args`` is an instance of + `torch.utils.data.Dataset`. + + Returns + ------- + loaders : `list` [`torch.utils.data.DataLoader`] + List of instances of `torch.utils.data.DataLoader`. If an instance + of `torch.utils.data.Dataset` in ``args`` is empty, `None` is + appended to ``loaders`` instead of an instance of + `torch.utils.data.DataLoader`. + + """ # check whether each dataset in args has the correct type loaders = [] for ds in args: @@ -192,6 +371,72 @@ class SplitConfig(BaseConfig): @dataclasses.dataclass class ModelConfig(BaseConfig): + """Model configuration class. + + Parameters + ---------- + model_name : `str` + The name of the model. + filters : `list` [`int`] + List of input channels to the convolutional layers. + torch_seed : `int` + The random seed to initialize the model weights. + Useful for reproducibility. + optim_name : `str` + The name of the optimizer to update the model weights. + loss_name : `str` + The name of the loss function measuring the model error. + skip_connection : `bool`, optional + Whether to apply skip connections. The defaul is True. + kwargs: `dict`, optional + The configuration for each convolution in the model. The default is + {'kernel_size': 3, 'stride': 1, 'dilation': 1}. + batch_size : `int`, optional + The model batch size. Determines the number of samples to process + before updating the model weights. The default is 64. + checkpoint : `bool`, optional + Whether to resume training from an existing model checkpoint. The + default is False. + transfer : `bool`, optional + Whether to use a model for transfer learning on a new dataset. If True, + the model architecture of ``pretrained_model`` is adjusted to a new + dataset. The default is False. + pretrained_model : `str`, optional + The name of the pretrained model to use for transfer learning. + The default is ''. + lr : `float`, optional + The learning rate used by the gradient descent algorithm. + The default is 0.001. + early_stop : `bool`, optional + Whether to apply `early stopping`_. The default is False. + mode : `str`, optional + The mode of the early stopping. Depends on the metric measuring + performance. When using model loss as metric, use ``mode`` = 'min', + however, when using accuracy as metric, use ``mode`` = 'max'. For now, + only ``mode`` = 'max' is supported. Only used if ``early_stop`` = True. + The default is 'max'. + delta : `float`, optional + Minimum change in early stopping metric to be considered as an + improvement. Only used if ``early_stop`` = True. The default is 0. + patience : `int`, optional + The number of epochs to wait for an improvement in the early stopping + metric. If the model does not improve over more than ``patience`` + epochs, quit training. Only used if ``early_stop`` = True. + The default is 10. + epochs : `int`, optional + The maximum number of epochs to train. The default is 50. + nthreads : `int`, optional + The number of cpu threads to use during training. The default is + torch.get_num_threads(). + save : `bool`, optional + Whether to save the model state to disk. Model states are saved in + pysegcnn/main/_models. The default is True. + + .. _early stopping: + https://en.wikipedia.org/wiki/Early_stopping + + """ + model_name: str filters: list torch_seed: int @@ -214,6 +459,19 @@ class ModelConfig(BaseConfig): save: bool = True def __post_init__(self): + """Check the type of each argument. + + Raises + ------ + ValueError + Raised if the model ``model_name``, the optimizer ``optim_name`` or + the loss function ``loss_name`` is not supported. + + Returns + ------- + None. + + """ # check input types super().__post_init__() @@ -233,7 +491,19 @@ class ModelConfig(BaseConfig): self.pretrained_path = self.state_path.joinpath(self.pretrained_model) def init_optimizer(self, model): + """Instanciate the optimizer. + Parameters + ---------- + model : `torch.nn.Module` + An instance of `torch.nn.Module`. + + Returns + ------- + optimizer : `torch.optim.Optimizer` + An instance of `torch.optim.Optimizer`. + + """ LOGGER.info('Optimizer: {}.'.format(repr(self.optim_class))) # initialize the optimizer for the specified model @@ -242,7 +512,14 @@ class ModelConfig(BaseConfig): return optimizer def init_loss_function(self): + """Instanciate the loss function. + + Returns + ------- + loss_function : `torch.nn.Module` + An instance of `torch.nn.Module`. + """ LOGGER.info('Loss function: {}.'.format(repr(self.loss_class))) # instanciate the loss function @@ -251,7 +528,38 @@ class ModelConfig(BaseConfig): return loss_function def init_model(self, ds, state_file): - + """Instanciate the model and the optimizer. + + If the model checkpoint ``state_file`` exists, the pretrained model and + optimizer states are loaded, otherwise the model and the optimizer are + initialized from scratch. + + Parameters + ---------- + ds : `pysegcnn.core.dataset.ImageDataset` + An instance of `pysegcnn.core.dataset.ImageDataset`. + state_file : `pathlib.Path` + Path to a model checkpoint. + + Returns + ------- + model : `pysegcnn.core.models.Network` + An instance of `pysegcnn.core.models.Network`. + optimizer : `torch.optim.Optimizer` + An instance of `torch.optim.Optimizer`. + checkpoint_state : `dict` + If the model checkpoint ``state_file`` exists, ``checkpoint_state`` + has keys: + ``'ta'`` + The accuracy on the training set (`numpy.ndarray`). + ``'tl'`` + The loss on the training set (`numpy.ndarray`). + ``'va'`` + The accuracy on the validation set (`numpy.ndarray`). + ``'vl'`` + The loss on the validation set (`numpy.ndarray`). + + """ # write an initialization string to the log file LogConfig.init_log('{}: Initializing model run. ') @@ -290,7 +598,39 @@ class ModelConfig(BaseConfig): @staticmethod def load_checkpoint(model, optimizer, state_file): - + """Load an existing model checkpoint. + + If the model checkpoint ``state_file`` exists, the pretrained model and + optimizer states are loaded. + + Parameters + ---------- + model : `pysegcnn.core.models.Network` + An instance of `pysegcnn.core.models.Network`. + optimizer : `torch.optim.Optimizer` + An instance of `torch.optim.Optimizer`. + state_file : `pathlib.Path` + Path to the model checkpoint. + + Returns + ------- + model : `pysegcnn.core.models.Network` + An instance of `pysegcnn.core.models.Network`. + optimizer : `torch.optim.Optimizer` + An instance of `torch.optim.Optimizer`. + checkpoint_state : `dict` + If the model checkpoint ``state_file`` exists, ``checkpoint_state`` + has keys: + ``'ta'`` + The accuracy on the training set (`numpy.ndarray`). + ``'tl'`` + The loss on the training set (`numpy.ndarray`). + ``'va'`` + The accuracy on the validation set (`numpy.ndarray`). + ``'vl'`` + The loss on the validation set (`numpy.ndarray`). + + """ # whether to resume training from an existing model checkpoint checkpoint_state = {} @@ -315,7 +655,36 @@ class ModelConfig(BaseConfig): @staticmethod def transfer_model(state_file, ds): - + """Adjust a pretrained model to a new dataset. + + The classification layer of the pretrained model in ``state_file`` is + initilialized from scratch with the classes of the new dataset ``ds``. + + The remaining model weights are preserved. + + Parameters + ---------- + state_file : `pathlib.Path` + Path to a pretrained model. + ds : `pysegcnn.core.dataset.ImageDataset` + An instance of `pysegcnn.core.dataset.ImageDataset`. + + Raises + ------ + TypeError + Raised if ``ds`` is not an instance of + `pysegcnn.core.dataset.ImageDataset`. + ValueError + Raised if the bands of ``ds`` do not match the bands of the dataset + the pretrained model was trained with. + + Returns + ------- + model : `pysegcnn.core.models.Network` + An instance of `pysegcnn.core.models.Network`. The pretrained model + adjusted to the new dataset. + + """ # check input type if not isinstance(ds, ImageDataset): raise TypeError('Expected "ds" to be {}.' @@ -346,8 +715,7 @@ class ModelConfig(BaseConfig): .format(', '.join('({}, {})'.format(k, v['label']) for k, v in ds.labels.items()))) - # adjust the classification layer to the number of classes of the - # current dataset + # adjust the classification layer to the classes of the new dataset model.classifier = Conv2dSame(in_channels=filters[0], out_channels=model.nclasses, kernel_size=1) @@ -757,7 +1125,7 @@ class EarlyStopping(object): # whether to check for an increase or a decrease in a given metric self.is_better = self.decreased if mode == 'min' else self.increased - # minimum change in metric to be classified as an improvement + # minimum change in metric to be considered as an improvement self.min_delta = min_delta # number of epochs to wait for improvement -- GitLab