"""Model configuration and training.
This module provides an end-to-end framework of dataclasses designed to train
segmentation models on image datasets.
See pysegcnn/main/train.py for a complete walkthrough.
License
-------
Copyright (c) 2020 Daniel Frisinghelli
This source code is licensed under the GNU General Public License v3.
See the LICENSE file in the repository's root directory.
"""
# !/usr/bin/env python
# -*- coding: utf-8 -*-
# builtins
import dataclasses
import pathlib
import logging
import datetime
# externals
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from torch.optim import Optimizer
# locals
from pysegcnn.core.dataset import SupportedDatasets, ImageDataset
from pysegcnn.core.transforms import Augment
from pysegcnn.core.utils import img2np, item_in_enum, accuracy_function
from pysegcnn.core.split import SupportedSplits
from pysegcnn.core.models import (SupportedModels, SupportedOptimizers,
SupportedLossFunctions, Network)
from pysegcnn.core.layers import Conv2dSame
from pysegcnn.main.config import HERE
# module level logger
LOGGER = logging.getLogger(__name__)
[docs]@dataclasses.dataclass
class BaseConfig:
"""Base `dataclasses.dataclass` for each configuration."""
def __post_init__(self):
"""Check the type of each argument.
Raises
------
TypeError
Raised if the conversion to the specified type of the argument
fails.
Returns
-------
None.
"""
# check input types
for field in dataclasses.fields(self):
# the value of the current field
value = getattr(self, field.name)
# check whether the value is of the correct type
if not isinstance(value, field.type):
# try to convert the value to the correct type
try:
setattr(self, field.name, field.type(value))
except TypeError:
# raise an exception if the conversion fails
raise TypeError('Expected {} to be {}, got {}.'
.format(field.name, field.type,
type(value)))
[docs]@dataclasses.dataclass
class DatasetConfig(BaseConfig):
"""Dataset configuration class.
Instanciate a dataset.
Parameters
----------
dataset_name : `str`
The name of the dataset.
root_dir : `pathlib.Path`
The root directory, path to the dataset.
bands : `list` [`str`]
A list of the spectral bands to use.
tile_size : `int`
The size of the tiles. Each scene is divided into square tiles of shape
(tile_size, tile_size).
gt_pattern : `str`
A pattern to match the ground truth naming convention. All directories
and subdirectories in ``root_dir`` are searched for files matching
``gt_pattern``.
seed : `int`
The random seed. Used to split the dataset into training, validation
and test set. Useful for reproducibility. The default is 0.
sort : `bool`, optional
Whether to chronologically sort the samples. Useful for time series
data. The default is False.
transforms : `list` [`pysegcnn.core.split.Augment`], optional
List of `pysegcnn.core.split.Augment` instances. Each item in
``transforms`` generates a distinct transformed version of the dataset.
The total dataset is composed of the original untransformed dataset
together with each transformed version of it.
If ``transforms`` = [], only the original dataset is used.
The default is [].
pad : `bool`, optional
Whether to center pad the input image. Set ``pad`` = True, if the
images are not evenly divisible by the ``tile_size``. The image data is
padded with a constant padding value of zero. For each image, the
corresponding ground truth image is padded with a "no data" label.
The default is False.
Returns
-------
None.
"""
dataset_name: str
root_dir: pathlib.Path
bands: list
tile_size: int
gt_pattern: str
seed: int
sort: bool = False
transforms: list = dataclasses.field(default_factory=list)
pad: bool = False
def __post_init__(self):
"""Check the type of each argument.
Raises
------
ValueError
Raised if ``dataset_name`` is not supported.
FileNotFoundError
Raised if ``root_dir`` does not exist.
TypeError
Raised if not each item in ``transforms`` is an instance of
`pysegcnn.core.split.Augment` in case ``transforms`` is not empty.
Returns
-------
None.
"""
# check input types
super().__post_init__()
# check whether the dataset is currently supported
self.dataset_class = item_in_enum(self.dataset_name, SupportedDatasets)
# check whether the root directory exists
if not self.root_dir.exists():
raise FileNotFoundError('{} does not exist.'.format(self.root_dir))
# check whether the transformations inherit from the correct class
if not all([isinstance(t, Augment) for t in self.transforms if
self.transforms]):
raise TypeError('Each transformation is expected to be an instance'
' of {}.'.format('.'.join([Augment.__module__,
Augment.__name__])))
[docs] def init_dataset(self):
"""Instanciate the dataset.
Returns
-------
dataset : `pysegcnn.core.dataset.ImageDataset`
An instance of `pysegcnn.core.dataset.ImageDataset`.
"""
# instanciate the dataset
dataset = self.dataset_class(
root_dir=str(self.root_dir),
use_bands=self.bands,
tile_size=self.tile_size,
seed=self.seed,
sort=self.sort,
transforms=self.transforms,
pad=self.pad,
gt_pattern=self.gt_pattern
)
return dataset
[docs]@dataclasses.dataclass
class SplitConfig(BaseConfig):
"""Dataset split configuration class.
Split a dataset into training, validation and test set.
Parameters
----------
split_mode : `str`
The mode to split the dataset.
ttratio : `float`
The ratio of training and validation data to test data, e.g.
``ttratio`` = 0.6 means 60% for training and validation, 40% for
testing.
tvratio : `float`
The ratio of training data to validation data, e.g. ``tvratio`` = 0.8
means 80% training, 20% validation.
date : `str`, optional
A date. Used if ``split_mode`` = 'date'. The default is 'yyyymmdd'.
dateformat : `str`, optional
The format of ``date``. ``dateformat`` is used by
`datetime.datetime.strptime' to parse ``date`` to a `datetime.datetime`
object. The default is '%Y%m%d'.
drop : `float`, optional
Whether to drop samples (during training only) with a fraction of
pixels equal to the constant padding value >= ``drop``. ``drop`` = 0
means, do not drop any samples. The default is 0.
Returns
-------
None.
"""
split_mode: str
ttratio: float
tvratio: float
date: str = 'yyyymmdd'
dateformat: str = '%Y%m%d'
drop: float = 0
def __post_init__(self):
"""Check the type of each argument.
Raises
------
ValueError
Raised if ``split_mode`` is not supported.
Returns
-------
None.
"""
# check input types
super().__post_init__()
# check if the split mode is valid
self.split_class = item_in_enum(self.split_mode, SupportedSplits)
@staticmethod
def _drop_samples(ds, drop_threshold=1):
"""Drop samples with a fraction of pixels equal to the padding value.
Parameters
----------
ds : `pysegcnn.core.split.RandomSubset` or
`pysegcnn.core.split.SceneSubset`.
An instance of `pysegcnn.core.split.RandomSubset` or
`pysegcnn.core.split.SceneSubset`.
drop_threshold : `float`, optional
The threshold above which samples are dropped. ``drop_threshold`` =
1 means a sample is dropped, if all pixels are equal to the padding
value. ``drop_threshold`` = 0.8 means, drop a sample if 80% of the
pixels are equal to the padding value, etc. The default is 1.
Returns
-------
dropped : `list` [`dict`]
List of the dropped samples.
"""
# iterate over the scenes returned by self.compose_scenes()
dropped = []
for pos, i in enumerate(ds.indices):
# the current scene
s = ds.dataset.scenes[i]
# the current tile in the ground truth
tile_gt = img2np(s['gt'], ds.dataset.tile_size, s['tile'],
ds.dataset.pad, ds.dataset.cval)
# percent of pixels equal to the constant padding value
npixels = (tile_gt[tile_gt == ds.dataset.cval].size / tile_gt.size)
# drop samples where npixels >= self.drop
if npixels >= drop_threshold:
LOGGER.info('Skipping scene {}, tile {}: {:.2f}% padded pixels'
' ...'.format(s['id'], s['tile'], npixels * 100))
dropped.append(s)
_ = ds.indices.pop(pos)
return dropped
[docs] def train_val_test_split(self, ds):
"""Split ``ds`` into training, validation and test set.
Parameters
----------
ds : `pysegcnn.core.dataset.ImageDataset`
An instance of `pysegcnn.core.dataset.ImageDataset`.
Raises
------
TypeError
Raised if ``ds`` is not an instance of
`pysegcnn.core.dataset.ImageDataset`.
Returns
-------
train_ds : `pysegcnn.core.split.RandomSubset` or
`pysegcnn.core.split.SceneSubset`.
The training set.
valid_ds : `pysegcnn.core.split.RandomSubset` or
`pysegcnn.core.split.SceneSubset`.
The validation set.
test_ds : `pysegcnn.core.split.RandomSubset` or
`pysegcnn.core.split.SceneSubset`.
The test set.
"""
if not isinstance(ds, ImageDataset):
raise TypeError('Expected "ds" to be {}.'
.format('.'.join([ImageDataset.__module__,
ImageDataset.__name__])))
if self.split_mode == 'random' or self.split_mode == 'scene':
subset = self.split_class(ds,
self.ttratio,
self.tvratio,
ds.seed)
else:
subset = self.split_class(ds, self.date, self.dateformat)
# the training, validation and test dataset
train_ds, valid_ds, test_ds = subset.split()
# whether to drop training samples with a fraction of pixels equal to
# the constant padding value cval >= drop
if ds.pad and self.drop > 0:
self.dropped = self._drop_samples(train_ds, self.drop)
return train_ds, valid_ds, test_ds
[docs] @staticmethod
def dataloaders(*args, **kwargs):
"""Build `torch.utils.data.DataLoader` instances.
Parameters
----------
*args : `list` [`torch.utils.data.Dataset`]
List of instances of `torch.utils.data.Dataset`.
**kwargs
Additional keyword arguments passed to
`torch.utils.data.DataLoader`.
Raises
------
TypeError
Raised if not each item in ``args`` is an instance of
`torch.utils.data.Dataset`.
Returns
-------
loaders : `list` [`torch.utils.data.DataLoader`]
List of instances of `torch.utils.data.DataLoader`. If an instance
of `torch.utils.data.Dataset` in ``args`` is empty, `None` is
appended to ``loaders`` instead of an instance of
`torch.utils.data.DataLoader`.
"""
# check whether each dataset in args has the correct type
loaders = []
for ds in args:
if not isinstance(ds, Dataset):
raise TypeError('Expected {}, got {}.'
.format(repr(Dataset), type(ds)))
# check if the dataset is not empty
if len(ds) > 0:
# build the dataloader
loader = DataLoader(ds, **kwargs)
else:
loader = None
loaders.append(loader)
return loaders
[docs]@dataclasses.dataclass
class ModelConfig(BaseConfig):
"""Model configuration class.
Instanciate a (pretrained) model.
Parameters
----------
model_name : `str`
The name of the model.
filters : `list` [`int`]
List of input channels to the convolutional layers.
torch_seed : `int`
The random seed to initialize the model weights.
Useful for reproducibility.
optim_name : `str`
The name of the optimizer to update the model weights.
loss_name : `str`
The name of the loss function measuring the model error.
skip_connection : `bool`, optional
Whether to apply skip connections. The defaul is True.
kwargs: `dict`, optional
The configuration for each convolution in the model. The default is
{'kernel_size': 3, 'stride': 1, 'dilation': 1}.
batch_size : `int`, optional
The model batch size. Determines the number of samples to process
before updating the model weights. The default is 64.
checkpoint : `bool`, optional
Whether to resume training from an existing model checkpoint. The
default is False.
transfer : `bool`, optional
Whether to use a model for transfer learning on a new dataset. If True,
the model architecture of ``pretrained_model`` is adjusted to a new
dataset. The default is False.
pretrained_model : `str`, optional
The name of the pretrained model to use for transfer learning.
The default is ''.
lr : `float`, optional
The learning rate used by the gradient descent algorithm.
The default is 0.001.
early_stop : `bool`, optional
Whether to apply `early stopping`_. The default is False.
mode : `str`, optional
The mode of the early stopping. Depends on the metric measuring
performance. When using model loss as metric, use ``mode`` = 'min',
however, when using accuracy as metric, use ``mode`` = 'max'. For now,
only ``mode`` = 'max' is supported. Only used if ``early_stop`` = True.
The default is 'max'.
delta : `float`, optional
Minimum change in early stopping metric to be considered as an
improvement. Only used if ``early_stop`` = True. The default is 0.
patience : `int`, optional
The number of epochs to wait for an improvement in the early stopping
metric. If the model does not improve over more than ``patience``
epochs, quit training. Only used if ``early_stop`` = True.
The default is 10.
epochs : `int`, optional
The maximum number of epochs to train. The default is 50.
nthreads : `int`, optional
The number of cpu threads to use during training. The default is
torch.get_num_threads().
save : `bool`, optional
Whether to save the model state to disk. Model states are saved in
pysegcnn/main/_models. The default is True.
.. _early stopping:
https://en.wikipedia.org/wiki/Early_stopping
Returns
-------
None.
"""
model_name: str
filters: list
torch_seed: int
optim_name: str
loss_name: str
skip_connection: bool = True
kwargs: dict = dataclasses.field(
default_factory=lambda: {'kernel_size': 3, 'stride': 1, 'dilation': 1})
batch_size: int = 64
checkpoint: bool = False
transfer: bool = False
pretrained_model: str = ''
lr: float = 0.001
early_stop: bool = False
mode: str = 'max'
delta: float = 0
patience: int = 10
epochs: int = 50
nthreads: int = torch.get_num_threads()
save: bool = True
def __post_init__(self):
"""Check the type of each argument.
Configure path to save model state.
Raises
------
ValueError
Raised if the model ``model_name``, the optimizer ``optim_name`` or
the loss function ``loss_name`` is not supported.
Returns
-------
None.
"""
# check input types
super().__post_init__()
# check whether the model is currently supported
self.model_class = item_in_enum(self.model_name, SupportedModels)
# check whether the optimizer is currently supported
self.optim_class = item_in_enum(self.optim_name, SupportedOptimizers)
# check whether the loss function is currently supported
self.loss_class = item_in_enum(self.loss_name, SupportedLossFunctions)
# path to model states
self.state_path = pathlib.Path(HERE).joinpath('_models/')
# path to pretrained model
self.pretrained_path = self.state_path.joinpath(self.pretrained_model)
[docs] def init_optimizer(self, model):
"""Instanciate the optimizer.
Parameters
----------
model : `torch.nn.Module`
An instance of `torch.nn.Module`.
Returns
-------
optimizer : `torch.optim.Optimizer`
An instance of `torch.optim.Optimizer`.
"""
LOGGER.info('Optimizer: {}.'.format(repr(self.optim_class)))
# initialize the optimizer for the specified model
optimizer = self.optim_class(model.parameters(), self.lr)
return optimizer
[docs] def init_loss_function(self):
"""Instanciate the loss function.
Returns
-------
loss_function : `torch.nn.Module`
An instance of `torch.nn.Module`.
"""
LOGGER.info('Loss function: {}.'.format(repr(self.loss_class)))
# instanciate the loss function
loss_function = self.loss_class()
return loss_function
[docs] def init_model(self, ds, state_file):
"""Instanciate the model and the optimizer.
If the model checkpoint ``state_file`` exists, the pretrained model and
optimizer states are loaded, otherwise the model and the optimizer are
initialized from scratch.
Parameters
----------
ds : `pysegcnn.core.dataset.ImageDataset`
An instance of `pysegcnn.core.dataset.ImageDataset`.
state_file : `pathlib.Path`
Path to a model checkpoint.
Returns
-------
model : `pysegcnn.core.models.Network`
An instance of `pysegcnn.core.models.Network`.
optimizer : `torch.optim.Optimizer`
An instance of `torch.optim.Optimizer`.
checkpoint_state : `dict` [`str`, `numpy.ndarray`]
If the model checkpoint ``state_file`` exists, ``checkpoint_state``
has keys:
``'ta'``
The accuracy on the training set (`numpy.ndarray`).
``'tl'``
The loss on the training set (`numpy.ndarray`).
``'va'``
The accuracy on the validation set (`numpy.ndarray`).
``'vl'``
The loss on the validation set (`numpy.ndarray`).
"""
# write an initialization string to the log file
LogConfig.init_log('{}: Initializing model run. ')
# case (1): build a new model
if not self.transfer:
# set the random seed for reproducibility
torch.manual_seed(self.torch_seed)
LOGGER.info('Initializing model: {}'.format(state_file.name))
# instanciate the model
model = self.model_class(
in_channels=len(ds.use_bands),
nclasses=len(ds.labels),
filters=self.filters,
skip=self.skip_connection,
**self.kwargs)
# case (2): load a pretrained model for transfer learning
else:
# load pretrained model
LOGGER.info('Loading pretrained model for transfer learning from: '
'{}'.format(self.pretrained_path))
model = self.transfer_model(self.pretrained_path, ds)
# initialize the optimizer
optimizer = self.init_optimizer(model)
# whether to resume training from an existing model checkpoint
checkpoint_state = {}
if self.checkpoint:
model, optimizer, checkpoint_state = self.load_checkpoint(
model, optimizer, state_file)
return model, optimizer, checkpoint_state
[docs] @staticmethod
def load_checkpoint(model, optimizer, state_file):
"""Load an existing model checkpoint.
If the model checkpoint ``state_file`` exists, the pretrained model and
optimizer states are loaded.
Parameters
----------
model : `pysegcnn.core.models.Network`
An instance of `pysegcnn.core.models.Network`.
optimizer : `torch.optim.Optimizer`
An instance of `torch.optim.Optimizer`.
state_file : `pathlib.Path`
Path to the model checkpoint.
Returns
-------
model : `pysegcnn.core.models.Network`
An instance of `pysegcnn.core.models.Network`.
optimizer : `torch.optim.Optimizer`
An instance of `torch.optim.Optimizer`.
checkpoint_state : `dict` [`str`, `numpy.ndarray`]
If the model checkpoint ``state_file`` exists, ``checkpoint_state``
has keys:
``'ta'``
The accuracy on the training set (`numpy.ndarray`).
``'tl'``
The loss on the training set (`numpy.ndarray`).
``'va'``
The accuracy on the validation set (`numpy.ndarray`).
``'vl'``
The loss on the validation set (`numpy.ndarray`).
"""
# whether to resume training from an existing model checkpoint
checkpoint_state = {}
# if no checkpoint exists, file a warning and continue with a model
# initialized from scratch
if not state_file.exists():
LOGGER.warning('Checkpoint for model {} does not exist. '
'Initializing new model.'
.format(state_file.name))
else:
# load model checkpoint
model, optimizer, model_state = Network.load(state_file)
# load model loss and accuracy
# get all non-zero elements, i.e. get number of epochs trained
# before the early stop
checkpoint_state = {k: v[np.nonzero(v)].reshape(v.shape[0], -1)
for k, v in model_state['state'].items()}
return model, optimizer, checkpoint_state
[docs] @staticmethod
def transfer_model(state_file, ds):
"""Adjust a pretrained model to a new dataset.
The classification layer of the pretrained model in ``state_file`` is
initilialized from scratch with the classes of the new dataset ``ds``.
The remaining model weights are preserved.
Parameters
----------
state_file : `pathlib.Path`
Path to a pretrained model.
ds : `pysegcnn.core.dataset.ImageDataset`
An instance of `pysegcnn.core.dataset.ImageDataset`.
Raises
------
TypeError
Raised if ``ds`` is not an instance of
`pysegcnn.core.dataset.ImageDataset`.
ValueError
Raised if the bands of ``ds`` do not match the bands of the dataset
the pretrained model was trained with.
Returns
-------
model : `pysegcnn.core.models.Network`
An instance of `pysegcnn.core.models.Network`. The pretrained model
adjusted to the new dataset.
"""
# check input type
if not isinstance(ds, ImageDataset):
raise TypeError('Expected "ds" to be {}.'
.format('.'.join([ImageDataset.__module__,
ImageDataset.__name__])))
# load the pretrained model
model, _, model_state = Network.load(state_file)
LOGGER.info('Configuring model for new dataset: {}.'.format(
ds.__class__.__name__))
# check whether the current dataset uses the correct spectral bands
if ds.use_bands != model_state['bands']:
raise ValueError('The pretrained network was trained with '
'bands {}, not with bands {}.'
.format(model_state['bands'], ds.use_bands))
# get the number of convolutional filters
filters = model_state['params']['filters']
# reset model epoch to 0, since the model is trained on a different
# dataset
model.epoch = 0
# adjust the number of classes in the model
model.nclasses = len(ds.labels)
LOGGER.info('Replacing classification layer to classes: {}.'
.format(', '.join('({}, {})'.format(k, v['label'])
for k, v in ds.labels.items())))
# adjust the classification layer to the classes of the new dataset
model.classifier = Conv2dSame(in_channels=filters[0],
out_channels=model.nclasses,
kernel_size=1)
return model
[docs]@dataclasses.dataclass
class StateConfig(BaseConfig):
"""Model state configuration class.
Generate the model state filename according to the following naming
convention:
model_dataset_optimizer_splitmode_splitparams_tilesize_batchsize_bands.pt
Parameters
----------
ds : `pysegcnn.core.dataset.ImageDataset`
An instance of `pysegcnn.core.dataset.ImageDataset`.
sc : `pysegcnn.core.trainer.SplitConfig`
An instance of `pysegcnn.core.trainer.SplitConfig`.
mc : `pysegcnn.core.trainer.ModelConfig`
An instance of `pysegcnn.core.trainer.SplitConfig`.
Returns
-------
None.
"""
ds: ImageDataset
sc: SplitConfig
mc: ModelConfig
def __post_init__(self):
"""Check the type of each argument.
Returns
-------
None.
"""
super().__post_init__()
[docs] def init_state(self):
"""Generate the model state filename.
Returns
-------
state : `pathlib.Path`
The path to the model state file.
"""
# file to save model state to:
# network_dataset_optim_split_splitparams_tilesize_batchsize_bands.pt
# model state filename
state_file = '{}_{}_{}_{}Split_{}_t{}_b{}_{}.pt'
# get the band numbers
bformat = ''.join(band[0] +
str(self.ds.sensor.__members__[band].value) for
band in self.ds.use_bands)
# check which split mode was used
if self.sc.split_mode == 'date':
# store the date that was used to split the dataset
state_file = state_file.format(self.mc.model_name,
self.ds.__class__.__name__,
self.mc.optim_name,
self.sc.split_mode.capitalize(),
self.sc.date,
self.ds.tile_size,
self.mc.batch_size,
bformat)
else:
# store the random split parameters
split_params = 's{}_t{}v{}'.format(
self.ds.seed, str(self.sc.ttratio).replace('.', ''),
str(self.sc.tvratio).replace('.', ''))
# model state filename
state_file = state_file.format(self.mc.model_name,
self.ds.__class__.__name__,
self.mc.optim_name,
self.sc.split_mode.capitalize(),
split_params,
self.ds.tile_size,
self.mc.batch_size,
bformat)
# check whether a pretrained model was used and change state filename
# accordingly
if self.mc.transfer:
# add the configuration of the pretrained model to the state name
state_file = (state_file.replace('.pt', '_') +
'pretrained_' + self.mc.pretrained_model)
# path to model state
state = self.mc.state_path.joinpath(state_file)
return state
[docs]@dataclasses.dataclass
class EvalConfig(BaseConfig):
"""Model inference configuration.
Evaluate a model.
Parameters
----------
state_file : `pathlib.Path`
Path to the model to evaluate.
test : `bool` or `None`
Whether to evaluate the model on the training(``test`` = `None`), the
validation (``test`` = False) or the test set (``test`` = True).
predict_scene : `bool`, optional
The model prediction order. If False, the samples (tiles) of a dataset
are predicted in any order and the scenes are not reconstructed.
If True, the samples (tiles) are ordered according to the scene they
belong to and a model prediction for each entire reconstructed scene is
returned. The default is False.
plot_samples : `bool`, optional
Whether to save a plot of false color composite, ground truth and model
prediction for each sample (tile). Only used if ``predict_scene`` =
False. The default is False.
plot_scenes : `bool`, optional
Whether to save a plot of false color composite, ground truth and model
prediction for each entire scene. Only used if ``predict_scene`` =
True. The default is False.
plot_bands : `list` [`str`], optional
The bands to build the false color composite. The default is
['nir', 'red', 'green'].
cm : `bool`, optional
Whether to compute and plot the confusion matrix. The default is True.
figsize : `tuple`, optional
The figure size in centimeters. The default is (10, 10).
alpha : `int`, optional
The level of the percentiles for contrast stretching of the false color
compsite. The default is 0, i.e. no stretching.
Returns
-------
None.
"""
state_file: pathlib.Path
test: object
predict_scene: bool = False
plot_samples: bool = False
plot_scenes: bool = False
plot_bands: list = dataclasses.field(
default_factory=lambda: ['nir', 'red', 'green'])
cm: bool = True
figsize: tuple = (10, 10)
alpha: int = 5
def __post_init__(self):
"""Check the type of each argument.
Configure figure output paths.
Raises
------
TypeError
Raised if ``test`` is not of type `bool` or `None`.
Returns
-------
None.
"""
super().__post_init__()
# check whether the test input parameter is correctly specified
if self.test not in [None, False, True]:
raise TypeError('Expected "test" to be None, True or False, got '
'{}.'.format(self.test))
# the output paths for the different graphics
self.base_path = pathlib.Path(HERE)
self.sample_path = self.base_path.joinpath('_samples')
self.scenes_path = self.base_path.joinpath('_scenes')
self.perfmc_path = self.base_path.joinpath('_graphics')
# input path for model state files
self.models_path = self.base_path.joinpath('_models')
self.state_file = self.models_path.joinpath(self.state_file)
# write initialization string to log file
LogConfig.init_log('{}: ' + 'Evaluating model: {}.'.format(
self.state_file.name))
[docs]@dataclasses.dataclass
class LogConfig(BaseConfig):
"""Logging configuration class.
Generate the model log file.
Parameters
----------
state_file : `pathlib.Path`
Path to a model state file.
"""
state_file: pathlib.Path
def __post_init__(self):
"""Check the type of each argument.
Generate model log file.
Returns
-------
None.
"""
super().__post_init__()
# the path to store model logs
self.log_path = pathlib.Path(HERE).joinpath('_logs')
# the log file of the current model
self.log_file = self.log_path.joinpath(
self.state_file.name.replace('.pt', '.log'))
[docs] @staticmethod
def now():
"""Return the current date and time.
Returns
-------
date : `datetime.datetime`
The current date and time.
"""
return datetime.datetime.strftime(datetime.datetime.now(),
'%Y-%m-%dT%H:%M:%S')
[docs] @staticmethod
def init_log(init_str):
"""Generate a string to identify a new model run.
Parameters
----------
init_str : `str`
The string to write to the model log file.
Returns
-------
None.
"""
LOGGER.info(80 * '-')
LOGGER.info(init_str.format(LogConfig.now()))
LOGGER.info(80 * '-')
[docs]@dataclasses.dataclass
class NetworkTrainer(BaseConfig):
"""Model training class.
Generic class to train an instance of `pysegcnn.core.models.Network` on
a dataset of type `pysegcnn.core.dataset.ImageDataset`.
Parameters
----------
model : `pysegcnn.core.models.Network`
The model to train. An instance of `pysegcnn.core.models.Network`.
optimizer : `torch.optim.Optimizer`
The optimizer to update the model weights. An instance of
`torch.optim.Optimizer`.
loss_function : `torch.nn.Module`
The loss function to compute the model error. An instance of
`torch.nn.Module`.
train_dl : `torch.utils.data.DataLoader`
The training `torch.utils.data.DataLoader` instance.
valid_dl : `torch.utils.data.DataLoader`
The validation `torch.utils.data.DataLoader` instance.
test_dl : `torch.utils.data.DataLoader`
The test `torch.utils.data.DataLoader` instance.
state_file : `pathlib.Path`
Path to save the model state.
epochs : `int`, optional
The maximum number of epochs to train. The default is 1.
nthreads : `int`, optional
The number of cpu threads to use during training. The default is
torch.get_num_threads().
early_stop : `bool`, optional
Whether to apply `early stopping`_. The default is False.
mode : `str`, optional
The mode of the early stopping. Depends on the metric measuring
performance. When using model loss as metric, use ``mode`` = 'min',
however, when using accuracy as metric, use ``mode`` = 'max'. For now,
only ``mode`` = 'max' is supported. Only used if ``early_stop`` = True.
The default is 'max'.
delta : `float`, optional
Minimum change in early stopping metric to be considered as an
improvement. Only used if ``early_stop`` = True. The default is 0.
patience : `int`, optional
The number of epochs to wait for an improvement in the early stopping
metric. If the model does not improve over more than ``patience``
epochs, quit training. Only used if ``early_stop`` = True.
The default is 10.
checkpoint_state : `dict` [`str`, `numpy.ndarray`], optional
A model checkpoint for ``model``. If specified, ``checkpoint_state``
should be a dictionary with keys:
``'ta'``
The accuracy on the training set (`numpy.ndarray`).
``'tl'``
The loss on the training set (`numpy.ndarray`).
``'va'``
The accuracy on the validation set (`numpy.ndarray`).
``'vl'``
The loss on the validation set (`numpy.ndarray`).
The default is {}.
save : `bool`, optional
Whether to save the model state to ``state_file``. The default is True.
.. _early stopping:
https://en.wikipedia.org/wiki/Early_stopping
Returns
-------
None.
"""
model: Network
optimizer: Optimizer
loss_function: nn.Module
train_dl: DataLoader
valid_dl: DataLoader
test_dl: DataLoader
state_file: pathlib.Path
epochs: int = 1
nthreads: int = torch.get_num_threads()
early_stop: bool = False
mode: str = 'max'
delta: float = 0
patience: int = 10
checkpoint_state: dict = dataclasses.field(default_factory=dict)
save: bool = True
def __post_init__(self):
"""Check the type of each argument.
Configure the device to train the model on, i.e. train on the gpu if
available.
Configure early stopping if required.
Returns
-------
None.
"""
super().__post_init__()
# whether to use the gpu
self.device = torch.device("cuda:0" if torch.cuda.is_available()
else "cpu")
# send the model to the gpu if available
self.model = self.model.to(self.device)
# maximum accuracy on the validation dataset
self.max_accuracy = 0
if self.checkpoint_state:
self.max_accuracy = self.checkpoint_state['va'].mean(
axis=0).max().item()
# whether to use early stopping
self.es = None
if self.early_stop:
self.es = EarlyStopping(self.mode, self.max_accuracy, self.delta,
self.patience)
# log representation
LOGGER.info(repr(self))
[docs] def train(self):
"""Train the model.
Returns
-------
training_state : `dict` [`str`, `numpy.ndarray`]
The training state dictionary with keys:
``'ta'``
The accuracy on the training set (`numpy.ndarray`).
``'tl'``
The loss on the training set (`numpy.ndarray`).
``'va'``
The accuracy on the validation set (`numpy.ndarray`).
``'vl'``
The loss on the validation set (`numpy.ndarray`).
"""
LOGGER.info(35 * '-' + ' Training ' + 35 * '-')
# set the number of threads
LOGGER.info('Device: {}'.format(self.device))
LOGGER.info('Number of cpu threads: {}'.format(self.nthreads))
torch.set_num_threads(self.nthreads)
# create dictionary of the observed losses and accuracies on the
# training and validation dataset
tshape = (len(self.train_dl), self.epochs)
vshape = (len(self.valid_dl), self.epochs)
self.training_state = {'tl': np.zeros(shape=tshape),
'ta': np.zeros(shape=tshape),
'vl': np.zeros(shape=vshape),
'va': np.zeros(shape=vshape)
}
# initialize the training: iterate over the entire training dataset
for epoch in range(self.epochs):
# set the model to training mode
LOGGER.info('Setting model to training mode ...')
self.model.train()
# iterate over the dataloader object
for batch, (inputs, labels) in enumerate(self.train_dl):
# send the data to the gpu if available
inputs = inputs.to(self.device)
labels = labels.to(self.device)
# reset the gradients
self.optimizer.zero_grad()
# perform forward pass
outputs = self.model(inputs)
# compute loss
loss = self.loss_function(outputs, labels.long())
observed_loss = loss.detach().numpy().item()
self.training_state['tl'][batch, epoch] = observed_loss
# compute the gradients of the loss function w.r.t.
# the network weights
loss.backward()
# update the weights
self.optimizer.step()
# calculate predicted class labels
ypred = F.softmax(outputs, dim=1).argmax(dim=1)
# calculate accuracy on current batch
observed_accuracy = accuracy_function(ypred, labels)
self.training_state['ta'][batch, epoch] = observed_accuracy
# print progress
LOGGER.info('Epoch: {:d}/{:d}, Mini-batch: {:d}/{:d}, '
'Loss: {:.2f}, Accuracy: {:.2f}'.format(
epoch + 1,
self.epochs,
batch + 1,
len(self.train_dl),
observed_loss,
observed_accuracy))
# update the number of epochs trained
self.model.epoch += 1
# whether to evaluate model performance on the validation set and
# early stop the training process
if self.early_stop:
# model predictions on the validation set
vacc, vloss = self.predict()
# append observed accuracy and loss to arrays
self.training_state['va'][:, epoch] = vacc.squeeze()
self.training_state['vl'][:, epoch] = vloss.squeeze()
# metric to assess model performance on the validation set
epoch_acc = vacc.squeeze().mean()
# whether the model improved with respect to the previous epoch
if self.es.increased(epoch_acc, self.max_accuracy, self.delta):
self.max_accuracy = epoch_acc
# save model state if the model improved with
# respect to the previous epoch
self.save_state()
# whether the early stopping criterion is met
if self.es.stop(epoch_acc):
break
else:
# if no early stopping is required, the model state is
# saved after each epoch
self.save_state()
return self.training_state
[docs] def predict(self):
"""Model inference at training time.
Returns
-------
accuracies : `numpy.ndarray`
The mean model prediction accuracy on each mini-batch in the
validation set.
losses : `numpy.ndarray`
The model loss for each mini-batch in the validation set.
"""
# set the model to evaluation mode
LOGGER.info('Setting model to evaluation mode ...')
self.model.eval()
# create arrays of the observed losses and accuracies
accuracies = np.zeros(shape=(len(self.valid_dl), 1))
losses = np.zeros(shape=(len(self.valid_dl), 1))
# iterate over the validation/test set
LOGGER.info('Calculating accuracy on the validation set ...')
for batch, (inputs, labels) in enumerate(self.valid_dl):
# send the data to the gpu if available
inputs = inputs.to(self.device)
labels = labels.to(self.device)
# calculate network outputs
with torch.no_grad():
outputs = self.model(inputs)
# compute loss
loss = self.loss_function(outputs, labels.long())
losses[batch, 0] = loss.detach().numpy().item()
# calculate predicted class labels
pred = F.softmax(outputs, dim=1).argmax(dim=1)
# calculate accuracy on current batch
acc = accuracy_function(pred, labels)
accuracies[batch, 0] = acc
# print progress
LOGGER.info('Mini-batch: {:d}/{:d}, Accuracy: {:.2f}'
.format(batch + 1, len(self.valid_dl), acc))
# calculate overall accuracy on the validation/test set
LOGGER.info('Epoch: {:d}, Mean accuracy: {:.2f}%.'
.format(self.model.epoch, accuracies.mean() * 100))
return accuracies, losses
[docs] def save_state(self):
"""Save the model state.
Returns
-------
None.
"""
# whether to save the model state
if self.save:
# append the model performance before the checkpoint to the model
# state, if a checkpoint is passed
if self.checkpoint_state:
# append values from checkpoint to current training state
state = {k1: np.hstack([v1, v2]) for (k1, v1), (k2, v2) in
zip(self.checkpoint_state.items(),
self.training_state.items()) if k1 == k2}
else:
state = self.training_state
# save model state
_ = self.model.save(
self.state_file,
self.optimizer,
bands=self.train_dl.dataset.dataset.use_bands,
train_ds=self.train_dl.dataset,
valid_ds=self.valid_dl.dataset,
test_ds=self.test_dl.dataset,
state=state,
)
def __repr__(self):
"""Representation of `~pysegcnn.core.trainer.NetworkTrainer`.
Returns
-------
fs : `str`
Representation string.
"""
# representation string to print
fs = self.__class__.__name__ + '(\n'
# dataset
fs += ' (dataset):\n '
fs += ''.join(
repr(self.train_dl.dataset.dataset)).replace('\n', '\n ')
# batch size
fs += '\n (batch):\n '
fs += '- batch size: {}\n '.format(self.train_dl.batch_size)
fs += '- mini-batch shape (b, c, h, w): {}'.format(
(self.train_dl.batch_size,
len(self.train_dl.dataset.dataset.use_bands),
self.train_dl.dataset.dataset.tile_size,
self.train_dl.dataset.dataset.tile_size)
)
# dataset split
fs += '\n (split):'
fs += '\n ' + repr(self.train_dl.dataset)
fs += '\n ' + repr(self.valid_dl.dataset)
fs += '\n ' + repr(self.test_dl.dataset)
# model
fs += '\n (model):\n '
fs += ''.join(repr(self.model)).replace('\n', '\n ')
# optimizer
fs += '\n (optimizer):\n '
fs += ''.join(repr(self.optimizer)).replace('\n', '\n ')
# early stopping
fs += '\n (early stop):\n '
fs += ''.join(repr(self.es)).replace('\n', '\n ')
fs += '\n)'
return fs
[docs]class EarlyStopping(object):
"""`Early stopping`_ algorithm.
This implementation of the early stopping algorithm advances a counter each
time a metric did not improve over a training epoch. If the metric does not
improve over more than ``patience`` epochs, the early stopping criterion is
met.
See `pysegcnn.core.trainer.NetworkTrainer.train` for an example
implementation.
Parameters
----------
mode : `str`, optional
The mode of the early stopping. Depends on the metric measuring
performance. When using model loss as metric, use ``mode`` = 'min',
however, when using accuracy as metric, use ``mode`` = 'max'. The
default is 'max'.
best : `float`, optional
Threshold indicating the best metric score. At instanciation, set
``best`` to the worst possible score of the metric. ``best`` will be
overwritten during training. The default is 0.
min_delta : `float`, optional
Minimum change in early stopping metric to be considered as an
improvement. The default is 0.
patience : `int`, optional
The number of epochs to wait for an improvement in the early stopping
metric. The default is 10.
Raises
------
ValueError
Raised if ``mode`` is not either 'min' or 'max'.
Returns
-------
None.
.. _Early stopping:
https://en.wikipedia.org/wiki/Early_stopping
"""
def __init__(self, mode='max', best=0, min_delta=0, patience=10):
# check if mode is correctly specified
if mode not in ['min', 'max']:
raise ValueError('Mode "{}" not supported. '
'Mode is either "min" (check whether the metric '
'decreased, e.g. loss) or "max" (check whether '
'the metric increased, e.g. accuracy).'
.format(mode))
# mode to determine if metric improved
self.mode = mode
# whether to check for an increase or a decrease in a given metric
self.is_better = self.decreased if mode == 'min' else self.increased
# minimum change in metric to be considered as an improvement
self.min_delta = min_delta
# number of epochs to wait for improvement
self.patience = patience
# initialize best metric
self.best = best
# initialize early stopping flag
self.early_stop = False
# initialize the early stop counter
self.counter = 0
[docs] def stop(self, metric):
"""Advance early stopping counter.
Parameters
----------
metric : `float`
The current metric score.
Returns
-------
early_stop : `bool`
Whether the early stopping criterion is met.
"""
# if the metric improved, reset the epochs counter, else, advance
if self.is_better(metric, self.best, self.min_delta):
self.counter = 0
self.best = metric
else:
self.counter += 1
LOGGER.info('Early stopping counter: {}/{}'.format(
self.counter, self.patience))
# if the metric did not improve over the last patience epochs,
# the early stopping criterion is met
if self.counter >= self.patience:
LOGGER.info('Early stopping criterion met, stopping training.')
self.early_stop = True
return self.early_stop
[docs] def decreased(self, metric, best, min_delta):
"""Whether a metric decreased with respect to a best score.
Measure improvement for metrics that are considered as 'better' when
they decrease, e.g. model loss, mean squared error, etc.
Parameters
----------
metric : `float`
The current score.
best : `float`
The current best score.
min_delta : `float`
Minimum change to be considered as an improvement.
Returns
-------
`bool`
Whether the metric improved.
"""
return metric < best - min_delta
[docs] def increased(self, metric, best, min_delta):
"""Whether a metric increased with respect to a best score.
Measure improvement for metrics that are considered as 'better' when
they increase, e.g. accuracy, precision, recall, etc.
Parameters
----------
metric : `float`
The current score.
best : `float`
The current best score.
min_delta : `float`
Minimum change to be considered as an improvement.
Returns
-------
`bool`
Whether the metric improved.
"""
return metric > best + min_delta
def __repr__(self):
"""Representation of `~pysegcnn.core.trainer.EarlyStopping`.
Returns
-------
fs : `str`
Representation string.
"""
fs = self.__class__.__name__
fs += '(mode={}, best={:.2f}, delta={}, patience={})'.format(
self.mode, self.best, self.min_delta, self.patience)
return fs