Skip to content
Snippets Groups Projects
Commit a1669e70 authored by Frisinghelli Daniel's avatar Frisinghelli Daniel
Browse files

Extended generic ImageDataset class; added Cloud95 support

parent 9a1d8772
No related branches found
No related tags found
No related merge requests found
......@@ -28,28 +28,136 @@ from torch.utils.data import Dataset
# generic image dataset class
class ImageDataset(Dataset):
def __init__(self, root_dir):
def __init__(self, root_dir, use_bands, tile_size):
super().__init__()
# the root directory: path to the image dataset
self.root = root_dir
# this function should return the length of the image dataset
# __len__() is used by pytorch to determine the total number of samples in
# the dataset, has to be implemented by a class inheriting from the
# ImageDataset class
# the size of a scene/patch in the dataset
self.size = self.get_size()
# the available spectral bands in the dataset
self.bands = self.get_bands()
# the class labels
self.labels = self.get_labels()
# check which bands to use
self.use_bands = (use_bands if use_bands else [*self.bands.values()])
# each scene is divided into (tile_size x tile_size) blocks
# each of these blocks is treated as a single sample
self.tile_size = tile_size
# calculate number of resulting tiles and check whether the images are
# evenly divisible in square tiles of size (tile_size x tile_size)
if self.tile_size is None:
self.tiles = None
else:
self.tiles = self.is_divisible(self.size, self.tile_size)
# the samples of the dataset
self.scenes = []
# the __len__() method returns the number of samples in the dataset
def __len__(self):
raise NotImplementedError('Inherit the ImageDataset class and '
'implement the method.')
# number of (tiles x channels x height x width) patches after each
# scene is decomposed to tiles blocks
return len(self.scenes)
# this function should return a single sample of the dataset given an
# the __getitem__() method returns a single sample of the dataset given an
# index, i.e. an array/tensor of shape (channels x height x width)
# it has to be implemented by a class inheriting from the
# ImageDataset class
def __getitem__(self, idx):
# select a scene
scene = self.read_scene(idx)
# get samples: (tiles x channels x height x width)
data, gt = self.build_samples(scene)
# convert to torch tensors
x = torch.tensor(data, dtype=torch.float32)
y = torch.tensor(gt, dtype=torch.uint8)
return x, y
# the compose_scenes() method has to be implemented by the class inheriting
# the ImageDataset class
# compose_scenes() should return a list of dictionaries, where each
# dictionary represent one sample of the dataset, a scene or a tile
# of a scene, etc.
# the dictionaries should have the following (key, value) pairs:
# - (band_1, path_to_band_1.tif)
# - (band_2, path_to_band_2.tif)
# - ...
# - (band_n, path_to_band_n.tif)
# - (gt, path_to_ground_truth.tif)
# - (tile, None or int)
def compose_scenes(self, *args, **kwargs):
raise NotImplementedError('Inherit the ImageDataset class and '
'implement the method.')
# the get_size() method has to be implemented by the class inheriting
# the ImageDataset class
# get_size() method should return the image size as tuple, (height, width)
def get_size(self, *args, **kwargs):
raise NotImplementedError('Inherit the ImageDataset class and '
'implement the method.')
# the get_bands() method has to be implemented by the class inheriting
# the ImageDataset class
# get_bands() should return a dictionary with the following
# (key: int, value: str) pairs:
# - (1, band_1_name)
# - (2, band_2_name)
# - ...
# - (n, band_n_name)
def get_bands(self, *args, **kwargs):
raise NotImplementedError('Inherit the ImageDataset class and '
'implement the method.')
# the get_labels() method has to be implemented by the class inheriting
# the ImageDataset class
# get_labels() should return a dictionary with the following
# (key: int, value: str) pairs:
# - (0, label_1_name)
# - (1, label_2_name)
# - ...
# - (n, label_n_name)
# where the keys should be the values representing the values of the
# corresponding label in the ground truth mask
# the labels in the dictionary determine the classes to be segmented
def get_labels(self, *args, **kwargs):
raise NotImplementedError('Inherit the ImageDataset class and '
'implement the method.')
# _read_scene() reads all the bands and the ground truth mask in a
# scene/tile to a numpy array and returns a dictionary with
# (key, value) = ('band_name', np.ndarray(band_data))
def read_scene(self, idx):
# select a scene from the root directory
scene = self.scenes[idx]
# read each band of the scene into a numpy array
scene_data = {key: (self.img2np(value, tile_size=self.tile_size,
tile=scene['tile'])
if key != 'tile' else value)
for key, value in scene.items()}
return scene_data
# _build_samples() stacks all bands of a scene/tile into a
# numpy array of shape (bands x height x width)
def build_samples(self, scene):
# iterate over the channels to stack
stack = np.stack([scene[band] for band in self.use_bands], axis=0)
gt = scene['gt']
return stack, gt
# the following functions are utility functions for common image
# manipulation operations
......@@ -194,117 +302,103 @@ class ImageDataset(Dataset):
return norm
# plot_sample() plots a false color composite of the scene/tile together
# with the model prediction and the corresponding ground truth
def plot_sample(self, x, y, y_pred=None, figsize=(10, 10),
bands=['red', 'green', 'blue'], stretch=False, **kwargs):
# SparcsDataset class: inherits from the generic ImageDataset class
class SparcsDataset(ImageDataset):
# check whether to apply constrast stretching
if kwargs: stretch = True
func = self.contrast_stretching if stretch else lambda x: x
def __init__(self, root_dir, bands=['red', 'green', 'blue'],
tile_size=None):
super().__init__(root_dir)
# Landsat 8 bands in the SPARCS dataset
self.sparcs_bands = {1: 'violet',
2: 'blue',
3: 'green',
4: 'red',
5: 'nir',
6: 'swir1',
7: 'swir2',
8: 'pan',
9: 'cirrus',
10: 'tir'}
# class labels and corresponding color map
self.labels = {0: 'Shadow',
1: 'Shadow over Water',
2: 'Water',
3: 'Snow',
4: 'Land',
5: 'Cloud',
6: 'Flooded'}
self.colors = {0: 'black',
1: 'darkblue',
2: 'blue',
3: 'lightblue',
4: 'grey',
5: 'white',
6: 'yellow'}
# image size of the SPARCS dataset: height x width
self.size = (1000, 1000)
# create an rgb stack
rgb = np.dstack([func(x[self.use_bands.index(band)],
**kwargs) for band in bands])
# check which bands to use
if bands == -1:
# in case bands=-1, use all bands of the sparcs dataset
self.bands = [*self.sparcs_bands.values()]
else:
self.bands = bands
# get labels and corresponding colors
labels = [label['label'] for label in self.labels.values()]
colors = [label['color'] for label in self.labels.values()]
# each scene is divided into (tile_size x tile_size) blocks
# each of these blocks is treated as a single sample
self.tile_size = tile_size
# create a ListedColormap
cmap = ListedColormap(colors)
boundaries = [*self.labels.keys(), cmap.N]
norm = BoundaryNorm(boundaries, cmap.N)
# calculate number of resulting tiles and check whether the images are
# evenly divisible in square tiles of size (tile_size x tile_size)
if self.tile_size is None:
self.tiles = None
# create figure: check whether to plot model prediction
if y_pred is not None:
fig, ax = plt.subplots(1, 3, figsize=figsize)
ax[2].imshow(y_pred, cmap=cmap, interpolation='nearest', norm=norm)
ax[2].set_title('Prediction', pad=20)
else:
self.tiles = self.is_divisible(self.size, self.tile_size)
# list of all scenes in the root directory
# each scene is divided into tiles blocks
self.scenes = []
for scene in os.listdir(root_dir):
self.scenes += self._compose_scenes(os.path.join(root_dir, scene))
# the __len__() method returns the number of samples in the Sparcs dataset
def __len__(self):
# number of (tiles x channels x height x width) patches after each
# scene is decomposed to tiles blocks
return len(self.scenes)
fig, ax = plt.subplots(1, 2, figsize=figsize)
# the __getitem__() method returns a sample of the Sparcs dataset
# __getitem__() is implicitly used by pytorch to draw samples during
# the training process
def __getitem__(self, idx):
# plot false color composite
ax[0].imshow(rgb)
ax[0].set_title('R = {}, G = {}, B = {}'.format(*bands), pad=20)
# select a scene
scene = self._read_scene(idx)
# plot ground thruth mask
ax[1].imshow(y, cmap=cmap, interpolation='nearest', norm=norm)
ax[1].set_title('Ground truth', pad=20)
# get samples: (tiles x channels x height x width)
data, gt = self._build_samples(scene)
# create a patch (proxy artist) for every color
patches = [mpatches.Patch(color=c, label=l) for c, l in
zip(colors, labels)]
# convert to torch tensors
x = torch.tensor(data, dtype=torch.float32)
y = torch.tensor(gt, dtype=torch.uint8)
# plot patches as legend
plt.legend(handles=patches, bbox_to_anchor=(1.05, 1), loc=2,
frameon=False)
return x, y
return fig, ax
# returns the band number of the preprocessed Sparcs Tiff files
def _get_band_number(self, x):
return int(os.path.basename(x).split('_')[2].replace('B', ''))
# _store_bands() writes the paths to the data of each scene to a dictionary
# only the bands of interest are stored
def _store_bands(self, bands, gt):
# store the bands of interest in a dictionary
scene_data = {}
for i, b in enumerate(bands):
band = self.sparcs_bands[self._get_band_number(b)]
if band in self.bands:
scene_data[band] = b
# SparcsDataset class: inherits from the generic ImageDataset class
class SparcsDataset(ImageDataset):
# store ground truth
scene_data['gt'] = gt
def __init__(self, root_dir, use_bands=['red', 'green', 'blue'],
tile_size=None):
# initialize super class ImageDataset
super().__init__(root_dir, use_bands, tile_size)
return scene_data
# list of all scenes in the root directory
# each scene is divided into tiles blocks
self.scenes = []
for scene in os.listdir(self.root):
self.scenes += self.compose_scenes(os.path.join(self.root, scene))
# image size of the Sparcs dataset: (height, width)
def get_size(self):
return (1000, 1000)
# Landsat 8 bands of the Sparcs dataset
def get_bands(self):
return {
1: 'violet',
2: 'blue',
3: 'green',
4: 'red',
5: 'nir',
6: 'swir1',
7: 'swir2',
8: 'pan',
9: 'cirrus',
10: 'tir'}
# class labels of the Sparcs dataset
def get_labels(self):
labels = ['Shadow', 'Shadow over Water', 'Water', 'Snow', 'Land',
'Cloud', 'Flooded']
colors = ['black', 'darkblue', 'blue', 'lightblue', 'grey', 'white',
'yellow']
lc = {}
for i, (l, c) in enumerate(zip(labels, colors)):
lc[i] = {'label': l, 'color': c}
return lc
# _compose_scenes() creates a list of dictionaries containing the paths
# to the files of each scene
# if the scenes are divided into tiles, each tile has its own entry
# with corresponding tile id
def _compose_scenes(self, scene):
def compose_scenes(self, scene):
# list the spectral bands of the scene
bands = glob.glob(os.path.join(scene, '*B*.tif'))
......@@ -346,84 +440,96 @@ class SparcsDataset(ImageDataset):
return scene_data
# _read_scene() reads all the bands and the ground truth mask in a
# scene/tile to a numpy array and returns a dictionary with
# (key, value) = ('band_name', np.ndarray(band_data))
def _read_scene(self, idx):
# returns the band number of the preprocessed Sparcs Tiff files
def _get_band_number(self, x):
return int(os.path.basename(x).split('_')[2].replace('B', ''))
# select a scene from the root directory
scene = self.scenes[idx]
# _store_bands() writes the paths to the data of each scene to a dictionary
# only the bands of interest are stored
def _store_bands(self, bands, gt):
# read each band of the scene into a numpy array
scene_data = {key: (self.img2np(value, tile_size=self.tile_size,
tile=scene['tile'])
if key != 'tile' else value)
for key, value in scene.items()}
# store the bands of interest in a dictionary
scene_data = {}
for i, b in enumerate(bands):
band = self.bands[self._get_band_number(b)]
if band in self.use_bands:
scene_data[band] = b
# store ground truth
scene_data['gt'] = gt
return scene_data
# _build_samples() stacks all bands of a scene/tile into a
# numpy array of shape (bands x height x width)
def _build_samples(self, scene):
# iterate over the channels to stack
stack = np.stack([scene[band] for band in self.bands], axis=0)
gt = scene['gt']
class Cloud95(ImageDataset):
return stack, gt
def __init__(self, root_dir, use_bands=[], tile_size=None):
# initialize super class ImageDataset
super().__init__(root_dir, use_bands, tile_size)
# plot_sample() plots a false color composite of the scene/tile together
# with the model prediction and the corresponding ground truth
def plot_sample(self, x, y, y_pred=None, figsize=(10, 10),
bands=['nir', 'red', 'green'], stretch=False, **kwargs):
# list of all scenes in the root directory
# each scene is divided into tiles blocks
self.scenes = self.compose_scenes(self.root)
# check whether to apply constrast stretching
func = self.contrast_stretching if stretch else lambda x: x
# image size of the Cloud-95 dataset: (height, width)
def get_size(self):
return (384, 384)
# create an rgb stack
rgb = np.dstack([func(x[self.bands.index(band)],
**kwargs) for band in bands])
# Landsat 8 bands in the Cloud-95 dataset
def get_bands(self):
return {1: 'red', 2: 'green', 3: 'blue', 4: 'nir'}
# create a ListedColormap
cmap = ListedColormap(self.colors.values())
boundaries = [*self.colors.keys(), cmap.N]
norm = BoundaryNorm(boundaries, cmap.N)
# class labels of the Cloud-95 dataset
def get_labels(self):
return {0: {'label': 'Clear', 'color': 'azure'},
1: {'label': 'Cloud', 'color': 'white'}}
# create figure: check whether to plot model prediction
if y_pred is not None:
fig, ax = plt.subplots(1, 3, figsize=figsize)
ax[2].imshow(y_pred, cmap=cmap, interpolation='nearest', norm=norm)
ax[2].set_title('Prediction', pad=20)
else:
fig, ax = plt.subplots(1, 2, figsize=figsize)
def compose_scenes(self, root_dir):
# plot false color composite
ax[0].imshow(rgb)
ax[0].set_title('R = {}, G = {}, B = {}'.format(*bands), pad=20)
# get the names of the directories containing the TIFF files of
# the bands of interest
band_dirs = {}
for dirpath, dirname, files in os.walk(root_dir):
# check if the current directory path includes the name of a band
# or the name of the ground truth mask
cband = [band for band in self.bands + ['gt'] if band in dirpath
and os.path.isdir(dirpath)]
# plot ground thruth mask
ax[1].imshow(y, cmap=cmap, interpolation='nearest', norm=norm)
ax[1].set_title('Ground truth', pad=20)
# add path to current band files to dictionary
if cband:
band_dirs[cband] = dirpath
# create a patch (proxy artist) for every color
patches = [mpatches.Patch(color=c, label=l) for c, l in
zip(self.colors.values(), self.labels.values())]
# create empty list to store all patches to
scenes = []
# plot patches as legend
plt.legend(handles=patches, bbox_to_anchor=(1.05, 1), loc=2,
frameon=False)
# iterate over all the patches of the following band
biter = self.bands[0]
for file in os.listdir(band_dirs[biter]):
return fig, ax
# initialize dictionary to store bands of current patch
scene = {}
# iterate over the bands of interest
for band in band_dirs.keys():
# save path to current band TIFF file to dictionary
scene[band] = os.path.join(band_dirs[band],
file.replace(biter, band))
# append patch to list of all patches
scenes.append(scene)
return scenes
if __name__ == '__main__':
# path to the preprocessed sparcs dataset
sparcs_path = "C:/Eurac/2020/Tutorial/Datasets/Sparcs"
sparcs_path = "C:/Eurac/2020/_Datasets/Sparcs"
# sparcs_path = "/mnt/CEPH_PROJECTS/cci_snow/dfrisinghelli/Datasets/Sparcs"
# instanciate the SparcsDataset class
sparcs_dataset = SparcsDataset(sparcs_path, tile_size=None, bands=-1)
sparcs_dataset = SparcsDataset(sparcs_path, tile_size=125,
use_bands=['nir', 'red', 'green'])
# randomly sample an integer from [0, nsamples]
sample = np.random.randint(len(sparcs_dataset), size=1).item()
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment