Skip to content
Snippets Groups Projects
Commit e221bb2f authored by Frisinghelli Daniel's avatar Frisinghelli Daniel
Browse files

Implemented cleaner solution to handle spectral bands and class labels from constants.py

parent 6e5be000
No related branches found
No related tags found
No related merge requests found
...@@ -36,7 +36,8 @@ from torch.utils.data import Dataset ...@@ -36,7 +36,8 @@ from torch.utils.data import Dataset
# locals # locals
from pysegcnn.core.constants import (Landsat8, Sentinel2, Label, SparcsLabels, from pysegcnn.core.constants import (Landsat8, Sentinel2, Label, SparcsLabels,
Cloud95Labels, ProSnowLabels) Cloud95Labels, ProSnowLabels,
MultiSpectralSensor)
from pysegcnn.core.utils import (img2np, is_divisible, tile_topleft_corner, from pysegcnn.core.utils import (img2np, is_divisible, tile_topleft_corner,
parse_landsat_scene, parse_sentinel2_scene) parse_landsat_scene, parse_sentinel2_scene)
...@@ -185,11 +186,12 @@ class ImageDataset(Dataset): ...@@ -185,11 +186,12 @@ class ImageDataset(Dataset):
# the available spectral bands in the dataset # the available spectral bands in the dataset
self.sensor = self.get_sensor() self.sensor = self.get_sensor()
self._assert_get_sensor() self._assert_get_sensor()
self.bands = {band.value: band.name for band in self.sensor} self.bands = self.sensor.band_dict()
# the original class labels # the original class labels
self.label_class = self.get_labels()
self._assert_get_labels() self._assert_get_labels()
self._labels = self._build_labels() self._labels = self.label_class.label_dict()
# check which bands to use # check which bands to use
self.use_bands = (self.use_bands if self.use_bands else self.use_bands = (self.use_bands if self.use_bands else
...@@ -215,7 +217,7 @@ class ImageDataset(Dataset): ...@@ -215,7 +217,7 @@ class ImageDataset(Dataset):
self.transforms = [None] + self.transforms self.transforms = [None] + self.transforms
# when padding, add a new "no data" label to the ground truth # when padding, add a new "no data" label to the ground truth
self.cval = self.get_labels().No_data.id self.cval = self.label_class.No_data.id
if self.pad and sum(self.padding) > 0: if self.pad and sum(self.padding) > 0:
LOGGER.info('Adding label "No data" with value={} to ground truth.' LOGGER.info('Adding label "No data" with value={} to ground truth.'
.format(self.cval)) .format(self.cval))
...@@ -225,7 +227,7 @@ class ImageDataset(Dataset): ...@@ -225,7 +227,7 @@ class ImageDataset(Dataset):
# remove labels to merge from dataset instance labels # remove labels to merge from dataset instance labels
for k, v in self.merge_labels.items(): for k, v in self.merge_labels.items():
self._labels.pop(getattr(self.get_labels(), k).id) self._labels.pop(getattr(self.label_class, k).id)
LOGGER.info('Merging label: {} -> {}.'.format(k, v)) LOGGER.info('Merging label: {} -> {}.'.format(k, v))
# create model class labels # create model class labels
...@@ -236,24 +238,6 @@ class ImageDataset(Dataset): ...@@ -236,24 +238,6 @@ class ImageDataset(Dataset):
# list of ground truth images # list of ground truth images
self.gt = [] self.gt = []
def _build_labels(self):
"""Build the label dictionary.
Returns
-------
labels : `dict` [`int`, `dict`]
The label dictionary. The keys are the values of the class labels
in the ground truth. Each nested `dict` should have keys:
``'color'``
A named color (`str`).
``'label'``
The name of the class label (`str`).
"""
return {band.id: {'label': band.name.replace('_', ' '),
'color': band.color}
for band in self.get_labels()}
def _assert_compose_scenes(self): def _assert_compose_scenes(self):
"""Check whether compose_scenes() is correctly implemented.""" """Check whether compose_scenes() is correctly implemented."""
# list of required keys additional to the spectral bands # list of required keys additional to the spectral bands
...@@ -299,9 +283,10 @@ class ImageDataset(Dataset): ...@@ -299,9 +283,10 @@ class ImageDataset(Dataset):
def _assert_get_sensor(self): def _assert_get_sensor(self):
"""Check whether get_sensor() is correctly implemented.""" """Check whether get_sensor() is correctly implemented."""
if not isinstance(self.sensor, enum.EnumMeta): if not issubclass(self.sensor, MultiSpectralSensor):
raise TypeError('{}.get_sensor() should return an instance of ' raise TypeError('{}.get_sensor() should return an instance of '
'enum.Enum, containing an enumeration of the ' 'pysegcnn.core.constants.MultiSpectralSensor '
'containing an enumeration of the '
'spectral bands of the sensor the dataset is ' 'spectral bands of the sensor the dataset is '
'derived from. Examples can be found in ' 'derived from. Examples can be found in '
'pysegcnn.core.constants.py.' 'pysegcnn.core.constants.py.'
...@@ -309,7 +294,7 @@ class ImageDataset(Dataset): ...@@ -309,7 +294,7 @@ class ImageDataset(Dataset):
def _assert_get_labels(self): def _assert_get_labels(self):
"""Check whether get_labels() is correctly implemented.""" """Check whether get_labels() is correctly implemented."""
if not issubclass(self.get_labels(), Label): if not issubclass(self.label_class, Label):
raise TypeError('{}.get_labels() should return an instance of ' raise TypeError('{}.get_labels() should return an instance of '
'pysegcnn.core.constants.Label, ' 'pysegcnn.core.constants.Label, '
'containing an enumeration of the ' 'containing an enumeration of the '
...@@ -439,7 +424,7 @@ class ImageDataset(Dataset): ...@@ -439,7 +424,7 @@ class ImageDataset(Dataset):
Returns Returns
------- -------
sensor : :py:class:`enum.EnumMeta` sensor : :py:class:`pysegcnn.core.constants.MultiSpectralSensor`
An enumeration of the bands of the sensor. An enumeration of the bands of the sensor.
""" """
...@@ -460,7 +445,7 @@ class ImageDataset(Dataset): ...@@ -460,7 +445,7 @@ class ImageDataset(Dataset):
Returns Returns
------- -------
labels : :py:class:`enum.EnumMeta` labels : :py:class:`pysegcnn.core.constants.Label`
The class labels. The class labels.
""" """
...@@ -916,7 +901,7 @@ class SparcsDataset(StandardEoDataset): ...@@ -916,7 +901,7 @@ class SparcsDataset(StandardEoDataset):
Returns Returns
------- -------
sensor : :py:class:`enum.EnumMeta` sensor : :py:class:`pysegcnn.core.constants.MultiSpectralSensor`
An enumeration of the bands of the sensor. An enumeration of the bands of the sensor.
""" """
...@@ -928,7 +913,7 @@ class SparcsDataset(StandardEoDataset): ...@@ -928,7 +913,7 @@ class SparcsDataset(StandardEoDataset):
Returns Returns
------- -------
labels : :py:class:`enum.EnumMeta` labels : :py:class:`pysegcnn.core.constants.Label`
The class labels. The class labels.
""" """
...@@ -969,7 +954,7 @@ class ProSnowDataset(StandardEoDataset): ...@@ -969,7 +954,7 @@ class ProSnowDataset(StandardEoDataset):
Returns Returns
------- -------
sensor : :py:class:`enum.EnumMeta` sensor : :py:class:`pysegcnn.core.constants.MultiSpectralSensor`
An enumeration of the bands of the sensor. An enumeration of the bands of the sensor.
""" """
...@@ -981,7 +966,7 @@ class ProSnowDataset(StandardEoDataset): ...@@ -981,7 +966,7 @@ class ProSnowDataset(StandardEoDataset):
Returns Returns
------- -------
labels : :py:class:`enum.EnumMeta` labels : :py:class:`pysegcnn.core.constants.Label`
The class labels. The class labels.
""" """
...@@ -1092,7 +1077,7 @@ class Cloud95Dataset(ImageDataset): ...@@ -1092,7 +1077,7 @@ class Cloud95Dataset(ImageDataset):
Returns Returns
------- -------
sensor : :py:class:`enum.EnumMeta` sensor : :py:class:`pysegcnn.core.constants.MultiSpectralSensor`
An enumeration of the bands of the sensor. An enumeration of the bands of the sensor.
""" """
...@@ -1104,7 +1089,7 @@ class Cloud95Dataset(ImageDataset): ...@@ -1104,7 +1089,7 @@ class Cloud95Dataset(ImageDataset):
Returns Returns
------- -------
labels : :py:class:`enum.EnumMeta` labels : :py:class:`pysegcnn.core.constants.Label`
The class labels. The class labels.
""" """
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment