Skip to content
Snippets Groups Projects
Commit 90c93b50 authored by Frisinghelli Daniel's avatar Frisinghelli Daniel
Browse files

Adding docstrings: part 1

parent c3da3509
No related branches found
No related tags found
No related merge requests found
"""A collection of enumerations of constant values."""
# !/usr/bin/env python
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
"""
Created on Tue Jul 14 10:58:20 2020
@author: Daniel
"""
# builtins # builtins
import enum import enum
# Landsat 8 bands # Landsat 8 bands
class Landsat8(enum.Enum): class Landsat8(enum.Enum):
"""The spectral bands of the Landsat 8 sensors.
sensors:
- Operational Land Imager (OLI), (bands 1-9)
- Thermal Infrared Sensor (TIRS), (bands 10, 11)
"""
violet = 1 violet = 1
blue = 2 blue = 2
green = 3 green = 3
...@@ -25,6 +32,8 @@ class Landsat8(enum.Enum): ...@@ -25,6 +32,8 @@ class Landsat8(enum.Enum):
# Sentinel 2 bands # Sentinel 2 bands
class Sentinel2(enum.Enum): class Sentinel2(enum.Enum):
"""The spectral bands of the Sentinel-2 MultiSpectral Instrument (MSI)."""
aerosol = 1 aerosol = 1
blue = 2 blue = 2
green = 3 green = 3
...@@ -42,18 +51,28 @@ class Sentinel2(enum.Enum): ...@@ -42,18 +51,28 @@ class Sentinel2(enum.Enum):
# generic class label enumeration class # generic class label enumeration class
class Label(enum.Enum): class Label(enum.Enum):
"""Generic enumeration for class labels."""
@property @property
def id(self): def id(self):
"""Return the value of a class in the ground truth."""
return self.value[0] return self.value[0]
@property @property
def color(self): def color(self):
"""Return the color to plot a class."""
return self.value[1] return self.value[1]
# labels of the Sparcs dataset # labels of the Sparcs dataset
class SparcsLabels(Label): class SparcsLabels(Label):
"""Class labels of the `Sparcs`_ dataset.
.. _Sparcs:
https://www.usgs.gov/land-resources/nli/landsat/spatial-procedures-automated-removal-cloud-and-shadow-sparcs-validation
"""
Shadow = 0, 'grey' Shadow = 0, 'grey'
Shadow_over_water = 1, 'darkblue' Shadow_over_water = 1, 'darkblue'
Water = 2, 'blue' Water = 2, 'blue'
...@@ -65,12 +84,21 @@ class SparcsLabels(Label): ...@@ -65,12 +84,21 @@ class SparcsLabels(Label):
# labels of the Cloud95 dataset # labels of the Cloud95 dataset
class Cloud95Labels(Label): class Cloud95Labels(Label):
"""Class labels of the `Cloud-95`_ dataset.
.. _Cloud-95:
https://github.com/SorourMo/95-Cloud-An-Extension-to-38-Cloud-Dataset
"""
Clear = 0, 'skyblue' Clear = 0, 'skyblue'
Cloud = 1, 'white' Cloud = 1, 'white'
# labels of the ProSnow dataset # labels of the ProSnow dataset
class ProSnowLabels(Label): class ProSnowLabels(Label):
"""Class labels of the ProSnow datasets."""
Cloud = 0, 'white' Cloud = 0, 'white'
Snow = 1, 'lightblue' Snow = 1, 'lightblue'
Snow_free = 2, 'sienna' Snow_free = 2, 'sienna'
"""Functions to plot multispectral image data and model output."""
# !/usr/bin/env python
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
"""
Created on Tue Jul 14 11:04:27 2020
@author: Daniel
"""
# builtins # builtins
import os import os
import itertools import itertools
...@@ -25,7 +24,21 @@ from pysegcnn.main.config import HERE ...@@ -25,7 +24,21 @@ from pysegcnn.main.config import HERE
# this function applies percentile stretching at the alpha level # this function applies percentile stretching at the alpha level
# can be used to increase constrast for visualization # can be used to increase constrast for visualization
def contrast_stretching(image, alpha=5): def contrast_stretching(image, alpha=5):
"""Apply percentile stretching to an image to increase constrast.
Parameters
----------
image : `numpy.ndarray`
the input image.
alpha : `int`, optional
The level of the percentiles. The default is 5.
Returns
-------
norm : `numpy.ndarray`
the stretched image.
"""
# compute upper and lower percentiles defining the range of the stretch # compute upper and lower percentiles defining the range of the stretch
inf, sup = np.percentile(image, (alpha, 100 - alpha)) inf, sup = np.percentile(image, (alpha, 100 - alpha))
...@@ -34,7 +47,7 @@ def contrast_stretching(image, alpha=5): ...@@ -34,7 +47,7 @@ def contrast_stretching(image, alpha=5):
norm = ((image - inf) * (image.max() - image.min()) / norm = ((image - inf) * (image.max() - image.min()) /
(sup - inf)) + image.min() (sup - inf)) + image.min()
# clip: values < inf = 0, values > sup = max # clip: values < min = min, values > max = max
norm[norm <= image.min()] = image.min() norm[norm <= image.min()] = image.min()
norm[norm >= image.max()] = image.max() norm[norm >= image.max()] = image.max()
...@@ -42,6 +55,21 @@ def contrast_stretching(image, alpha=5): ...@@ -42,6 +55,21 @@ def contrast_stretching(image, alpha=5):
def running_mean(x, w): def running_mean(x, w):
"""Compute a running mean of the input sequence.
Parameters
----------
x : array_like
The sequence to compute a running mean on.
w : `int`
The window length of the running mean.
Returns
-------
rm : `numpy.ndarray`
The running mean of the sequence ``x``.
"""
cumsum = np.cumsum(np.insert(x, 0, 0)) cumsum = np.cumsum(np.insert(x, 0, 0))
return (cumsum[w:] - cumsum[:-w]) / w return (cumsum[w:] - cumsum[:-w]) / w
...@@ -51,7 +79,48 @@ def running_mean(x, w): ...@@ -51,7 +79,48 @@ def running_mean(x, w):
def plot_sample(x, y, use_bands, labels, y_pred=None, figsize=(10, 10), def plot_sample(x, y, use_bands, labels, y_pred=None, figsize=(10, 10),
bands=['nir', 'red', 'green'], state=None, bands=['nir', 'red', 'green'], state=None,
outpath=os.path.join(HERE, '_samples/'), alpha=0): outpath=os.path.join(HERE, '_samples/'), alpha=0):
"""Plot false color composite (FCC), ground truth and model prediction.
Parameters
----------
x : `numpy.ndarray` or `torch.tensor`, (b, h, w)
Array containing the raw data of the tile, shape=(bands, height, width)
y : `numpy.ndarray` or `torch.tensor`, (h, w)
Array containing the ground truth of tile ``x``, shape=(height, width)
use_bands : `list` of `str`
List describing the order of the bands in ``x``.
labels : `dict` [`int`, `dict`]
The keys are the values of the class labels in the ground truth ``y``.
Each nested `dict` should have keys:
``'color'``
A named color (`str`).
``'label'``
The name of the class label (`str`).
y_pred : `numpy.ndarray` or `None`, optional
Array containing the prediction for tile ``x``, shape=(height, width).
The default is None, i.e. only FCC and ground truth are plotted.
figsize : `tuple`, optional
The figure size in centimeters. The default is (10, 10).
bands : `list` [`str`], optional
The bands to build the FCC. The default is ['nir', 'red', 'green'].
state : `str` or `None`, optional
Filename to save the plot to. ``state`` should be an existing model
state file ending with '.pt'. The default is None, i.e. plot is not
saved to disk.
outpath : `str` or `pathlib.Path`, optional
Output path. The default is os.path.join(HERE, '_samples/').
alpha : `int`, optional
The level of the percentiles to increase constrast in the FCC.
The default is 0, i.e. no stretching.
Returns
-------
fig : `matplotlib.figure.Figure`
The figure handle.
ax : `matplotlib.axes._subplots.AxesSubplot`
The axes handle.
"""
# check whether to apply constrast stretching # check whether to apply constrast stretching
rgb = np.dstack([contrast_stretching(x[use_bands.index(band)], alpha) rgb = np.dstack([contrast_stretching(x[use_bands.index(band)], alpha)
for band in bands]) for band in bands])
...@@ -109,7 +178,40 @@ def plot_sample(x, y, use_bands, labels, y_pred=None, figsize=(10, 10), ...@@ -109,7 +178,40 @@ def plot_sample(x, y, use_bands, labels, y_pred=None, figsize=(10, 10),
def plot_confusion_matrix(cm, labels, normalize=True, def plot_confusion_matrix(cm, labels, normalize=True,
figsize=(10, 10), cmap='Blues', state=None, figsize=(10, 10), cmap='Blues', state=None,
outpath=os.path.join(HERE, '_graphics/')): outpath=os.path.join(HERE, '_graphics/')):
"""Plot the confusion matrix ``cm``.
Parameters
----------
cm : `numpy.ndarray`
The confusion matrix.
labels : `dict` [`int`, `dict`]
The keys are the values of the class labels in the ground truth ``y``.
Each nested `dict` should have keys:
``'color'``
A named color (`str`).
``'label'``
The name of the class label (`str`).
normalize : `bool`, optional
Whether to normalize the confusion matrix. The default is True.
figsize : `tuple`, optional
The figure size in centimeters. The default is (10, 10).
cmap : `str`, optional
A colormap in `matplotlib.pyplot.colormaps()`. The default is 'Blues'.
state : `str` or `None`, optional
Filename to save the plot to. ``state`` should be an existing model
state file ending with '.pt'. The default is None, i.e. plot is not
saved to disk.
outpath : `str` or `pathlib.Path`, optional
Output path. The default is os.path.join(HERE, '_graphics/').
Returns
-------
fig : `matplotlib.figure.Figure`
The figure handle.
ax : `matplotlib.axes._subplots.AxesSubplot`
The axes handle.
"""
# number of classes # number of classes
labels = [label['label'] for label in labels.values()] labels = [label['label'] for label in labels.values()]
nclasses = len(labels) nclasses = len(labels)
...@@ -177,7 +279,30 @@ def plot_confusion_matrix(cm, labels, normalize=True, ...@@ -177,7 +279,30 @@ def plot_confusion_matrix(cm, labels, normalize=True,
def plot_loss(state_file, figsize=(10, 10), step=5, def plot_loss(state_file, figsize=(10, 10), step=5,
colors=['lightgreen', 'green', 'skyblue', 'steelblue'], colors=['lightgreen', 'green', 'skyblue', 'steelblue'],
outpath=os.path.join(HERE, '_graphics/')): outpath=os.path.join(HERE, '_graphics/')):
"""Plot the observed loss and accuracy of a model run.
Parameters
----------
state_file : `str` or `pathlib.Path`
The model state file. Model state files are stored in
pysegcnn/main/_models.
figsize : `tuple`, optional
The figure size in centimeters. The default is (10, 10).
step : `int`, optional
The step of epochs for the x-axis labels. The default is 5, i.e. label
each fifth epoch.
colors : `list` [`str`], optional
A list of four named colors supported by `matplotlib`.
The default is ['lightgreen', 'green', 'skyblue', 'steelblue'].
outpath : `str` or `pathlib.Path`, optional
Output path. The default is os.path.join(HERE, '_graphics/').
Returns
-------
fig : `matplotlib.figure.Figure`
The figure handle.
"""
# load the model state # load the model state
model_state = torch.load(state_file) model_state = torch.load(state_file)
......
"""Layers of a convolutional encoder-decoder network."""
# !/usr/bin/env python # !/usr/bin/env python
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
"""
Created on Fri Jun 26 16:23:36 2020
@author: Daniel
"""
# externals # externals
import torch import torch
...@@ -13,6 +10,34 @@ import torch.nn.functional as F ...@@ -13,6 +10,34 @@ import torch.nn.functional as F
class Conv2dSame(nn.Conv2d): class Conv2dSame(nn.Conv2d):
"""A convolution preserving the shape of its input.
Given the kernel size, the dilation and a stride of 1, the padding is
calculated such that the output of the convolution has the same spatial
dimensions as the input.
Parameters
----------
*args: `list` [`str`]
positional arguments passed to `torch.nn.Conv2d`:
``'in_channels'``: `int`
Number of input channels
``'kernel_size'``: `int` or `tuple` [`int`]
Size of the convolving kernel
``'out_channels'``: `int`
Number of desired output channels
**kwargs: 'dict' [`str`]
Additional keyword arguments passed to `torch.nn.Conv2d`_.
.. _torch.nn.Conv2d:
https://pytorch.org/docs/stable/generated/torch.nn.Conv2d.html#torch.nn.Conv2d
Returns
-------
None.
"""
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
...@@ -23,12 +48,45 @@ class Conv2dSame(nn.Conv2d): ...@@ -23,12 +48,45 @@ class Conv2dSame(nn.Conv2d):
self.padding = (y_pad, x_pad) self.padding = (y_pad, x_pad)
def same_padding(self, d, k): def same_padding(self, d, k):
"""Calculate the amount of padding.
Parameters
----------
d : `int`
The dilation of the convolution.
k : `int`
The kernel size.
Returns
-------
p : `int`
the amount of padding.
"""
# calculates the padding so that the convolution # calculates the padding so that the convolution
# conserves the shape of its input when stride = 1 # conserves the shape of its input when stride = 1
return int(d * (k - 1) / 2) return int(d * (k - 1) / 2)
def conv_bn_relu(in_channels, out_channels, **kwargs): def conv_bn_relu(in_channels, out_channels, **kwargs):
"""Block of convolution, batch normalization and rectified linear unit.
Parameters
----------
in_channels : `int`
Number of input channels.
out_channels : `int`
Number of desired output channels
**kwargs: 'dict' [`str`]
Additional keyword arguments passed to
`pysegcnn.core.layers.Conv2dSame`.
Returns
-------
block : `torch.nn.Sequential`
An instance of `torch.nn.Sequential` containing the different layers.
"""
return nn.Sequential( return nn.Sequential(
Conv2dSame(in_channels, out_channels, **kwargs), Conv2dSame(in_channels, out_channels, **kwargs),
nn.BatchNorm2d(out_channels), nn.BatchNorm2d(out_channels),
...@@ -40,8 +98,26 @@ def conv_bn_relu(in_channels, out_channels, **kwargs): ...@@ -40,8 +98,26 @@ def conv_bn_relu(in_channels, out_channels, **kwargs):
class Conv2dPool(nn.Module): class Conv2dPool(nn.Module):
"""Block of convolution, batchnorm, relu and 2x2 max pool.
Parameters
----------
in_channels : `int`
Number of input channels.
out_channels : `int`
Number of desired output channels
**kwargs: 'dict' [`str`]
Additional keyword arguments passed to
`pysegcnn.core.layers.Conv2dSame`.
Returns
-------
None.
"""
def __init__(self, in_channels, out_channels, **kwargs): def __init__(self, in_channels, out_channels, **kwargs):
# initialize nn.Module class # initialize nn.Module class
super().__init__() super().__init__()
...@@ -53,7 +129,23 @@ class Conv2dPool(nn.Module): ...@@ -53,7 +129,23 @@ class Conv2dPool(nn.Module):
# defines the forward pass # defines the forward pass
def forward(self, x): def forward(self, x):
"""Forward propagation through this block.
Parameters
----------
x : `torch.tensor`
Output of previous layer.
Returns
-------
y : `torch.tensor`
Output of this block.
x : `torch.tensor`
Output before max pooling. Stored for skip connections.
i : `torch.tensor`
Indices of the max pooling operation. Used in unpooling operation.
"""
# output of the convolutional block # output of the convolutional block
x = self.conv(x) x = self.conv(x)
...@@ -64,6 +156,23 @@ class Conv2dPool(nn.Module): ...@@ -64,6 +156,23 @@ class Conv2dPool(nn.Module):
class Conv2dUnpool(nn.Module): class Conv2dUnpool(nn.Module):
"""Block of convolution, batchnorm, relu and 2x2 max unpool.
Parameters
----------
in_channels : `int`
Number of input channels.
out_channels : `int`
Number of desired output channels
**kwargs: 'dict' [`str`]
Additional keyword arguments passed to
`pysegcnn.core.layers.Conv2dSame`.
Returns
-------
None.
"""
def __init__(self, in_channels, out_channels, **kwargs): def __init__(self, in_channels, out_channels, **kwargs):
super().__init__() super().__init__()
...@@ -76,7 +185,25 @@ class Conv2dUnpool(nn.Module): ...@@ -76,7 +185,25 @@ class Conv2dUnpool(nn.Module):
# defines the forward pass # defines the forward pass
def forward(self, x, feature, indices, skip): def forward(self, x, feature, indices, skip):
"""Forward propagation through this block.
Parameters
----------
x : `torch.tensor`
Output of previous layer.
feature : `torch.tensor`
Encoder feature used for the skip connection.
indices : `torch.tensor`
Indices of the max pooling operation. Used in unpooling operation.
skip : `bool`
Whether to apply skip connetion.
Returns
-------
x : `torch.tensor`
Output of this block.
"""
# upsampling with pooling indices # upsampling with pooling indices
x = self.upsample(x, indices, output_size=feature.shape) x = self.upsample(x, indices, output_size=feature.shape)
...@@ -93,6 +220,23 @@ class Conv2dUnpool(nn.Module): ...@@ -93,6 +220,23 @@ class Conv2dUnpool(nn.Module):
class Conv2dUpsample(nn.Module): class Conv2dUpsample(nn.Module):
"""Block of convolution, batchnorm, relu and nearest neighbor upsampling.
Parameters
----------
in_channels : `int`
Number of input channels.
out_channels : `int`
Number of desired output channels
**kwargs: 'dict' [`str`]
Additional keyword arguments passed to
`pysegcnn.core.layers.Conv2dSame`.
Returns
-------
None.
"""
def __init__(self, in_channels, out_channels, **kwargs): def __init__(self, in_channels, out_channels, **kwargs):
super().__init__() super().__init__()
...@@ -105,7 +249,27 @@ class Conv2dUpsample(nn.Module): ...@@ -105,7 +249,27 @@ class Conv2dUpsample(nn.Module):
# defines the forward pass # defines the forward pass
def forward(self, x, feature, indices, skip): def forward(self, x, feature, indices, skip):
"""Forward propagation through this block.
Parameters
----------
x : `torch.tensor`
Output of previous layer.
feature : `torch.tensor`
Encoder feature used for the skip connection.
indices : `torch.tensor`
Indices of the max pooling operation. Used in unpooling operation.
Not used here, but passed to preserve generic interface. Useful in
`pysegcnn.core.layers.Decoder`.
skip : `bool`
Whether to apply skip connection.
Returns
-------
x : `torch.tensor`
Output of this block.
"""
# upsampling with pooling indices # upsampling with pooling indices
x = self.upsample(x, size=feature.shape[2:], mode='nearest') x = self.upsample(x, size=feature.shape[2:], mode='nearest')
...@@ -122,6 +286,24 @@ class Conv2dUpsample(nn.Module): ...@@ -122,6 +286,24 @@ class Conv2dUpsample(nn.Module):
class Encoder(nn.Module): class Encoder(nn.Module):
"""Generic convolutional encoder.
Parameters
----------
filters : `list` [`int`]
List of input channels to each convolutional block.
block : `torch.nn.Module`
The convolutional block. ``block`` should inherit from
`torch.nn.Module`, e.g. `pysegcnn.core.layers.Conv2dPool`.
**kwargs: 'dict' [`str`]
Additional keyword arguments passed to
`pysegcnn.core.layers.Conv2dSame`.
Returns
-------
None.
"""
def __init__(self, filters, block, **kwargs): def __init__(self, filters, block, **kwargs):
super().__init__() super().__init__()
...@@ -144,7 +326,19 @@ class Encoder(nn.Module): ...@@ -144,7 +326,19 @@ class Encoder(nn.Module):
# forward pass through the encoder # forward pass through the encoder
def forward(self, x): def forward(self, x):
"""Forward propagation through the encoder.
Parameters
----------
x : `torch.tensor`
Input image.
Returns
-------
x : `torch.tensor`
Output of the encoder.
"""
# initialize a dictionary that caches the intermediate outputs, i.e. # initialize a dictionary that caches the intermediate outputs, i.e.
# features and pooling indices of each block in the encoder # features and pooling indices of each block in the encoder
self.cache = {} self.cache = {}
...@@ -160,6 +354,26 @@ class Encoder(nn.Module): ...@@ -160,6 +354,26 @@ class Encoder(nn.Module):
class Decoder(nn.Module): class Decoder(nn.Module):
"""Generic convolutional decoder.
Parameters
----------
filters : `list` [`int`]
List of input channels to each convolutional block.
block : `torch.nn.Module`
The convolutional block. ``block`` should inherit from
`torch.nn.Module`, e.g. `pysegcnn.core.layers.Conv2dUnpool`.
skip : `bool`
Whether to apply skip connections from the encoder to the decoder.
**kwargs: 'dict' [`str`]
Additional keyword arguments passed to
`pysegcnn.core.layers.Conv2dSame`.
Returns
-------
None.
"""
def __init__(self, filters, block, skip=True, **kwargs): def __init__(self, filters, block, skip=True, **kwargs):
super().__init__() super().__init__()
...@@ -184,15 +398,33 @@ class Decoder(nn.Module): ...@@ -184,15 +398,33 @@ class Decoder(nn.Module):
# construct decoder layers # construct decoder layers
self.layers = [] self.layers = []
for l, lp1 in zip(n_in * self.features, self.features[1:]): for lyr, lyrp1 in zip(n_in * self.features, self.features[1:]):
self.layers.append(self.block(l, lp1, **kwargs)) self.layers.append(self.block(lyr, lyrp1, **kwargs))
# convert list of layers to ModuleList # convert list of layers to ModuleList
self.layers = nn.ModuleList(*[self.layers]) self.layers = nn.ModuleList(*[self.layers])
# forward pass through decoder # forward pass through decoder
def forward(self, x, enc_cache): def forward(self, x, enc_cache):
"""Forward propagation through the decoder.
Parameters
----------
x : `torch.tensor`
Output of the encoder.
enc_cache : `dict`
Cache dictionary with keys:
``'feature'``
Encoder features used for the skip connection.
``'indices'``
The indices of the max pooling operations.
Returns
-------
x : `torch.tensor`
Output of the decoder.
"""
# for each layer, upsample input and apply optional skip connection # for each layer, upsample input and apply optional skip connection
for i, layer in enumerate(self.layers): for i, layer in enumerate(self.layers):
......
...@@ -10,7 +10,25 @@ import pathlib ...@@ -10,7 +10,25 @@ import pathlib
# the logging configuration dictionary # the logging configuration dictionary
def log_conf(logfile): def log_conf(logfile):
"""Set basic logging configuration passed to `logging.config.dictConfig`.
See the logging `docs`_ for a detailed description of the configuration
dictionary.
.. _docs:
https://docs.python.org/3/library/logging.config.html#dictionary-schema-details
Parameters
----------
logfile : `str` or `pathlib.Path`
The file to save the logs to.
Returns
-------
LOGGING_CONFIG : `dict`
The logging configuration.
"""
# check if the parent directory of the log file exists # check if the parent directory of the log file exists
logfile = pathlib.Path(logfile) logfile = pathlib.Path(logfile)
if not logfile.parent.is_dir(): if not logfile.parent.is_dir():
......
"""A collection of neural networks for semantic image segmentation."""
# !/usr/bin/env python # !/usr/bin/env python
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
"""
Created on Fri Jun 26 16:31:36 2020
@author: Daniel
"""
# builtins # builtins
import os
import enum import enum
import logging import logging
import pathlib import pathlib
...@@ -19,13 +16,23 @@ import torch.optim as optim ...@@ -19,13 +16,23 @@ import torch.optim as optim
# locals # locals
from pysegcnn.core.layers import (Encoder, Decoder, Conv2dPool, Conv2dUnpool, from pysegcnn.core.layers import (Encoder, Decoder, Conv2dPool, Conv2dUnpool,
Conv2dUpsample, Conv2dSame) Conv2dSame)
# module level logger # module level logger
LOGGER = logging.getLogger(__name__) LOGGER = logging.getLogger(__name__)
class Network(nn.Module): class Network(nn.Module):
"""Generic Network class.
The base class for each model. If you want to implement a new model,
inherit the ``~pysegcnn.core.models.Network`` class.
Returns
-------
None.
"""
def __init__(self): def __init__(self):
super().__init__() super().__init__()
...@@ -34,19 +41,63 @@ class Network(nn.Module): ...@@ -34,19 +41,63 @@ class Network(nn.Module):
self.state_file = None self.state_file = None
def freeze(self): def freeze(self):
"""Freeze the weights of a model.
Disables gradient computation: useful when using a pretrained model for
inference.
Returns
-------
None.
"""
for param in self.parameters(): for param in self.parameters():
param.requires_grad = False param.requires_grad = False
def unfreeze(self): def unfreeze(self):
"""Unfreeze the weights of a model.
Enables gradient computation: useful when adjusting a pretrained model
to a new dataset.
Returns
-------
None.
"""
for param in self.parameters(): for param in self.parameters():
param.requires_grad = True param.requires_grad = True
def save(self, state_file, optimizer, bands=None, **kwargs): def save(self, state_file, optimizer, bands=None, **kwargs):
"""Save the model state.
Saves the model and optimizer states together with the model
construction parameters, to easily re-instanciate the model.
Optional ``kwargs`` are also saved.
Parameters
----------
state_file : `str` or `pathlib.Path`
Path to save the model state.
optimizer : `torch.optim.Optimizer`
The optimizer used to train the model.
bands : `list` [`str`] or `None`, optional
List of bands the model is trained with. The default is None.
**kwargs
Arbitrary keyword arguments. Each keyword argument will be saved
as (key, value) pair in ``state_file``.
Returns
-------
model_state : `dict`
A dictionary containing the model and optimizer state
"""
# check if the output path exists and if not, create it # check if the output path exists and if not, create it
state_file = pathlib.Path(state_file) state_file = pathlib.Path(state_file)
if not state_file.parent.is_dir(): if not state_file.parent.is_dir():
state_file.parent.mkdir(parents=True, exist_ok=True) state_file.parent.mkdir(parents=True, exist_ok=True)
# initialize dictionary to store network parameters # initialize dictionary to store network parameters
model_state = {**kwargs} model_state = {**kwargs}
...@@ -79,11 +130,41 @@ class Network(nn.Module): ...@@ -79,11 +130,41 @@ class Network(nn.Module):
torch.save(model_state, state_file) torch.save(model_state, state_file)
LOGGER.info('Network parameters saved in {}'.format(state_file)) LOGGER.info('Network parameters saved in {}'.format(state_file))
return state_file return model_state
@staticmethod @staticmethod
def load(state_file, optimizer=None): def load(state_file, optimizer=None):
"""Load a model state.
Returns the model in ``state_file`` with the pretrained model weights.
If ``optimizer`` is specified, the optimizer parameters are also loaded
from ``state_file``. This is useful when resuming training an existing
model.
Parameters
----------
state_file : `str` or `pathlib.Path`
The model state file. Model state files are stored in
pysegcnn/main/_models.
optimizer : `torch.optim.Optimizer` or `None`, optional
The optimizer used to train the model.
Raises
------
FileNotFoundError
Raised if ``state_file`` does not exist.
Returns
-------
model : `pysegcnn.core.models.Network`
The pretrained model.
optimizer : `torch.optim.Optimizer` or `None`
The optimizer used to train the model.
model_state : '`dict`
A dictionary containing the model and optimizer state, as
constructed by `~pysegcnn.core.Network.save`.
"""
# load the pretrained model # load the pretrained model
state_file = pathlib.Path(state_file) state_file = pathlib.Path(state_file)
if not state_file.exists(): if not state_file.exists():
...@@ -117,10 +198,42 @@ class Network(nn.Module): ...@@ -117,10 +198,42 @@ class Network(nn.Module):
@property @property
def state(self): def state(self):
"""Return the model state file.
Returns
-------
state_file : `pathlib.Path` or `None`
The model state file.
"""
return self.state_file return self.state_file
class UNet(Network): class UNet(Network):
"""A PyTorch implementation of `U-Net`_.
.. _U-Net:
https://arxiv.org/abs/1505.04597
Parameters
----------
in_channels : `int`
Number of channels of the input images.
nclasses : `int`
Number of classes.
filters : `list` [`int`]
List of input channels to each convolutional block.
skip : `bool`
Whether to apply skip connections from the encoder to the decoder.
**kwargs: 'dict' [`str`]
Additional keyword arguments passed to
`pysegcnn.core.layers.Conv2dSame`.
Returns
-------
None.
"""
def __init__(self, in_channels, nclasses, filters, skip, **kwargs): def __init__(self, in_channels, nclasses, filters, skip, **kwargs):
super().__init__() super().__init__()
...@@ -158,7 +271,19 @@ class UNet(Network): ...@@ -158,7 +271,19 @@ class UNet(Network):
kernel_size=1) kernel_size=1)
def forward(self, x): def forward(self, x):
"""Forward propagation of U-Net.
Parameters
----------
x : `torch.tensor`
The input image, shape=(batch_size, channels, height, width).
Returns
-------
y : 'torch.tensor'
The classified image, shape=(batch_size, height, width).
"""
# forward pass: encoder # forward pass: encoder
x = self.encoder(x) x = self.encoder(x)
...@@ -173,11 +298,18 @@ class UNet(Network): ...@@ -173,11 +298,18 @@ class UNet(Network):
class SupportedModels(enum.Enum): class SupportedModels(enum.Enum):
"""Names and corresponding classes of the implemented models."""
Unet = UNet Unet = UNet
class SupportedOptimizers(enum.Enum): class SupportedOptimizers(enum.Enum):
"""Names and corresponding classes of the tested optimizers."""
Adam = optim.Adam Adam = optim.Adam
class SupportedLossFunctions(enum.Enum): class SupportedLossFunctions(enum.Enum):
"""Names and corresponding classes of the tested loss functions."""
CrossEntropy = nn.CrossEntropyLoss CrossEntropy = nn.CrossEntropyLoss
"""A collection of functions for model inference."""
# !/usr/bin/env python
# -*- coding: utf-8 -*-
# builtins # builtins
import logging import logging
...@@ -18,7 +23,21 @@ LOGGER = logging.getLogger(__name__) ...@@ -18,7 +23,21 @@ LOGGER = logging.getLogger(__name__)
def _get_scene_tiles(ds, scene_id): def _get_scene_tiles(ds, scene_id):
"""Return the tiles of the scene with id = ``scene_id``.
Parameters
----------
ds : `pysegcnn.core.dataset.ImageDataset`
An instance of `~pysegcnn.core.dataset.ImageDataset`.
scene_id : `str`
A valid scene identifier.
Returns
-------
indices : `list` [`int`]
List of indices of the tiles from scene with id ``scene_id`` in ``ds``.
"""
# iterate over the scenes of the dataset # iterate over the scenes of the dataset
indices = [] indices = []
for i, scene in enumerate(ds.scenes): for i, scene in enumerate(ds.scenes):
...@@ -30,7 +49,46 @@ def _get_scene_tiles(ds, scene_id): ...@@ -30,7 +49,46 @@ def _get_scene_tiles(ds, scene_id):
def predict_samples(ds, model, cm=False, plot=False, **kwargs): def predict_samples(ds, model, cm=False, plot=False, **kwargs):
"""Classify each sample in ``ds`` with model ``model``.
Parameters
----------
ds : `pysegcnn.core.split.RandomSubset` or
`pysegcnn.core.split.SceneSubset`
An instance of `~pysegcnn.core.split.RandomSubset` or
`~pysegcnn.core.split.SceneSubset`.
model : `pysegcnn.core.models.Network`
An instance of `~pysegcnn.core.models.Network`.
cm : `bool`, optional
Whether to compute the confusion matrix. The default is False.
plot : `bool`, optional
Whether to plot a false color composite, ground truth and model
prediction for each sample. The default is False.
**kwargs
Additional keyword arguments passed to
`pysegcnn.core.graphics.plot_sample`.
Raises
------
TypeError
Raised if ``ds`` is not an instance of
`~pysegcnn.core.split.RandomSubset` or
`~pysegcnn.core.split.SceneSubset`.
Returns
-------
output : `dict`
Output dictionary with keys:
``'input'``
Model input data
``'labels'``
The ground truth
``'prediction'``
Model prediction
conf_mat : `numpy.ndarray`
The confusion matrix. Note that the confusion matrix ``conf_mat`` is
only computed if ``cm`` = True.
"""
# check whether the dataset is a valid subset, i.e. # check whether the dataset is a valid subset, i.e.
# an instance of pysegcnn.core.split.SceneSubset or # an instance of pysegcnn.core.split.SceneSubset or
# an instance of pysegcnn.core.split.RandomSubset # an instance of pysegcnn.core.split.RandomSubset
...@@ -50,7 +108,7 @@ def predict_samples(ds, model, cm=False, plot=False, **kwargs): ...@@ -50,7 +108,7 @@ def predict_samples(ds, model, cm=False, plot=False, **kwargs):
fname = model.state_file.name.split('.pt')[0] fname = model.state_file.name.split('.pt')[0]
# initialize confusion matrix # initialize confusion matrix
cmm = np.zeros(shape=(model.nclasses, model.nclasses)) conf_mat = np.zeros(shape=(model.nclasses, model.nclasses))
# create the dataloader # create the dataloader
dataloader = DataLoader(ds, batch_size=1, shuffle=False, drop_last=False) dataloader = DataLoader(ds, batch_size=1, shuffle=False, drop_last=False)
...@@ -78,7 +136,7 @@ def predict_samples(ds, model, cm=False, plot=False, **kwargs): ...@@ -78,7 +136,7 @@ def predict_samples(ds, model, cm=False, plot=False, **kwargs):
# update confusion matrix # update confusion matrix
if cm: if cm:
for ytrue, ypred in zip(labels.view(-1), prd.view(-1)): for ytrue, ypred in zip(labels.view(-1), prd.view(-1)):
cmm[ytrue.long(), ypred.long()] += 1 conf_mat[ytrue.long(), ypred.long()] += 1
# save plot of current batch to disk # save plot of current batch to disk
if plot: if plot:
...@@ -93,11 +151,49 @@ def predict_samples(ds, model, cm=False, plot=False, **kwargs): ...@@ -93,11 +151,49 @@ def predict_samples(ds, model, cm=False, plot=False, **kwargs):
state=sname, state=sname,
**kwargs) **kwargs)
return output, cmm return output, conf_mat
def predict_scenes(ds, model, scene_id=None, cm=False, plot=False, **kwargs): def predict_scenes(ds, model, scene_id=None, cm=False, plot=False, **kwargs):
"""Classify each scene in ``ds`` with model ``model``.
Parameters
----------
ds : `pysegcnn.core.split.SceneSubset`
An instance of `~pysegcnn.core.split.SceneSubset`.
model : `pysegcnn.core.models.Network`
An instance of `~pysegcnn.core.models.Network`.
scene_id : `str` or `None`
A valid scene identifier.
cm : `bool`, optional
Whether to compute the confusion matrix. The default is False.
plot : `bool`, optional
Whether to plot a false color composite, ground truth and model
prediction for each scene. The default is False.
**kwargs
Additional keyword arguments passed to
`pysegcnn.core.graphics.plot_sample`.
Raises
------
TypeError
Raised if ``ds`` is not an instance of
`~pysegcnn.core.split.SceneSubset`.
Returns
-------
output : `dict`
Output dictionary with keys:
``'input'``
Model input data
``'labels'``
The ground truth
``'prediction'``
Model prediction
conf_mat : `numpy.ndarray`
The confusion matrix. Note that the confusion matrix ``conf_mat`` is
only computed if ``cm`` = True.
"""
# check whether the dataset is a valid subset, i.e. an instance of # check whether the dataset is a valid subset, i.e. an instance of
# pysegcnn.core.split.SceneSubset # pysegcnn.core.split.SceneSubset
if not isinstance(ds, SceneSubset): if not isinstance(ds, SceneSubset):
...@@ -116,7 +212,7 @@ def predict_scenes(ds, model, scene_id=None, cm=False, plot=False, **kwargs): ...@@ -116,7 +212,7 @@ def predict_scenes(ds, model, scene_id=None, cm=False, plot=False, **kwargs):
fname = model.state_file.name.split('.pt')[0] fname = model.state_file.name.split('.pt')[0]
# initialize confusion matrix # initialize confusion matrix
cmm = np.zeros(shape=(model.nclasses, model.nclasses)) conf_mat = np.zeros(shape=(model.nclasses, model.nclasses))
# check whether a scene id is provided # check whether a scene id is provided
if scene_id is None: if scene_id is None:
...@@ -130,7 +226,7 @@ def predict_scenes(ds, model, scene_id=None, cm=False, plot=False, **kwargs): ...@@ -130,7 +226,7 @@ def predict_scenes(ds, model, scene_id=None, cm=False, plot=False, **kwargs):
# iterate over the scenes # iterate over the scenes
LOGGER.info('Predicting scenes of the {} dataset ...'.format(ds.name)) LOGGER.info('Predicting scenes of the {} dataset ...'.format(ds.name))
scenes = {} output = {}
for i, sid in enumerate(scene_ids): for i, sid in enumerate(scene_ids):
# filename for the current scene # filename for the current scene
...@@ -161,7 +257,7 @@ def predict_scenes(ds, model, scene_id=None, cm=False, plot=False, **kwargs): ...@@ -161,7 +257,7 @@ def predict_scenes(ds, model, scene_id=None, cm=False, plot=False, **kwargs):
# update confusion matrix # update confusion matrix
if cm: if cm:
for ytrue, ypred in zip(lab.view(-1), prd.view(-1)): for ytrue, ypred in zip(lab.view(-1), prd.view(-1)):
cmm[ytrue.long(), ypred.long()] += 1 conf_mat[ytrue.long(), ypred.long()] += 1
# reconstruct the entire scene # reconstruct the entire scene
inputs = reconstruct_scene(inp, scene_size, nbands=inp.shape[1]) inputs = reconstruct_scene(inp, scene_size, nbands=inp.shape[1])
...@@ -173,7 +269,7 @@ def predict_scenes(ds, model, scene_id=None, cm=False, plot=False, **kwargs): ...@@ -173,7 +269,7 @@ def predict_scenes(ds, model, scene_id=None, cm=False, plot=False, **kwargs):
i + 1, len(scene_ids), sid, accuracy_function(prdtcn, labels))) i + 1, len(scene_ids), sid, accuracy_function(prdtcn, labels)))
# save outputs to dictionary # save outputs to dictionary
scenes[sid] = {'input': inputs, 'labels': labels, 'prediction': prdtcn} output[sid] = {'input': inputs, 'labels': labels, 'prediction': prdtcn}
# plot current scene # plot current scene
if plot: if plot:
...@@ -185,4 +281,4 @@ def predict_scenes(ds, model, scene_id=None, cm=False, plot=False, **kwargs): ...@@ -185,4 +281,4 @@ def predict_scenes(ds, model, scene_id=None, cm=False, plot=False, **kwargs):
state=sname, state=sname,
**kwargs) **kwargs)
return scenes, cmm return output, conf_mat
"""Split the dataset to training, validation and test set."""
# !/usr/bin/env python
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
"""
Created on Wed Jul 29 12:02:32 2020
@author: Daniel
"""
# builtins # builtins
import datetime import datetime
import enum import enum
...@@ -15,17 +14,70 @@ from torch.utils.data.dataset import Subset ...@@ -15,17 +14,70 @@ from torch.utils.data.dataset import Subset
# the names of the subsets # the names of the subsets
SUBSET_NAMES = ['train', 'valid', 'test'] SUBSET_NAMES = ['train', 'valid', 'test']
# function calculating number of samples in a dataset given a ratio
def _ds_len(ds, ratio): def _ds_len(ds, ratio):
"""Calcute number of samples in a dataset given a ratio.
Parameters
----------
ds : `collections.Sized`
An object with a __len__ attribute.
ratio : `float`
A ratio to multiply with the length of ``ds``.
Returns
-------
n_samples: `int`
Length of ``ds`` * ``ratio``.
"""
return int(np.round(len(ds) * ratio)) return int(np.round(len(ds) * ratio))
# randomly split the tiles of a dataset across the training, validation and
# test dataset
# for each scene, the tiles can be distributed among the training, validation
# and test set
def random_tile_split(ds, tvratio, ttratio=1, seed=0): def random_tile_split(ds, tvratio, ttratio=1, seed=0):
"""Randomly split the tiles of a dataset.
For each scene, the tiles of the scene can be distributed among the
training, validation and test set.
The parameters ``ttratio`` and ``tvratio`` control the size of the
training, validation and test datasets.
Test dataset size : (1 - ``ttratio``) * len(``ds``)
Train dataset size : ``ttratio`` * ``tvratio`` * len(``ds``)
Validation dataset size: ``ttratio`` * (1 - ``tvratio``) * len(``ds``)
Parameters
----------
ds : `pysegcnn.core.dataset.ImageDataset`
An instance of `~pysegcnn.core.dataset.ImageDataset`.
tvratio : `float`
The ratio of training data to validation data, e.g. ``tvratio`` = 0.8
means 80% training, 20% validation.
ttratio : `float`, optional
The ratio of training and validation data to test data, e.g.
``ttratio`` = 0.6 means 60% for training and validation, 40% for
testing. The default is 1.
seed : `int`, optional
The random seed for reproducibility. The default is 0.
Raises
------
AssertionError
Raised if the splits are not pairwise disjoint.
Returns
-------
subsets : `dict`
Subset dictionary with keys:
``'train'``
dictionary containing the training scenes.
``'valid'``
dictionary containing the validation scenes.
``'test'``
dictionary containing the test scenes.
"""
# set the random seed for reproducibility # set the random seed for reproducibility
np.random.seed(seed) np.random.seed(seed)
...@@ -64,12 +116,50 @@ def random_tile_split(ds, tvratio, ttratio=1, seed=0): ...@@ -64,12 +116,50 @@ def random_tile_split(ds, tvratio, ttratio=1, seed=0):
return subsets return subsets
# randomly split the tiles of a dataset across the training, validation and
# test dataset
# for each scene, all the tiles of the scene are included in either the
# training set, the validation set or the test set, respectively
def random_scene_split(ds, tvratio, ttratio=1, seed=0): def random_scene_split(ds, tvratio, ttratio=1, seed=0):
"""Randomly split the tiles of a dataset.
For each scene, all the tiles of the scene are included in either the
training, validation or test set, respectively.
The parameters ``ttratio`` and ``tvratio`` control the size of the
training, validation and test datasets.
Test dataset size : (1 - ``ttratio``) * len(``ds``)
Train dataset size : ``ttratio`` * ``tvratio`` * len(``ds``)
Validation dataset size: ``ttratio`` * (1 - ``tvratio``) * len(``ds``)
Parameters
----------
ds : `pysegcnn.core.dataset.ImageDataset`
An instance of `~pysegcnn.core.dataset.ImageDataset`.
tvratio : `float`
The ratio of training data to validation data, e.g. ``tvratio`` = 0.8
means 80% training, 20% validation.
ttratio : `float`, optional
The ratio of training and validation data to test data, e.g.
``ttratio`` = 0.6 means 60% for training and validation, 40% for
testing. The default is 1.
seed : `int`, optional
The random seed for reproducibility. The default is 0.
Raises
------
AssertionError
Raised if the splits are not pairwise disjoint.
Returns
-------
subsets : `dict`
Subset dictionary with keys:
``'train'``
dictionary containing the training scenes.
``'valid'``
dictionary containing the validation scenes.
``'test'``
dictionary containing the test scenes.
"""
# set the random seed for reproducibility # set the random seed for reproducibility
np.random.seed(seed) np.random.seed(seed)
...@@ -113,7 +203,41 @@ def random_scene_split(ds, tvratio, ttratio=1, seed=0): ...@@ -113,7 +203,41 @@ def random_scene_split(ds, tvratio, ttratio=1, seed=0):
# scenes before date build the training set, scenes after the date build the # scenes before date build the training set, scenes after the date build the
# validation set, the test set is empty # validation set, the test set is empty
def date_scene_split(ds, date, dateformat='%Y%m%d'): def date_scene_split(ds, date, dateformat='%Y%m%d'):
"""Split the dataset based on a date.
Scenes before ``date`` build the training set, scenes after ``date`` build
the validation set, the test set is empty.
Useful for time series data.
Parameters
----------
ds : `pysegcnn.core.dataset.ImageDataset`
An instance of `~pysegcnn.core.dataset.ImageDataset`.
date : 'str'
A date.
dateformat : 'str', optional
The format of ``date``. ``dateformat`` is used by
`datetime.datetime.strptime' to parse ``date`` to a `datetime.datetime`
object. The default is '%Y%m%d'.
Raises
------
AssertionError
Raised if the splits are not pairwise disjoint.
Returns
-------
subsets : `dict`
Subset dictionary with keys:
``'train'``
dictionary containing the training scenes.
``'valid'``
dictionary containing the validation scenes.
``'test'``
dictionary containing the test scenes, empty.
"""
# convert date to datetime object # convert date to datetime object
date = datetime.datetime.strptime(date, dateformat) date = datetime.datetime.strptime(date, dateformat)
...@@ -137,15 +261,32 @@ def date_scene_split(ds, date, dateformat='%Y%m%d'): ...@@ -137,15 +261,32 @@ def date_scene_split(ds, date, dateformat='%Y%m%d'):
def pairwise_disjoint(sets): def pairwise_disjoint(sets):
"""Check if ``sets`` are pairwise disjoint.
Sets are pairwise disjoint if the length of their union equals the sum of
their lengths.
Parameters
----------
sets : `list` [`collections.Sized`]
A list of sized objects.
Returns
-------
disjoint : `bool`
Whether the sets are pairwise disjoint.
"""
union = set().union(*sets) union = set().union(*sets)
n = sum(len(u) for u in sets) n = sum(len(u) for u in sets)
return n == len(union) return n == len(union)
class CustomSubset(Subset): class CustomSubset(Subset):
"""Custom subset inheriting `torch.utils.data.dataset.Subset`."""
def __repr__(self): def __repr__(self):
"""Representation of ``~pysegcnn.core.split.CustomSubset``."""
# representation string # representation string
fs = '- {}: {:d} tiles ({:.2f}%)'.format( fs = '- {}: {:d} tiles ({:.2f}%)'.format(
self.name, len(self.scenes), 100 * len(self.scenes) / self.name, len(self.scenes), 100 * len(self.scenes) /
...@@ -155,6 +296,26 @@ class CustomSubset(Subset): ...@@ -155,6 +296,26 @@ class CustomSubset(Subset):
class SceneSubset(CustomSubset): class SceneSubset(CustomSubset):
"""A custom subset for dataset splits where the scenes are preserved.
Parameters
----------
ds : `pysegcnn.core.dataset.ImageDataset`
An instance of `~pysegcnn.core.dataset.ImageDataset`.
indices : `list` [`int`]
List of the subset indices to access ``ds``.
name : `str`
Name of the subset.
scenes : `list` [`dict`]
List of the subset tiles.
scene_ids : `list` or `numpy.ndarray`
Container of the scene ids.
Returns
-------
None.
"""
def __init__(self, ds, indices, name, scenes, scene_ids): def __init__(self, ds, indices, name, scenes, scene_ids):
super().__init__(dataset=ds, indices=indices) super().__init__(dataset=ds, indices=indices)
...@@ -170,6 +331,26 @@ class SceneSubset(CustomSubset): ...@@ -170,6 +331,26 @@ class SceneSubset(CustomSubset):
class RandomSubset(CustomSubset): class RandomSubset(CustomSubset):
"""A custom subset for random dataset splits.
Parameters
----------
ds : `pysegcnn.core.dataset.ImageDataset`
An instance of `~pysegcnn.core.dataset.ImageDataset`.
indices : `list` [`int`]
List of the subset indices to access ``ds``.
name : `str`
Name of the subset.
scenes : `list` [`dict`]
List of the subset tiles.
scene_ids : `list` or `numpy.ndarray`
Container of the scene ids.
Returns
-------
None.
"""
def __init__(self, ds, indices, name, scenes, scene_ids): def __init__(self, ds, indices, name, scenes, scene_ids):
super().__init__(dataset=ds, indices=indices) super().__init__(dataset=ds, indices=indices)
...@@ -182,6 +363,22 @@ class RandomSubset(CustomSubset): ...@@ -182,6 +363,22 @@ class RandomSubset(CustomSubset):
class Split(object): class Split(object):
"""Generic class handling how ``ds`` is split.
Inherit `~pysegcnn.core.split.Split` and implement the
`~pysegcnn.core.split.Split.subsets` and
`~pysegcnn.core.split.Split.subset_type` method.
Parameters
----------
ds : `pysegcnn.core.dataset.ImageDataset`
An instance of `~pysegcnn.core.dataset.ImageDataset`.
Returns
-------
None.
"""
def __init__(self, ds): def __init__(self, ds):
...@@ -189,7 +386,13 @@ class Split(object): ...@@ -189,7 +386,13 @@ class Split(object):
self.ds = ds self.ds = ds
def split(self): def split(self):
"""Split dataset into training, validation and test set.
`~pysegcnn.core.split.Split.split` works only if
`~pysegcnn.core.split.Split.subsets` and
`~pysegcnn.core.split.Split.subset_type` are implemented.
"""
# build the subsets # build the subsets
ds_split = [] ds_split = []
for name, sub in self.subsets().items(): for name, sub in self.subsets().items():
...@@ -204,14 +407,67 @@ class Split(object): ...@@ -204,14 +407,67 @@ class Split(object):
return ds_split return ds_split
@property
def subsets(self): def subsets(self):
"""Define training, validation and test sets.
Wrapper method for
`pysegcnn.core.split.Split.random_tile_split`,
`pysegcnn.core.split.Split.random_scene_split` or
`pysegcnn.core.split.Split.date_scene_split`.
Raises
------
NotImplementedError
Raised if `pysegcnn.core.split.Split` is not inherited.
Returns
-------
None.
"""
raise NotImplementedError raise NotImplementedError
def subset_type(self): def subset_type(self):
"""Define the type of each subset.
Wrapper method for
`pysegcnn.core.split.RandomSubset` or
`pysegcnn.core.split.SceneSubset`.
Raises
------
NotImplementedError
Raised if `pysegcnn.core.split.Split` is not inherited.
Returns
-------
None.
"""
raise NotImplementedError raise NotImplementedError
class DateSplit(Split): class DateSplit(Split):
"""Split the dataset based on a date.
Class wrapper for `pysegcnn.core.split.Split.date_scene_split`.
Parameters
----------
ds : `pysegcnn.core.dataset.ImageDataset`
An instance of `~pysegcnn.core.dataset.ImageDataset`.
date : 'str'
A date.
dateformat : 'str', optional
The format of ``date``. ``dateformat`` is used by
`datetime.datetime.strptime' to parse ``date`` to a `datetime.datetime`
object. The default is '%Y%m%d'.
Returns
-------
None.
"""
def __init__(self, ds, date, dateformat): def __init__(self, ds, date, dateformat):
super().__init__(ds) super().__init__(ds)
...@@ -225,13 +481,58 @@ class DateSplit(Split): ...@@ -225,13 +481,58 @@ class DateSplit(Split):
self.dateformat = dateformat self.dateformat = dateformat
def subsets(self): def subsets(self):
"""Wrap `pysegcnn.core.split.Split.date_scene_split`.
Returns
-------
subsets : `dict`
Subset dictionary with keys:
``'train'``
dictionary containing the training scenes.
``'valid'``
dictionary containing the validation scenes.
``'test'``
dictionary containing the test scenes, empty.
"""
return date_scene_split(self.ds, self.date, self.dateformat) return date_scene_split(self.ds, self.date, self.dateformat)
def subset_type(self): def subset_type(self):
"""Wrap `pysegcnn.core.split.SceneSubset`.
Returns
-------
SceneSubset : `pysegcnn.core.split.SceneSubset`
The subset type.
"""
return SceneSubset return SceneSubset
class RandomSplit(Split): class RandomSplit(Split):
"""Randomly split the dataset.
Generic class for random dataset splits.
Parameters
----------
ds : `pysegcnn.core.dataset.ImageDataset`
An instance of `~pysegcnn.core.dataset.ImageDataset`.
tvratio : `float`
The ratio of training data to validation data, e.g. ``tvratio`` = 0.8
means 80% training, 20% validation.
ttratio : `float`, optional
The ratio of training and validation data to test data, e.g.
``ttratio`` = 0.6 means 60% for training and validation, 40% for
testing. The default is 1.
seed : `int`, optional
The random seed for reproducibility. The default is 0.
Returns
-------
None.
"""
def __init__(self, ds, ttratio, tvratio, seed): def __init__(self, ds, ttratio, tvratio, seed):
super().__init__(ds) super().__init__(ds)
...@@ -245,32 +546,130 @@ class RandomSplit(Split): ...@@ -245,32 +546,130 @@ class RandomSplit(Split):
class RandomTileSplit(RandomSplit): class RandomTileSplit(RandomSplit):
"""Randomly split the dataset.
For each scene, the tiles of the scene can be distributed among the
training, validation and test set.
Class wrapper for `pysegcnn.core.split.Split.random_tile_split`.
Parameters
----------
ds : `pysegcnn.core.dataset.ImageDataset`
An instance of `~pysegcnn.core.dataset.ImageDataset`.
tvratio : `float`
The ratio of training data to validation data, e.g. ``tvratio`` = 0.8
means 80% training, 20% validation.
ttratio : `float`, optional
The ratio of training and validation data to test data, e.g.
``ttratio`` = 0.6 means 60% for training and validation, 40% for
testing. The default is 1.
seed : `int`, optional
The random seed for reproducibility. The default is 0.
Returns
-------
None.
"""
def __init__(self, ds, ttratio, tvratio, seed): def __init__(self, ds, ttratio, tvratio, seed):
super().__init__(ds, ttratio, tvratio, seed) super().__init__(ds, ttratio, tvratio, seed)
def subsets(self): def subsets(self):
"""Wrap `pysegcnn.core.split.Split.random_tile_split`.
Returns
-------
subsets : `dict`
Subset dictionary with keys:
``'train'``
dictionary containing the training scenes.
``'valid'``
dictionary containing the validation scenes.
``'test'``
dictionary containing the test scenes, empty.
"""
return random_tile_split(self.ds, self.tvratio, self.ttratio, return random_tile_split(self.ds, self.tvratio, self.ttratio,
self.seed) self.seed)
def subset_type(self): def subset_type(self):
"""Wrap `pysegcnn.core.split.RandomSubset`.
Returns
-------
SceneSubset : `pysegcnn.core.split.RandomSubset`
The subset type.
"""
return RandomSubset return RandomSubset
class RandomSceneSplit(RandomSplit): class RandomSceneSplit(RandomSplit):
"""Randomly split the dataset.
For each scene, all the tiles of the scene are included in either the
training, validation or test set, respectively.
Class wrapper for `pysegcnn.core.split.Split.random_scene_split`.
Parameters
----------
ds : `pysegcnn.core.dataset.ImageDataset`
An instance of `~pysegcnn.core.dataset.ImageDataset`.
tvratio : `float`
The ratio of training data to validation data, e.g. ``tvratio`` = 0.8
means 80% training, 20% validation.
ttratio : `float`, optional
The ratio of training and validation data to test data, e.g.
``ttratio`` = 0.6 means 60% for training and validation, 40% for
testing. The default is 1.
seed : `int`, optional
The random seed for reproducibility. The default is 0.
Returns
-------
None.
"""
def __init__(self, ds, ttratio, tvratio, seed): def __init__(self, ds, ttratio, tvratio, seed):
super().__init__(ds, ttratio, tvratio, seed) super().__init__(ds, ttratio, tvratio, seed)
def subsets(self): def subsets(self):
"""Wrap `pysegcnn.core.split.Split.random_scene_split`.
Returns
-------
subsets : `dict`
Subset dictionary with keys:
``'train'``
dictionary containing the training scenes.
``'valid'``
dictionary containing the validation scenes.
``'test'``
dictionary containing the test scenes, empty.
"""
return random_scene_split(self.ds, self.tvratio, self.ttratio, return random_scene_split(self.ds, self.tvratio, self.ttratio,
self.seed) self.seed)
def subset_type(self): def subset_type(self):
"""Wrap `pysegcnn.core.split.SceneSubset`.
Returns
-------
SceneSubset : `pysegcnn.core.split.SceneSubset`
The subset type.
"""
return SceneSubset return SceneSubset
class SupportedSplits(enum.Enum): class SupportedSplits(enum.Enum):
"""Names and corresponding classes of the implemented split modes."""
random = RandomTileSplit random = RandomTileSplit
scene = RandomSceneSplit scene = RandomSceneSplit
date = DateSplit date = DateSplit
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment