Skip to content
Snippets Groups Projects
Commit 451e1467 authored by Frisinghelli Daniel's avatar Frisinghelli Daniel
Browse files

Added a function to plot the confusion matrix

parent c5887fd3
No related branches found
No related tags found
No related merge requests found
......@@ -14,6 +14,7 @@ your custom dataset.
# builtins
import os
import glob
import itertools
# externals
import gdal
......@@ -22,6 +23,7 @@ import torch
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
from matplotlib.colors import ListedColormap, BoundaryNorm
from matplotlib import cm as colormap
from torch.utils.data import Dataset
......@@ -76,9 +78,10 @@ class ImageDataset(Dataset):
# get samples: (tiles x channels x height x width)
data, gt = self.build_samples(scene)
# convert to torch tensors
x = torch.tensor(data, dtype=torch.float32)
y = torch.tensor(gt, dtype=torch.uint8)
# preprocess input and return torch tensors of shape:
# x : (bands, height, width)
# y : (height, width)
x, y = self.preprocess(data, gt)
return x, y
......@@ -132,6 +135,16 @@ class ImageDataset(Dataset):
raise NotImplementedError('Inherit the ImageDataset class and '
'implement the method.')
# the preprocess() method has to be implemented by the class inheriting
# the ImageDataset class
# preprocess() should return two torch.tensors:
# - input data: tensor of shape (bands, height, width)
# - ground truth: tensor of shape (height, width)
def preprocess(self, data, gt):
raise NotImplementedError('Inherit the ImageDataset class and '
'implement the method.')
# _read_scene() reads all the bands and the ground truth mask in a
# scene/tile to a numpy array and returns a dictionary with
# (key, value) = ('band_name', np.ndarray(band_data))
......@@ -350,6 +363,60 @@ class ImageDataset(Dataset):
return fig, ax
# plot_confusion_matrix() plots the confusion matrix of the validation/test
# set returned by the pytorch.predict function
def plot_confusion_matrix(cm, labels, normalize=True,
figsize=(10, 10), cmap='Blues'):
# number of classes
nclasses = len(labels)
# string format to plot values of confusion matrix
fmt = 'd'
# check whether to normalize the confusion matrix
if normalize:
# normalize
cm = cm / cm.sum(axis=1, keepdims=True)
# change string format to floating point
fmt = '.2f'
# create figure
fig, ax = plt.subplots(1, 1, figsize=figsize)
# get colormap
cmap = colormap.get_cmap(cmap, 256)
# plot confusion matrix
im = ax.imshow(cm, cmap=cmap)
# threshold determining the color of the values
thresh = (cm.max() + cm.min()) / 2
# brightest/darkest color of current colormap
cmap_min, cmap_max = im.cmap(0), im.cmap(256)
# plot values of confusion matrix
for i, j in itertools.product(range(nclasses), range(nclasses)):
ax.text(j, i, format(cm[i, j], fmt), ha='center', va='center',
color = cmap_max if cm[i, j] < thresh else cmap_min)
# axes properties and labels
ax.set(xticks=np.arange(nclasses),
yticks=np.arange(nclasses),
xticklabels=labels,
yticklabels=labels,
ylabel='True',
xlabel='Predicted')
# add colorbar axes
cax = fig.add_axes([ax.get_position().x1 + 0.025, ax.get_position().y0,
0.05, ax.get_position().y1 - ax.get_position().y0])
fig.colorbar(im, cax=cax)
return fig, ax
# SparcsDataset class: inherits from the generic ImageDataset class
class SparcsDataset(ImageDataset):
......@@ -394,6 +461,16 @@ class SparcsDataset(ImageDataset):
lc[i] = {'label': l, 'color': c}
return lc
# preprocessing of the Sparcs dataset
def preprocess(self, data, gt):
# if the preprocessing is not done externally, implement it here
# convert to torch tensors
x = torch.tensor(data, dtype=torch.float32)
y = torch.tensor(gt, dtype=torch.uint8)
return x, y
# _compose_scenes() creates a list of dictionaries containing the paths
# to the files of each scene
# if the scenes are divided into tiles, each tile has its own entry
......@@ -447,7 +524,7 @@ class SparcsDataset(ImageDataset):
return scene_data
class Cloud95(ImageDataset):
class Cloud95Dataset(ImageDataset):
def __init__(self, root_dir, use_bands=[], tile_size=None):
# initialize super class ImageDataset
......@@ -467,9 +544,21 @@ class Cloud95(ImageDataset):
# class labels of the Cloud-95 dataset
def get_labels(self):
return {0: {'label': 'Clear', 'color': 'azure'},
return {0: {'label': 'Clear', 'color': 'skyblue'},
1: {'label': 'Cloud', 'color': 'white'}}
# preprocess Cloud-95 dataset
def preprocess(self, data, gt):
# normalize the data
# here, we use the normalization of the authors of Cloud-95, i.e.
# Mohajerani and Saeedi (2019, 2020)
x = torch.tensor(data / 65535, dtype=torch.float32)
y = torch.tensor(gt / 255, dtype=torch.uint8)
return x, y
def compose_scenes(self, root_dir):
# get the names of the directories containing the TIFF files of
......@@ -478,7 +567,7 @@ class Cloud95(ImageDataset):
for dirpath, dirname, files in os.walk(root_dir):
# check if the current directory path includes the name of a band
# or the name of the ground truth mask
cband = [band for band in [*self.bands.values()] + ['gt'] if
cband = [band for band in self.use_bands + ['gt'] if
dirpath.endswith(band) and os.path.isdir(dirpath)]
# add path to current band files to dictionary
......@@ -527,23 +616,30 @@ if __name__ == '__main__':
cloud_path = os.path.join(wd, '_Datasets/Cloud95/Training')
# instanciate the Cloud-95 dataset
cloud_dataset = Cloud95(cloud_path)
cloud_dataset = Cloud95Dataset(cloud_path, tile_size=192)
# instanciate the SparcsDataset class
sparcs_dataset = SparcsDataset(sparcs_path, tile_size=None,
use_bands=['nir', 'red', 'green'])
# randomly sample an integer from [0, nsamples]
sample = np.random.randint(len(sparcs_dataset), size=1).item()
# a sample from the sparcs dataset
sample_x, sample_y = sparcs_dataset[sample]
sample_s = np.random.randint(len(sparcs_dataset), size=1).item()
s_x, s_y = sparcs_dataset[sample_s]
fig, ax = sparcs_dataset.plot_sample(s_x, s_y,
bands=['nir', 'red', 'green'])
# print shape of the sample
print('A sample from the Sparcs dataset:')
print('Shape of input data: {}'.format(sample_x.shape))
print('Shape of ground truth: {}'.format(sample_y.shape))
# a sample from the cloud dataset
sample_c = np.random.randint(len(cloud_dataset), size=1).item()
c_x, c_y = cloud_dataset[sample_c]
fig, ax = cloud_dataset.plot_sample(c_x, c_y,
bands=['nir', 'red', 'green'])
# plot the sample
fig, ax = sparcs_dataset.plot_sample(sample_x, sample_y,
bands=['nir', 'red', 'green'])
# print shape of the sample
for i, l, d in zip([s_x, c_x], [s_y, c_y],
[sparcs_dataset, cloud_dataset]):
print('A sample from the {}:'.format(d.__class__.__name__))
print('Shape of input data: {}'.format(i.shape))
print('Shape of ground truth: {}'.format(l.shape))
# show figures
plt.show()
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment