diff --git a/pysegcnn/core/trainer.py b/pysegcnn/core/trainer.py index 641fbec4ddad63a8aca4dfe2caa383928c7ca8a8..a6948c7e809dab3e57983074762262c7bbfe9078 100644 --- a/pysegcnn/core/trainer.py +++ b/pysegcnn/core/trainer.py @@ -44,7 +44,7 @@ from pysegcnn.core.utils import (img2np, item_in_enum, accuracy_function, reconstruct_scene, check_filename_length) from pysegcnn.core.split import SupportedSplits, CustomSubset, SceneSubset from pysegcnn.core.models import (SupportedModels, SupportedOptimizers, - SupportedLossFunctions, Network) + Network) from pysegcnn.core.uda import SupportedUdaMethods, CoralLoss, UDA_POSITIONS from pysegcnn.core.layers import Conv2dSame from pysegcnn.core.logging import log_conf @@ -410,22 +410,14 @@ class ModelConfig(BaseConfig): ---------- model_name : `str` The name of the model. - filters : `list` [`int`] - List of input channels to the convolutional layers. torch_seed : `int` The random seed to initialize the model weights. Useful for reproducibility. optim_name : `str` The name of the optimizer to update the model weights. - cla_loss : `str` - The name of the loss function measuring the model error. uda_loss : `str` - The name of the unsupervised domain adaptation loss. - skip_connection : `bool` - Whether to apply skip connections. The default is `True`. - kwargs: `dict` - The configuration for each convolution in the model. The default is - `{'kernel_size': 3, 'stride': 1, 'dilation': 1}`. + The name of the unsupervised domain adaptation loss. The default is + `''`, which is equivalent of not using unsupervised domain adaptation. batch_size : `int` The model batch size. Determines the number of samples to process before updating the model weights. The default is `64`. @@ -448,11 +440,13 @@ class ModelConfig(BaseConfig): default is `False`, i.e. train from scratch. uda_lambda : `float` The weight of the domain adaptation, trading off adaptation with - classification accuracy on the source domain. + classification accuracy on the source domain. The default is `0.5`. uda_pos : `str` - The layer where to compute the domain adaptation loss. + The layer where to compute the domain adaptation loss. The default is + `enc`, which means calculating the adaptation loss after the encoder + layers. freeze : `bool` - Whether to freeze the pretrained weights. + Whether to freeze the pretrained weights. The default is `False`. lr : `float` The learning rate used by the gradient descent algorithm. The default is `0.001`. @@ -487,8 +481,6 @@ class ModelConfig(BaseConfig): A subclass of :py:class:`pysegcnn.core.models.Network`. optim_class : :py:class:`torch.optim.Optimizer` A subclass of :py:class:`torch.optim.Optimizer`. - cla_loss_class : :py:class:`torch.nn.Module` - A subclass of :py:class:`torch.nn.Module` uda_loss_class : :py:class:`torch.nn.Module` A subclass of :py:class:`torch.nn.Module` state_path : :py:class:`pathlib.Path` @@ -502,14 +494,9 @@ class ModelConfig(BaseConfig): """ model_name: str - filters: list torch_seed: int optim_name: str - cla_loss: str uda_loss: str = '' - skip_connection: bool = True - kwargs: dict = dataclasses.field( - default_factory=lambda: {'kernel_size': 3, 'stride': 1, 'dilation': 1}) batch_size: int = 64 checkpoint: bool = False transfer: bool = False @@ -539,8 +526,7 @@ class ModelConfig(BaseConfig): ------ ValueError Raised if the model ``model_name``, the optimizer ``optim_name``, - the loss function ``cla_loss`` or the domain adaptation loss - ``uda_loss`` is not supported. + or the domain adaptation loss ``uda_loss`` is not supported. """ # check input types @@ -552,10 +538,6 @@ class ModelConfig(BaseConfig): # check whether the optimizer is currently supported self.optim_class = item_in_enum(self.optim_name, SupportedOptimizers) - # check whether the loss function is currently supported - self.cla_loss_class = item_in_enum(self.cla_loss, - SupportedLossFunctions) - # check whether the domain adaptation loss is currently supported if self.transfer and not self.supervised: self.uda_loss_class = item_in_enum(self.uda_loss, @@ -575,15 +557,13 @@ class ModelConfig(BaseConfig): # path to pretrained model self.pretrained_path = self.state_path.joinpath(self.pretrained_model) - def init_optimizer(self, model, **kwargs): + def init_optimizer(self, model): """Instanciate the optimizer. Parameters ---------- model : :py:class:`torch.nn.Module` An instance of :py:class:`torch.nn.Module`. - **kwargs: - Additional keyword arguments passed to ``self.optim_class``. Returns ------- @@ -594,35 +574,14 @@ class ModelConfig(BaseConfig): LOGGER.info('Optimizer: {}.'.format(repr(self.optim_class))) # initialize the optimizer for the specified model - optimizer = self.optim_class(model.parameters(), self.lr, **kwargs) + optimizer = self.optim_class(model.parameters(), self.lr, + **self.optim_kwargs) return optimizer - def init_cla_loss_function(self): - """Instanciate the classification loss function. - - Returns - ------- - cla_loss_function : :py:class:`torch.nn.Module` - An instance of :py:class:`torch.nn.Module`. - - """ - LOGGER.info('Classification loss function: {}.' - .format(repr(self.cla_loss_class))) - - # instanciate the classification loss function - cla_loss_function = self.cla_loss_class() - - return cla_loss_function - - def init_uda_loss_function(self, uda_lambda): + def init_uda_loss_function(self): """Instanciate the domain adaptation loss function. - Parameters - ---------- - uda_lambda : `float` - The weight of the domain adaptation. - Returns ------- uda_loss_function : :py:class:`torch.nn.Module` @@ -633,7 +592,7 @@ class ModelConfig(BaseConfig): .format(repr(self.uda_loss_class))) # instanciate the loss function - uda_loss_function = self.uda_loss_class(uda_lambda) + uda_loss_function = self.uda_loss_class(self.uda_lambda) return uda_loss_function @@ -641,8 +600,8 @@ class ModelConfig(BaseConfig): """Instanciate the model and the optimizer. If ``self.checkpoint`` is set to True, the pretrained model in - ``state_file`` is loaded. Otherwise, the model is initiated - from scratch on the dataset ``ds``. + ``state_file`` is loaded, if it exists. Otherwise, the model is + initiated from scratch on the dataset ``ds``. If ``self.transfer`` is True, the pretrained model in ``self.pretrained_path`` is adjusted to the dataset ``ds``. @@ -727,7 +686,8 @@ class ModelConfig(BaseConfig): Parameters ---------- model_state : `dict` - A dictionary containing the model and optimizer state. + A dictionary containing the model and optimizer state, as + constructed by :py:meth:`~pysegcnn.core.Network.save`. Returns ------- @@ -735,8 +695,6 @@ class ModelConfig(BaseConfig): The model checkpoint loss and accuracy time series. """ - # load model loss and accuracy - # get all non-zero elements, i.e. get number of epochs trained # before the early stop checkpoint_state = {k: v[np.nonzero(v)].reshape(v.shape[0], -1) @@ -745,7 +703,7 @@ class ModelConfig(BaseConfig): return checkpoint_state @staticmethod - def transfer_model(model, bands, ds, freeze=False): + def transfer_model(model, ds, freeze=False): """Adjust a pretrained model to a new dataset. If the number of classes in the pretrained model ``model`` does not @@ -757,8 +715,6 @@ class ModelConfig(BaseConfig): ---------- model : :py:class:`pysegcnn.core.models.Network` An instance of the pretrained model to adjust to the dataset``ds``. - bands : `list` [`str`] - The spectral bands used to train ``model``. ds : :py:class:`pysegcnn.core.dataset.ImageDataset` The dataset to which the classification layer of ``model`` is adapted. @@ -774,8 +730,9 @@ class ModelConfig(BaseConfig): Raised if ``ds`` is not an instance of :py:class:`pysegcnn.core.dataset.ImageDataset`. ValueError - Raised if the bands of ``ds`` do not match the bands ``bands`` of - the dataset the pretrained model ``model`` was trained with. + Raised if the number of input channels in ``ds`` do not match the + number of input channels of the dataset the pretrained model + ``model`` was trained with. Returns ------- @@ -789,10 +746,13 @@ class ModelConfig(BaseConfig): .format('.'.join([ImageDataset.__module__, ImageDataset.__name__]))) - # check whether the current dataset uses the correct spectral bands - if ds.use_bands != bands: - raise ValueError('The model was trained with bands {}, not with ' - 'bands {}.'.format(bands, ds.use_bands)) + # check whether the current dataset uses the same number of input + # channels as the pretrained model + if len(ds.use_bands) != model.in_channels: + raise ValueError('The model was trained with {} input channels, ' + 'which does not match the {} input channels of the + 'specified dataset.'.format(model.in_channels, + len(ds.use_bands))) # configure model for the specified dataset LOGGER.info('Configuring model for new dataset: {}.'.format( @@ -830,9 +790,20 @@ class StateConfig(BaseConfig): """Model state configuration class. Generate the model state filename according to the following naming - convention: + conventions: + - For source domain without domain adaptation: + Model_Optim_SourceDataset_ModelParams.pt + + - For supervised domain adaptation to a target domain: + NameOfPretrainedModel_sda_TargetDataset.pt + + - For unsupervised domain adaptation to a target domain: + Model_Optim_SourceDataset_uda_TargetDataset_ModelParams.pt - `model_dataset_optimizer_splitmode_splitparams_tilesize_batchsize_bands.pt` + - For unsupervised domain adaptation to a target domain using a + pretrained model: + Model_Optim_SourceDataset_TargetDataset_ModelParams_prt_ + NameOfPretrainedModel.pt Attributes ---------- @@ -872,9 +843,14 @@ class StateConfig(BaseConfig): """ super().__post_init__() - # base model state filename - # Model_Dataset_SplitMode_SplitParams_TileSize_BatchSize_Bands - self.state_file = '{}_{}_{}Split_{}_t{}_b{}_{}.pt' + # base dataset state filename: Dataset_SplitMode_SplitParams + self.ds_state_file = '{}_{}Split_{}' + + # base model state filename: Model_Optim + self.ml_state_file = '{}_{}' + + # base model state filename extentsion: TileSize_BatchSize_Bands + self.ml_state_ext = 't{}_b{}_{}.pt' # check that the spectral bands are the same for both source and target # domains @@ -894,71 +870,68 @@ class StateConfig(BaseConfig): The path to the model state file. """ - # state file name for model trained on the source domain only - state_src = self._format_state_file( - self.state_file, self.src_dc, self.src_sc, self.mc) + # source domain dataset state filename + src_ds_state = self._format_ds_state( + self.ds_state_file, self.src_dc, self.src_sc) + + # model state file name and extension: common to both source and target + # domain + ml_state, ml_ext = self._format_ml_state(self.src_dc, self.mc) + + # state file for models trained only on the source domain + state = '_'.join([ml_state, src_ds_state, ml_ext]) # check whether the model is trained on the source domain only - if not self.mc.transfer: - # state file for models trained only on the source domain - state = state_src - else: - # state file for model trained on target domain - state_trg = self._format_state_file( - self.state_file, self.trg_dc, self.trg_sc, self.mc) + if self.mc.transfer: + + # target domain dataset state filename + trg_ds_state = self._format_ds_state( + self.ds_state_file, self.trg_dc, self.trg_sc) # check whether a pretrained model is used to fine-tune to the # target domain if self.mc.supervised: # state file for models fine-tuned to target domain - state = state_trg.replace('.pt', '_pretrained_{}'.format( - self.mc.pretrained_model)) + # DatasetConfig_PretrainedModel.pt + state = '_'.join([self.mc.pretrained_model, + 'sda_{}'.format(trg_ds_state)]) else: # state file for models trained via unsupervised domain # adaptation - state = state_src.replace('.pt', '_uda_{}{}'.format( - self.mc.uda_pos, state_trg.replace(self.mc.model_name, '')) - ) + state = '_'.join([state.replace( + ml_ext, 'uda_{}'.format(self.mc.uda_pos)), + trg_ds_state, ml_ext]) # check whether unsupervised domain adaptation is initialized # from a pretrained model state if self.mc.uda_from_pretrained: - state = state.replace('.pt', '_pretrained_{}'.format( - self.mc.pretrained_model)) + state = '_'.join(state.replace('.pt', ''), + 'prt_{}'.format( + self.mc.pretrained_model)) # path to model state state = self.mc.state_path.joinpath(state) return state - def _format_state_file(self, state_file, dc, sc, mc): - """Format base model state filename. + def _format_ds_state(self, state_file, dc, sc): + """Format base dataset state filename. Parameters ---------- state_file : `str` - The base model state filename. + The base dataset state filename. dc : :py:class:`pysegcnn.core.trainer.DatasetConfig` - The domain dataset configuration. + The dataset configuration. sc : :py:class:`pysegcnn.core.trainer.SplitConfig` - The domain dataset split configuration. - mc : :py:class:`pysegcnn.core.trainer.ModelConfig` - The model configuration. + The dataset split configuration. Returns ------- file : `str` - The formatted model state filename. + The formatted dataset state filename. """ - # get the band numbers - if dc.bands: - bformat = ''.join(band[0] + - str(dc.dataset_class.get_sensor(). - __members__[band].value) - for band in dc.bands) - else: - bformat = 'all' # check which split mode was used if sc.split_mode == 'date': @@ -971,879 +944,870 @@ class StateConfig(BaseConfig): str(sc.tvratio).replace('.', '')) # model state filename - file = state_file.format(mc.model_name, - dc.dataset_class.__name__ + + file = state_file.format(dc.dataset_class.__name__ + '_m{}'.format(len(dc.merge_labels)), sc.split_mode.capitalize(), split_params, - dc.tile_size, - mc.batch_size, - bformat) + ) return file + def _format_ml_state(self, dc, mc): + """Format base model state filename. + + Parameters + ---------- + dc : :py:class:`pysegcnn.core.trainer.DatasetConfig` + The source domain dataset configuration. + mc : :py:class:`pysegcnn.core.trainer.ModelConfig` + The model configuration. + + Returns + ------- + extention : `str` + The formatted model state filename. + + """ + # get the band numbers + if dc.bands: + bands = dc.dataset_class.get_sensor().band_dict() + bformat = ''.join([(v[0] + str(k)) for k, v in bands.items()]) + else: + bformat = 'all' + + return (self.ml_state_file.format(mc.model_name, mc.optim_name), + self.ml_state_ext.format(dc.tile_size, mc.batch_size, bformat)) + @dataclasses.dataclass -class NetworkInference(BaseConfig): - """Model inference configuration. +class LogConfig(BaseConfig): + """Logging configuration class. - Evaluate a model. + Generate the model log file. Attributes ---------- state_file : :py:class:`pathlib.Path` - Path to the model to evaluate. - implicit : `bool` - Whether to evaluate the model on the datasets defined at training time. - domain : `str` - Whether to evaluate on the source domain (``domain='src'``), i.e. the - domain the model in ``state_file`` was trained on, or a target domain - (``domain='trg'``). - test : `bool` or `None` - Whether to evaluate the model on the training(``test=None``), the - validation (``test=False``) or the test set (``test=True``). - map_labels : `bool` - Whether to map the model labels from the model source domain to the - defined ``domain`` in case the domain class labels differ. - ds : `dict` - The dataset configuration dictionary passed to - :py:class:`pysegcnn.core.trainer.DatasetConfig` when evaluating on - an explicitly defined dataset, i.e. ``implicit=False``. - ds_split : `dict` - The dataset split configuration dictionary passed to - :py:class:`pysegcnn.core.trainer.SplitConfig` when evaluating on - an explicitly defined dataset, i.e. ``implicit=False``. - predict_scene : `bool` - The model prediction order. If False, the samples (tiles) of a dataset - are predicted in any order and the scenes are not reconstructed. - If True, the samples (tiles) are ordered according to the scene they - belong to and a model prediction for each entire reconstructed scene is - returned. The default is `False`. - plot_samples : `bool` - Whether to save a plot of false color composite, ground truth and model - prediction for each sample (tile). Only used if ``predict_scene=False`` - . The default is `False`. - plot_scenes : `bool` - Whether to save a plot of false color composite, ground truth and model - prediction for each entire scene. Only used if ``predict_scene=True``. - The default is `False`. - plot_bands : `list` [`str`] - The bands to build the false color composite. The default is - `['nir', 'red', 'green']`. - cm : `bool` - Whether to compute and plot the confusion matrix. The default is `True` - . - figsize : `tuple` - The figure size in centimeters. The default is `(10, 10)`. - alpha : `int` - The level of the percentiles for contrast stretching of the false color - compsite. The default is `0`, i.e. no stretching. - animate : `bool` - Whether to create an animation of (input, ground truth, prediction) for - the scenes in the train/validation/test dataset. Only works if - ``predict_scene=True`` and ``plot_scene=True``. - device : `str` - The device to evaluate the model on, i.e. `cpu` or `cuda`. - base_path : :py:class:`pathlib.Path` - Root path to store model output. - sample_path : :py:class:`pathlib.Path` - Path to store plots of model predictions for single samples. - scenes_path : :py:class:`pathlib.Path` - Path to store plots of model predictions for entire scenes. - perfmc_path : :py:class:`pathlib.Path` - Path to store plots of model performance, e.g. confusion matrix. - animtn_path : :py:class:`pathlib.Path` - Path to store animations. - models_path : :py:class:`pathlib.Path` - Path to search for model state files, i.e. pretrained models. - kwargs : `dict` - Keyword arguments for :py:func:`pysegcnn.core.graphics.plot_sample` - basename : `str` - Base filename for each plot. - model : :py:class:`pysegcnn.core.models.Network` - The model to use for inference. - model_state : `dict` - The model state as saved by - :py:class:`pysegcnn.core.trainer.NetworkTrainer`. - trg_ds : :py:class:`pysegcnn.core.split.CustomSubset` - The dataset to evaluate ``model`` on. - src_ds : :py:class:`pysegcnn.core.split.CustomSubset` - The model source domain training dataset. - fig : :py:class:`matplotlib.figure.Figure` - A :py:class:`matplotlib.figure.Figure` instance to iteratively plot to. - anim : :py:class:`pysegcnn.core.graphics.Animate` - An instance :py:class:`pysegcnn.core.graphics.Animate` Used to create - animations if ``animate=True``. - conf_mat : :py:class:`numpy.ndarray` - The model confusion matrix. - + Path to a model state file. + log_path : :py:class:`pathlib.Path` + Path to store model logs. + log_file : :py:class:`pathlib.Path` + Path to the log file of the model ``state_file``. """ state_file: pathlib.Path - implicit: bool - domain: str - test: object - map_labels: bool - ds: dict = dataclasses.field(default_factory={}) - ds_split: dict = dataclasses.field(default_factory={}) - predict_scene: bool = False - plot_samples: bool = False - plot_scenes: bool = False - plot_bands: list = dataclasses.field( - default_factory=lambda: ['nir', 'red', 'green']) - cm: bool = True - figsize: tuple = (10, 10) - alpha: int = 5 - animate: bool = False def __post_init__(self): """Check the type of each argument. - Configure figure output paths. - - Raises - ------ - TypeError - Raised if ``test`` is not of type `bool` or `None`. - ValueError - Raised if ``domain`` is not 'src' or 'trg'. + Generate model log file. """ super().__post_init__() - # check whether the test input parameter is correctly specified - if self.test not in [None, False, True]: - raise TypeError('Expected "test" to be None, True or False, got ' - '{}.'.format(self.test)) - - # check whether the domain is correctly specified - if self.domain not in ['src', 'trg']: - raise ValueError('Expected "domain" to be "src" or "trg", got {}.' - .format(self.domain)) - - # the device to compute on, use gpu if available - self.device = torch.device("cuda:0" if torch.cuda.is_available() else - "cpu") + # the path to store model logs + self.log_path = pathlib.Path(HERE).joinpath('_logs') - # the output paths for the different graphics - self.base_path = pathlib.Path(HERE) - self.sample_path = self.base_path.joinpath('_samples') - self.scenes_path = self.base_path.joinpath('_scenes') - self.perfmc_path = self.base_path.joinpath('_graphics') - self.animtn_path = self.base_path.joinpath('_animations') + # the log file of the current model + self.log_file = check_filename_length(self.log_path.joinpath( + self.state_file.name.replace('.pt', '.log'))) - # input path for model state files - self.models_path = self.base_path.joinpath('_models') - self.state_file = self.models_path.joinpath(self.state_file) + @staticmethod + def now(): + """Return the current date and time. - # initialize logging - log = LogConfig(self.state_file) - dictConfig(log_conf(log.log_file)) - log.init_log('{}: ' + 'Evaluating model: {}.' - .format(self.state_file.name)) + Returns + ------- + date : :py:class:`datetime.datetime` + The current date and time. - # plotting keyword arguments - self.kwargs = {'bands': self.plot_bands, - 'alpha': self.alpha, - 'figsize': self.figsize} + """ + return datetime.datetime.strftime(datetime.datetime.now(), + '%Y-%m-%dT%H:%M:%S') - # base filename for each plot - self.basename = self.state_file.stem + @staticmethod + def init_log(init_str): + """Generate a string to identify a new model run. - # load the model state - self.model, _, self.model_state = Network.load(self.state_file) + Parameters + ---------- + init_str : `str` + The string to write to the model log file. - # load the target dataset: dataset to evaluate the model on - self.trg_ds = self.load_dataset() + """ + LOGGER.info(80 * '-') + LOGGER.info(init_str.format(LogConfig.now())) + LOGGER.info(80 * '-') - # load the source dataset: dataset the model was trained on - self.src_ds = self.model_state['src_train_dl'].dataset.dataset - # create a figure to use for plotting - self.fig, _ = plt.subplots(1, 3, figsize=self.kwargs['figsize']) +@dataclasses.dataclass +class ClassificationNetworkTrainer(BaseConfig): + """Base model training class for classification problems. - # check if the animate parameter is correctly specified - if self.animate: - if not self.plot: - LOGGER.warning('animate requires plot_scenes=True or ' - 'plot_samples=True. Can not create animation.') - else: - # check whether the output path is valid - if not self.anim_path.exists(): - # create output path - self.anim_path.mkdir(parents=True, exist_ok=True) - self.anim = Animate(self.anim_path) + Train an instance of :py:class:`pysegcnn.core.models.Network` on a + classification problem. The `categorical cross-entropy loss`_ + is used as the loss function in combination with the `softmax`_ output + layer activation function. - @staticmethod - def get_scene_tiles(ds, scene_id): - """Return the tiles of the scene with id ``scene_id``. + In case of a binary classification problem, the categorical cross-entropy + loss reduces to the binary cross-entropy loss and the softmax function to + the standard `logistic function`_. - Parameters - ---------- - ds : :py:class:`pysegcnn.core.split.CustomSubset` - A instance of a subclass of - :py:class:`pysegcnn.core.split.CustomSubset`. - scene_id : `str` - A valid scene identifier. - - Raises - ------ - ValueError - Raised if ``scene_id`` is not a valid scene identifier for the - dataset ``ds``. - - Returns - ------- - indices : `list` [`int`] - List of indices of the tiles of the scene with id ``scene_id`` in - ``ds``. - date : :py:class:`datetime.datetime` - The date of the scene with id ``scene_id``. + Attributes + ---------- + model : :py:class:`pysegcnn.core.models.Network` + The model to train. An instance of + :py:class:`pysegcnn.core.models.Network`. + optimizer : :py:class:`torch.optim.Optimizer` + The optimizer to update the model weights. An instance of + :py:class:`torch.optim.Optimizer`. + state_file : :py:class:`pathlib.Path` + Path to save the model state. + src_train_dl : :py:class:`torch.utils.data.DataLoader` + The source domain training :py:class:`torch.utils.data.DataLoader` + instance build from an instance of + :py:class:`pysegcnn.core.split.CustomSubset`. + src_valid_dl : :py:class:`torch.utils.data.DataLoader` + The source domain validation :py:class:`torch.utils.data.DataLoader` + instance build from an instance of + :py:class:`pysegcnn.core.split.CustomSubset`. + src_test_dl : :py:class:`torch.utils.data.DataLoader` + The source domain test :py:class:`torch.utils.data.DataLoader` + instance build from an instance of + :py:class:`pysegcnn.core.split.CustomSubset`. + epochs : `int` + The maximum number of epochs to train. The default is `1`. + nthreads : `int` + The number of cpu threads to use during training. The default is + :py:func:`torch.get_num_threads()`. + early_stop : `bool` + Whether to apply `Early Stopping`_. The default is `False`. + mode : `str` + The early stopping mode. Depends on the metric measuring + performance. When using model loss as metric, use ``mode='min'``, + however, when using accuracy as metric, use ``mode='max'``. For now, + only ``mode='max'`` is supported. Only used if ``early_stop=True``. + The default is `'max'`. + delta : `float` + Minimum change in early stopping metric to be considered as an + improvement. Only used if ``early_stop=True``. The default is `0`. + patience : `int` + The number of epochs to wait for an improvement in the early stopping + metric. If the model does not improve over more than ``patience`` + epochs, quit training. Only used if ``early_stop=True``. The default is + `10`. + checkpoint_state : `dict` [`str`, :py:class:`numpy.ndarray`] + A model checkpoint for ``model``. If specified, ``checkpoint_state`` + should be a dictionary with keys describing the training metric. + The default is `{}`. + save : `bool` + Whether to save the model state to ``state_file``. The default is + `True`. + device : `str` + The device to train the model on, i.e. `cpu` or `cuda`. + cla_loss_function : :py:class:`torch.nn.Module` + The classification loss function to compute the model error. An + instance of :py:class:`torch.nn.CrossEntropyLoss`. + tracker : :py:class:`pysegcnn.core.trainer.MetricTracker` + A :py:class:`pysegcnn.core.trainer.MetricTracker` instance tracking + training metrics, i.e. loss and accuracy. + max_accuracy : `float` + Maximum accuracy of ``model`` on the validation dataset. + es : `None` or :py:class:`pysegcnn.core.trainer.EarlyStopping` + The early stopping instance if ``early_stop=True``, else `None`. + tmbatch : `int` + Number of mini-batches in the training dataset. + vmbatch : `int` + Number of mini-batches in the validation dataset. + training_state : `dict` [`str`, :py:class:`numpy.ndarray`] + The training state dictionary. The keys describe the type of the + training metric. + params_to_save : `dict` + The parameters to save in the model ``state_file``, in addition to the + model and optimizer weights. - """ - # check if the scene id is valid - scene_meta = ds.dataset.parse_scene_id(scene_id) - if scene_meta is None: - raise ValueError('{} is not a valid scene identifier' - .format(scene_id)) + .. _Early Stopping: + https://en.wikipedia.org/wiki/Early_stopping - # iterate over the scenes of the dataset - indices = [] - for i, scene in enumerate(ds.scenes): - # if the scene id matches a given id, save the index of the scene - if scene['id'] == scene_id: - indices.append(i) + .. _multi-class cross-entropy loss: + https://gombru.github.io/2018/05/23/cross_entropy_loss/ - return indices, scene_meta['date'] + .. _softmax: + https://peterroelants.github.io/posts/cross-entropy-softmax/ - @staticmethod - def replace_dataset_path(ds, drive_path): - """Replace the path to the datasets. + .. _logistic function: + https://en.wikipedia.org/wiki/Logistic_function - Useful to evaluate models on machines, that are different from the - machine the model was trained on. + """ - .. important:: + model: Network + optimizer: Optimizer + state_file: pathlib.Path + src_train_dl: DataLoader + src_valid_dl: DataLoader + src_test_dl: DataLoader + epochs: int = 1 + nthreads: int = torch.get_num_threads() + early_stop: bool = False + mode: str = 'max' + delta: float = 0 + patience: int = 10 + checkpoint_state: dict = dataclasses.field(default_factory={}) + save: bool = True - This function assumes that the datasets are stored in a directory - named "Datasets" on each machine. + def __post_init__(self): + """Check the type of each argument. - See ``DRIVE_PATH`` in :py:mod:`pysegcnn.main.config`. + Configure the device to train the model on, i.e. train on the gpu if + available. - Parameters - ---------- - ds : :py:class:`pysegcnn.core.split.CustomSubset` - A subset of an instance of - :py:class:`pysegcnn.core.dataset.ImageDataset`. - drive_path : `str` - Base path to the datasets on the current machine. ``drive_path`` - should end with `'Datasets'`. + Configure early stopping if required. - Raises - ------ - TypeError - Raised if ``ds`` is not an instance of - :py:class:`pysegcnn.core.split.CustomSubset` and if ``ds`` is not - a subset of an instance of - :py:class:`pysegcnn.core.dataset.ImageDataset`. + Initialize training metric tracking. """ - # check input type - if isinstance(ds, CustomSubset): - # check the type of the dataset - if not isinstance(ds.dataset, ImageDataset): - raise TypeError('ds should be a subset created from a {}.' - .format(repr(ImageDataset))) - else: - raise TypeError('ds should be an instance of {}.' - .format(repr(CustomSubset))) + super().__post_init__() - # iterate over the scenes of the dataset - for scene in ds.dataset.scenes: - for k, v in scene.items(): - # do only look for paths - if isinstance(v, str) and k != 'id': + # the device to train the model on + self.device = torch.device('cuda:0' if torch.cuda.is_available() else + 'cpu') + # set the number of threads + torch.set_num_threads(self.nthreads) - # drive path: match path before "Datasets" - # dpath = re.search('^(.*)(?=(/.*Datasets))', v) + # send the model to the gpu if available + self.model = self.model.to(self.device) - # drive path: match path up to "Datasets" - dpath = re.search('^(.*?Datasets)', v)[0] + # instanciate multiclass classification loss function: multi-class + # cross-entropy loss function + self.cla_loss_function = nn.CrossEntropyLoss() + LOGGER.info('Classification loss function: {}.' + .format(repr(self.cla_loss_class))) - # replace drive path - if dpath != drive_path: - scene[k] = v.replace(str(dpath), drive_path) + # instanciate metric tracker + self.tracker = MetricTracker( + train_metrics=['train_loss', 'train_accu'], + valid_metrics=['valid_loss', 'valid_accu']) - def load_dataset(self): - """Load the defined dataset. + # initialize metric tracker + self.tracker.initialize() - Raises - ------ - ValueError - Raised if the requested dataset was not available at training time, - if ``implicit=True``. + # maximum accuracy on the validation set + self.max_accuracy = 0 + if self.checkpoint_state: + self.max_accuracy = self.checkpoint_state['valid_accu'].mean( + axis=0).max().item() - Raised if the dataset ``ds`` does not have the same spectral bands - as the model to evaluate, if ``implicit=False``. + # whether to use early stopping + self.es = None + if self.early_stop: + self.es = EarlyStopping(self.mode, self.max_accuracy, self.delta, + self.patience) - Returns - ------- - ds : :py:class:`pysegcnn.core.split.CustomSubset` - The dataset to evaluate the model on. + # number of mini-batches in the training and validation sets + self.tmbatch = len(self.src_train_dl) + self.vmbatch = len(self.src_valid_dl) - """ - # check whether to evaluate on the datasets defined at training time - if self.implicit: + # log representation + LOGGER.info(repr(self)) - # check whether to evaluate the model on the training, validation - # or test set - if self.test is None: - ds_set = 'train' - else: - ds_set = 'test' if self.test else 'valid' + # initialize training log + LOGGER.info(35 * '-' + ' Training ' + 35 * '-') - # the dataset to evaluate the model on - ds = self.model_state[ - self.domain + '_{}_dl'.format(ds_set)].dataset - if ds is None: - raise ValueError('Requested dataset "{}" is not available.' - .format(self.domain + '_{}_dl'.format(ds_set)) - ) + # log the device and number of threads + LOGGER.info('Device: {}'.format(self.device)) + LOGGER.info('Number of cpu threads: {}'.format(self.nthreads)) - # log dataset representation - LOGGER.info('Evaluating on {} set of the {} domain defined at ' - 'training time.'.format(ds_set, self.domain)) + def train_source_domain(self, epoch): + """Train a model for a single epoch on the source domain. - else: - # explicitly defined dataset - ds = DatasetConfig(**self.ds).init_dataset() + Parameters + ---------- + epoch : `int` + The current epoch. - # check if the spectral bands match - if ds.use_bands != self.model_state['bands']: - raise ValueError('The model was trained with bands {}, not ' - 'with bands {}.'.format( - self.model_state['bands'], ds.use_bands)) + """ + # iterate over the dataloader object + for batch, (inputs, labels) in enumerate(self.src_train_dl): - # split configuration - sc = SplitConfig(**self.ds_split) - train_ds, valid_ds, test_ds = sc.train_val_test_split(ds) + # send the data to the gpu if available + inputs = inputs.to(self.device) + labels = labels.to(self.device) - # check whether to evaluate the model on the training, validation - # or test set - if self.test is None: - ds = train_ds - else: - ds = test_ds if self.test else valid_ds + # reset the gradients + self.optimizer.zero_grad() - # log dataset representation - LOGGER.info('Evaluating on {} set of explicitly defined dataset: ' - '\n {}'.format(ds.name, repr(ds.dataset))) + # perform forward pass + outputs = self.model(inputs) - # check the dataset path: replace by path on current machine - self.replace_dataset_path(ds, DRIVE_PATH) + # compute loss + loss = self.cla_loss_function(outputs, labels.long()) - return ds + # compute the gradients of the loss function w.r.t. + # the network weights + loss.backward() - @property - def source_labels(self): - """Class labels of the source domain the model was trained on. + # update the weights + self.optimizer.step() - Returns - ------- - source_labels : `dict` [`int`, `dict`] - The class labels of the source domain. + # calculate predicted class labels + ypred = F.softmax(outputs, dim=1).argmax(dim=1) + + # calculate accuracy on current batch + acc = accuracy_function(ypred, labels) + + # print progress + LOGGER.info('Epoch: {:d}/{:d}, Mini-batch: {:d}/{:d}, ' + 'Loss: {:.2f}, Accuracy: {:.2f}' + .format(epoch + 1, self.epochs, batch + 1, + self.tmbatch, loss.item(), acc)) + + # update training metrics + self.tracker.batch_update(self.tracker.train_metrics, + [loss.item(), acc]) + + def train_epoch(self, epoch): + """Train a model for a single epoch on the source domain. + + Parameters + ---------- + epoch : `int` + The current epoch. """ - return self.src_ds.labels + self.train_source_domain(epoch) - @property - def target_labels(self): - """Class labels of the dataset to evaluate. + def train(self): + """Train the model. Returns ------- - target_labels : `dict` [`int`, `dict`] - The class labels of the dataset to evaluate. + training_state : `dict` [`str`, :py:class:`numpy.ndarray`] + The training state dictionary. The keys describe the type of the + training metric. See + :py:meth:`~pysegcnn.core.trainer.NetworkTrainer.training_state`. """ - return self.trg_ds.dataset.labels + # initialize the training: iterate over the entire training dataset + for epoch in range(self.epochs): - # @property - # def label_map(self): - # """Label mapping from the source to the target domain. + # set the model to training mode + LOGGER.info('Setting model to training mode ...') + self.model.train() - # See :py:func:`pysegcnn.core.constants.map_labels`. + # train model for a single epoch + self.train_epoch(epoch) - # Returns - # ------- - # label_map : `dict` [`int`, `int`] - # Dictionary with source labels as keys and corresponding target - # labels as values. + # update the number of epochs trained + self.model.epoch += 1 - # """ - # # check whether the source domain labels are the same as the target - # # domain labels - # return map_labels(self.src_ds.get_labels(), - # self.trg_ds.dataset.get_labels()) + # whether to evaluate model performance on the validation set and + # early stop the training process + if self.early_stop: - @property - def source_is_target(self): - """Whether the source and target domain labels are the same. + # model predictions on the validation set + valid_accu, valid_loss = self.predict(self.src_valid_dl) - Returns - ------- - source_is_target : `bool` - `True` if the source and target domain labels are the same, `False` - if not. + # update validation metrics + self.tracker.batch_update(self.tracker.valid_metrics, + [valid_loss, valid_accu]) - """ - return self.label_map is None + # metric to assess model performance on the validation set + epoch_acc = np.mean(valid_accu) - @property - def apply_label_map(self): - """Whether to map source labels to target labels. + # whether the model improved with respect to the previous epoch + if self.es.increased(epoch_acc, self.max_accuracy, self.delta): + self.max_accuracy = epoch_acc - Returns - ------- - apply_label_map : `bool` - `True` if source and target labels differ and label mapping is - requested, `False` otherwise. + # save model state if the model improved with + # respect to the previous epoch + if self.save: + self.save_state() - """ - return not self.source_is_target and self.map_labels + # whether the early stopping criterion is met + if self.es.stop(epoch_acc): + break - @property - def use_labels(self): - """Labels to be predicted. + else: + # if no early stopping is required, the model state is + # saved after each epoch + if self.save: + self.save_state() - Returns - ------- - use_labels : `dict` [`int`, `dict`] - The labels of the classes to be predicted. + return self.training_state - """ - return (self.target_labels if self.apply_label_map else - self.source_labels) + def predict(self, dataloader): + """Model inference at training time. - @property - def bands(self): - """Spectral bands the model was trained with. + Parameters + ---------- + dataloader : :py:class:`torch.utils.data.DataLoader` + The validation dataloader to evaluate the model predictions. Returns ------- - bands : `list` [`str`] - A list of the named spectral bands used to train the model. + accuracy : :py:class:`numpy.ndarray` + The mean model prediction accuracy on each mini-batch in the + validation set. + loss : :py:class:`numpy.ndarray` + The model loss for each mini-batch in the validation set. """ - return self.src_ds.use_bands + # set the model to evaluation mode + LOGGER.info('Setting model to evaluation mode ...') + self.model.eval() + + # create arrays of the observed loss and accuracy + accuracy = [] + loss = [] + + # iterate over the validation/test set + LOGGER.info('Calculating accuracy on the validation set ...') + for batch, (inputs, labels) in enumerate(dataloader): + + # send the data to the gpu if available + inputs = inputs.to(self.device) + labels = labels.to(self.device) + + # calculate network outputs + with torch.no_grad(): + outputs = self.model(inputs) + + # compute loss + cla_loss = self.cla_loss_function(outputs, labels.long()) + loss.append(cla_loss.item()) + + # calculate predicted class labels + pred = F.softmax(outputs, dim=1).argmax(dim=1) + + # calculate accuracy on current batch + acc = accuracy_function(pred, labels) + accuracy.append(acc) + + # print progress + LOGGER.info('Mini-batch: {:d}/{:d}, Accuracy: {:.2f}' + .format(batch + 1, len(dataloader), acc)) + + # calculate overall accuracy on the validation/test set + LOGGER.info('Epoch: {:d}, Mean accuracy: {:.2f}%.' + .format(self.model.epoch, np.mean(accuracy) * 100)) + + return accuracy, loss + + def save_state(self): + """Save the model state.""" + _ = self.model.save(self.state_file, + self.optimizer, + state=self.training_state, + **self.params_to_save) @property - def compute_cm(self): - """Whether to compute the confusion matrix. + def training_state(self): + """Model training metrics. Returns ------- - compute_cm : `bool` - Whether the confusion matrix can be computed. For datasets with - labels different from the source domain labels, the confusion - matrix can not be computed. + state : `dict` [`str`, :py:class:`numpy.ndarray`] + The training state dictionary. The keys describe the type of the + training metric and the values are :py:class:`numpy.ndarray`'s of + the corresponding metric observed during training with + shape=(mini_batch, epoch). """ - return (False if not self.source_is_target and not self.map_labels - else self.cm) + # current training state + state = self.tracker.np_state(self.tmbatch, self.vmbatch) + + # optional: training state of the model checkpoint + if self.checkpoint_state: + # prepend values from checkpoint to current training state + state = {k1: np.hstack([v1, v2]) for (k1, v1), (k2, v2) in + zip(self.checkpoint_state.items(), state.items()) + if k1 == k2} + + return state @property - def plot(self): - """Whether to save plots of (input, ground truth, prediction). + def params_to_save(self): + """The parameters and variables to save in the model state file.""" + return {'src_train_dl': self.src_train_dl, + 'src_valid_dl': self.src_valid_dl, + 'src_test_dl': self.src_test_dl} + + def _build_model_repr_(self): + """Build the model representation. Returns ------- - plot : `bool` - Save plots for each sample or for each scene of the target dataset, - depending on ``self.predict_scene``. + fs : `str` + Representation string. """ - return self.plot_scenes if self.predict_scene else self.plot_samples + # model + fs = '\n (model):' + '\n' + 8 * ' ' + fs += ''.join(repr(self.model)).replace('\n', '\n' + 8 * ' ') - @property - def is_scene_subset(self): - """Check the type of the target dataset. + # optimizer + fs += '\n (optimizer):' + '\n' + 8 * ' ' + fs += ''.join(repr(self.optimizer)).replace('\n', '\n' + 8 * ' ') - Whether ``self.trg_ds`` is an instance of - :py:class:`pysegcnn.core.split.SceneSubset`, as required when - ``self.predict_scene=True``. + # loss function + fs += '\n (loss function):' + '\n' + 8 * ' ' + fs += ''.join(repr(self.cla_loss_function)).replace('\n', + '\n' + 8 * ' ') + + # early stopping + fs += '\n (early stop):' + '\n' + 8 * ' ' + fs += ''.join(repr(self.es)).replace('\n', '\n' + 8 * ' ') + + return fs + + def __repr__(self): + """Representation. Returns ------- - is_scene_subset : `bool` - Whether ``self.trg_ds`` is an instance of - :py:class:`pysegcnn.core.split.SceneSubset`. + fs : `str` + Representation string. """ - return isinstance(self.trg_ds, SceneSubset) + # representation string to print + fs = self.__class__.__name__ + '(\n' - @property - def dataloader(self): - """Dataloader instance for model inference. + # model configuration + fs += self._build_model_repr_() - Returns - ------- - dataloader : :py:class:`torch.utils.data.DataLoader` - The dataset for model inference. + fs += '\n)' + return fs - """ - # build the dataloader for model inference - return DataLoader(self.trg_ds, batch_size=self._batch_size, - shuffle=False, drop_last=False) - @property - def _original_source_labels(self): - """Original source domain labels. +@dataclasses.dataclass +class MultispectralImageSegmentationTrainer(ClassificationNetworkTrainer): + """Model training class for multispectral image segmentation. - Since PyTorch requires class labels to be an ascending sequence - starting from 0, the actual class labels in the ground truth may differ - from the class labels fed to the model. + Train an instance of :py:class:`pysegcnn.core.models.EncoderDecoderNetwork` + on an instance of :py:class:`pysegcnn.core.dataset.ImageDataset`. - Returns - ------- - original_source_labels : `dict` [`int`, `dict`] - The original class labels of the source domain. + Attributes + ---------- + trg_train_dl : :py:class:`torch.utils.data.DataLoader` + The target domain training :py:class:`torch.utils.data.DataLoader` + instance build from an instance of + :py:class:`pysegcnn.core.split.CustomSubset`. The default is an empty + :py:class:`torch.utils.data.DataLoader`. + trg_valid_dl : :py:class:`torch.utils.data.DataLoader` + The target domain validation :py:class:`torch.utils.data.DataLoader` + instance build from an instance of + :py:class:`pysegcnn.core.split.CustomSubset`. The default is an empty + :py:class:`torch.utils.data.DataLoader`. + trg_test_dl : :py:class:`torch.utils.data.DataLoader` + The target domain test :py:class:`torch.utils.data.DataLoader` + instance build from an instance of + :py:class:`pysegcnn.core.split.CustomSubset`. The default is an empty + :py:class:`torch.utils.data.DataLoader`. + uda_loss_function : :py:class:`torch.nn.Module` + The domain adaptation loss function. An instance of + :py:class:`torch.nn.Module`. + The default is :py:class:`pysegcnn.core.uda.CoralLoss`. + uda_lambda : `float` + The weight of the domain adaptation, trading off adaptation with + classification accuracy on the source domain. The default is `0`. + uda_pos : `str` + The layer where to compute the domain adaptation loss. The default + is `'enc'`, i.e. compute the domain adaptation loss using the output of + the model encoder. + uda : `bool` + Whether to train using deep domain adaptation. - """ - return self.src_ds._labels + """ - @property - def _original_target_labels(self): - """Original target domain labels. + trg_train_dl: DataLoader = DataLoader(None) + trg_valid_dl: DataLoader = DataLoader(None) + trg_test_dl: DataLoader = DataLoader(None) + uda_loss_function: nn.Module = CoralLoss(uda_lambda=0) + uda_lambda: float = 0 + uda_pos: str = 'enc' - Returns - ------- - original_target_labels : `dict` [`int`, `dict`] - The original class labels of the target domain. + def __post_init__(self): + """Check the type of each argument. - """ - return self.trg_ds.dataset._labels + Configure the device to train the model on, i.e. train on the gpu if + available. - @property - def _label_log(self): - """Log if a label mapping is applied. + Configure early stopping if required. - Returns - ------- - log : `str` - Represenation of the label mapping. + Initialize training metric tracking. """ - log = 'Retaining model labels ({})'.format( - ', '.join([v['label'] for _, v in self.source_labels.items()])) - if self.apply_label_map: - log = (log.replace('Retaining', 'Mapping') + ' to target labels ' - '({}).'.format(', '.join([v['label'] for _, v in - self.target_labels.items()]))) - return log - - @property - def _batch_size(self): - """Batch size of the inference dataloader. + super().__post_init__() - Returns - ------- - batch_size : `int` - The batch size of the dataloader used for model inference. Depends - on whether to predict each sample of the target dataset - individually or whether to reconstruct each scene in the target - dataset. + # whether to train using supervised transfer learning or + # deep domain adaptation - """ - return (self.trg_ds.dataset.tiles if self.predict_scene and - self.is_scene_subset else 1) + # dummy variables for easy model evaluation + self.uda = False + if self.trg_train_dl.dataset is not None and self.uda_lambda > 0: - def _check_long_filename(self, filename): - """Modify filenames that exceed Windows' maximum filename length. + # set the device for computing domain adaptation loss + self.uda_loss_function.device = self.device - Parameters - ---------- - filename : `str` - The filename to check. + # adjust metrics and initialize metric tracker + self.tracker.train_metrics.extend(['cla_loss', 'uda_loss']) - Returns - ------- - filename : `str` - The modified filename, in case ``filename`` exceeds 255 characters. + # train using deep domain adaptation + self.uda = True - """ - # check for maximum path component length: Windows allows a maximum - # of 255 characters for each component in a path - if len(filename) >= 255: - # try to parse the scene identifier - scene_metadata = self.trg_ds.dataset.parse_scene_id(filename) - if scene_metadata is not None: - # new, shorter name for the current batch - batch_name = '_'.join([scene_metadata['satellite'], - scene_metadata['tile'], - datetime.datetime.strftime( - scene_metadata['date'], '%Y%m%d')]) - filename = filename.replace(scene_metadata['id'], batch_name) + def _inp_uda(self, src_input, trg_input): + """Domain adaptation at input feature level.""" - return filename + # perform forward pass: classified source domain features + src_prdctn = self.model(src_input) - def map_to_target(self, prd): - """Map source domain labels to target domain labels. + return src_input, trg_input, src_prdctn - Parameters - ---------- - prd : :py:class:`torch.Tensor` - The source domain class labels as predicted by ```self.model``. + def _enc_uda(self, src_input, trg_input): + """Domain adaptation at encoder feature level.""" - Returns - ------- - prd : :py:class:`torch.Tensor` - The predicted target domain labels. + # perform forward pass: encoded source domain features + src_feature = self.model.encoder(src_input) + src_dec_feature = self.model.decoder(src_feature, + self.model.encoder.cache) + # model logits on source domain + src_prdctn = self.model.classifier(src_dec_feature) + del self.model.encoder.cache # clear intermediate encoder outputs - """ - # map actual source labels to original source labels - for aid, oid in zip(self.source_labels.keys(), - self._original_source_labels.keys()): - prd[torch.where(prd == aid)] = oid + # perform forward pass: encoded target domain features + trg_feature = self.model.encoder(trg_input) - # apply the label mapping - for src_label, trg_label in self.label_map.items(): - prd[torch.where(prd == src_label)] = trg_label + return src_feature, trg_feature, src_prdctn - # map original target labels to actual target labels - for oid, aid in zip(self._original_target_labels.keys(), - self.target_labels.keys()): - prd[torch.where(prd == oid)] = aid + def _dec_uda(self, src_input, trg_input): + """Domain adaptation at decoder feature level.""" - return prd + # perform forward pass: decoded source domain features + src_feature = self.model.encoder(src_input) + src_feature = self.model.decoder(src_feature, + self.model.encoder.cache) + # model logits on source domain + src_prdctn = self.model.classifier(src_feature) + del self.model.encoder.cache # clear intermediate encoder outputs - def predict(self): - """Classify the samples of the target dataset. + # perform forward pass: decoded target domain features + trg_feature = self.model.encoder(trg_input) + trg_feature = self.model.decoder(trg_feature, + self.model.encoder.cache) + del self.model.encoder.cache - Returns - ------- - output : `dict` [`str`, `dict`] - The inference output dictionary. The keys are either the number of - the samples (``self.predict_scene=False``) or the name of the - scenes of the target dataset (``self.predict_scene=True``). The - values are dictionaries with keys: - ``'input'`` - Model input data of the sample (:py:class:`numpy.ndarray`). - ``'labels' - Ground truth class labels (:py:class:`numpy.ndarray`). - ``'prediction'`` - Model prediction class labels (:py:class:`numpy.ndarray`). + return src_feature, trg_feature, src_prdctn - """ - # iterate over the samples of the target dataset - output = {} - for batch, (inputs, labels) in enumerate(self.dataloader): + def _cla_uda(self, src_input, trg_input): + """Domain adaptation at classifier feature level.""" - # send inputs and labels to device - inputs = inputs.to(self.device) - labels = labels.to(self.device) + # perform forward pass: classified source domain features + src_feature = self.model(src_input) - # compute model predictions - with torch.no_grad(): - prdctn = F.softmax(self.model(inputs), - dim=1).argmax(dim=1).squeeze() + # perform forward pass: target domain features + trg_feature = self.model(trg_input) - # map source labels to target dataset labels - if self.apply_label_map: - prdctn = self.map_to_target(prdctn) + return src_feature, trg_feature, src_feature - # update confusion matrix - if self.compute_cm: - for ytrue, ypred in zip(labels.view(-1), prdctn.view(-1)): - self.conf_mat[ytrue.long(), ypred.long()] += 1 + def uda_frwd(self, src_input, trg_input): + """Forward function for deep domain adaptation. - # convert torch tensors to numpy arrays - inputs = inputs.numpy() - labels = labels.numpy() - prdctn = prdctn.numpy() + Parameters + ---------- + src_input : :py:class:`torch.Tensor` + Source domain input features. + trg_input : :py:class:`torch.Tensor` + Target domain input features. - # progress string to log - progress = 'Sample: {:d}/{:d}'.format(batch + 1, - len(self.dataloader)) + """ + if self.uda_pos == 'inp': + self._inp_uda(src_input, trg_input) - # check whether to reconstruct the scene - date = None - if self.dataloader.batch_size > 1: + if self.uda_pos == 'enc': + self._enc_uda(src_input, trg_input) - # id and date of the current scene - batch = self.trg_ds.ids[batch] - if self.trg_ds.split_mode == 'date': - date = self.trg_ds.dataset.parse_scene_id(batch)['date'] + if self.uda_pos == 'dec': + self._dec_uda(src_input, trg_input) - # modify the progress string - progress = progress.replace('Sample', 'Scene') - progress += ' Id: {}'.format(batch) + if self.uda_pos == 'cla': + self._cla_uda(src_input, trg_input) - # reconstruct the entire scene - inputs = reconstruct_scene(inputs) - labels = reconstruct_scene(labels) - prdctn = reconstruct_scene(prdctn) + def train_domain_adaptation(self, epoch): + """Train a model for an epoch on the source and target domain. - # save current batch to output dictionary - output[batch] = {'input': inputs, 'labels': labels, - 'prediction': prdctn} - - # filename for the plot of the current batch - batch_name = self.basename + '_{}_{}.pt'.format(self.trg_ds.name, - batch) + This function implements deep domain adaptation by extending the + standard classification loss by a "domain adaptation loss" calculated + from unlabelled target domain samples. - # check if the current batch name exceeds the Windows limit of - # 255 characters - batch_name = self._check_long_filename(batch_name) + Parameters + ---------- + epoch : `int` + The current epoch. - # in case the source and target class labels are the same or a - # label mapping is applied, the accuracy of the prediction can be - # calculated - if self.source_is_target or self.apply_label_map: - progress += ', Accuracy: {:.2f}'.format( - accuracy_function(prdctn, labels)) - LOGGER.info(progress) + """ + # create target domain iterator + target = iter(self.trg_train_dl) - # plot current scene - if self.plot: + # increase domain adaptation weight with increasing epochs + uda_lambda = self.uda_lambda * ((epoch + 1) / self.epochs) - # in case the source and target class labels are the same or a - # label mapping is applied, plot the ground truth, otherwise - # just plot the model prediction - gt = None - if self.source_is_target or self.apply_label_map: - gt = labels + # iterate over the number of samples + for batch, (src_input, src_label) in enumerate(self.src_train_dl): - # plot inputs, ground truth and model predictions - plot_sample(inputs.clip(0, 1), - self.bands, - self.use_labels, - y=gt, - y_pred={self.model.__class__.__name__: prdctn}, - accuracy=True, - state=batch_name, - date=date, - fig=self.fig, - plot_path=self.scenes_path, - **self.kwargs) + # get the target domain input data + try: + trg_input, _ = target.next() + # in case the iterator is finished, re-instanciate it + except StopIteration: + target = iter(self.trg_train_dl) + trg_input, _ = target.next() - # save current figure state as frame for animation - if self.animate: - self.anim.frame(self.fig.axes) + # send the data to the gpu if available + src_input, src_label = (src_input.to(self.device), + src_label.to(self.device)) + trg_input = trg_input.to(self.device) - # save animation - if self.animate: - self.anim.animate(self.fig, interval=1000, repeat=True, blit=True) - self.anim.save(self.basename + '_{}.gif'.format(self.trg_ds.name), - dpi=200) + # reset the gradients + self.optimizer.zero_grad() - return output + # forward pass + src_feature, trg_feature, src_prdctn = self.uda_forward(src_input, + trg_input) - def evaluate(self): - """Evaluate a pretrained model on a defined dataset. + # compute classification loss + cla_loss = self.cla_loss_function(src_prdctn, src_label.long()) - Returns - ------- - output : `dict` [`str`, `dict`] - The inference output dictionary. The keys are either the number of - the samples (``self.predict_scene=False``) or the name of the - scenes of the target dataset (``self.predict_scene=True``). The - values are dictionaries with keys: - ``'input'`` - Model input data of the sample (:py:class:`numpy.ndarray`). - ``'labels' - Ground truth class labels (:py:class:`numpy.ndarray`). - ``'prediction'`` - Model prediction class labels (:py:class:`numpy.ndarray`). + # compute domain adaptation loss: + # the difference between source and target domain is computed + # from the compressed representation of the model encoder + uda_loss = self.uda_loss_function(src_feature, trg_feature) - """ - # plot loss and accuracy - plot_loss(check_filename_length(self.state_file), - outpath=self.perfmc_path) + # total loss + tot_loss = cla_loss + uda_lambda * uda_loss - # set the model to evaluation mode - LOGGER.info('Setting model to evaluation mode ...') - self.model.eval() - self.model.to(self.device) + # compute the gradients of the loss function w.r.t. + # the network weights + tot_loss.backward() - # initialize confusion matrix - self.conf_mat = np.zeros(shape=2 * (len(self.use_labels), )) + # update the weights + self.optimizer.step() - # log which labels the model predicts - LOGGER.info(self._label_log) + # calculate predicted class labels + ypred = F.softmax(src_prdctn, dim=1).argmax(dim=1) - # evaluate the model on the target dataset - output = self.predict() + # calculate accuracy on current batch + acc = accuracy_function(ypred, src_label) - # whether to plot the confusion matrix - if self.compute_cm: - plot_confusion_matrix(self.conf_mat, self.use_labels, - state_file=self.state_file, - subset=self.domain + '_' + self.trg_ds.name, - outpath=self.perfmc_path) + # print progress + LOGGER.info('Epoch: {:d}/{:d}, Mini-batch: {:d}/{:d}, ' + 'Cla_loss: {:.2f}, Uda_loss: {:.2f}, ' + 'Tot_loss: {:.2f}, Acc: {:.2f}' + .format(epoch + 1, self.epochs, batch + 1, + self.tmbatch, cla_loss.item(), + uda_loss.item(), tot_loss.item(), acc)) - return output + # update training metrics + self.tracker.batch_update(self.tracker.train_metrics, + [tot_loss.item(), acc, + cla_loss.item(), uda_loss.item()]) + def train_epoch(self, epoch): + """Wrap the function to train a model for a single epoch. -@dataclasses.dataclass -class LogConfig(BaseConfig): - """Logging configuration class. + Depends on whether to apply deep domain adaptation. - Generate the model log file. + Parameters + ---------- + epoch : `int` + The current epoch. - Attributes - ---------- - state_file : :py:class:`pathlib.Path` - Path to a model state file. - log_path : :py:class:`pathlib.Path` - Path to store model logs. - log_file : :py:class:`pathlib.Path` - Path to the log file of the model ``state_file``. - """ + Returns + ------- + `function` + The function to train a model for a single epoch. - state_file: pathlib.Path + """ + if self.uda: + self.train_domain_adaptation(epoch) + else: + self.train_source_domain(epoch) - def __post_init__(self): - """Check the type of each argument. + @property + def params_to_save(self): + """The parameters and variables to save in the model state file.""" + return {'src_train_dl': self.src_train_dl, + 'src_valid_dl': self.src_valid_dl, + 'src_test_dl': self.src_test_dl, + 'trg_train_dl': self.trg_train_dl, + 'trg_valid_dl': self.trg_valid_dl, + 'trg_test_dl': self.trg_test_dl, + 'uda': self.uda, + 'uda_pos': self.uda_pos, + 'uda_lambda': self.uda_lambda} + + def _build_ds_repr(self, train_dl, valid_dl, test_dl): + """Build the dataset representation. - Generate model log file. + Returns + ------- + fs : `str` + Representation string. """ - super().__post_init__() - - # the path to store model logs - self.log_path = pathlib.Path(HERE).joinpath('_logs') + # dataset configuration + fs = ' (dataset):\n ' + fs += ''.join(repr(train_dl.dataset.dataset)).replace('\n', + '\n' + 8 * ' ') + fs += '\n (batch):\n ' + fs += '- batch size: {}\n '.format(train_dl.batch_size) + fs += '- mini-batch shape (b, c, h, w): {}'.format( + ((train_dl.batch_size, len(train_dl.dataset.dataset.use_bands),) + + 2 * (train_dl.dataset.dataset.tile_size,))) + + # dataset split + fs += '\n (split):' + for dl in [train_dl, valid_dl, test_dl]: + if dl.dataset is not None: + fs += '\n' + 8 * ' ' + repr(dl.dataset) - # the log file of the current model - self.log_file = check_filename_length(self.log_path.joinpath( - self.state_file.name.replace('.pt', '.log'))) + return fs - @staticmethod - def now(): - """Return the current date and time. + def __repr__(self): + """Representation. Returns ------- - date : :py:class:`datetime.datetime` - The current date and time. + fs : `str` + Representation string. """ - return datetime.datetime.strftime(datetime.datetime.now(), - '%Y-%m-%dT%H:%M:%S') + # representation string to print + fs = self.__class__.__name__ + '(\n' - @staticmethod - def init_log(init_str): - """Generate a string to identify a new model run. + # source domain + fs += ' (source domain)\n ' + fs += self._build_ds_repr( + self.src_train_dl, self.src_valid_dl, self.src_test_dl).replace( + '\n', '\n' + 4 * ' ') - Parameters - ---------- - init_str : `str` - The string to write to the model log file. + # target domain + if self.uda: + fs += '\n (target domain)\n ' + fs += self._build_ds_repr( + self.trg_train_dl, + self.trg_valid_dl, + self.trg_test_dl).replace('\n', '\n' + 4 * ' ') - """ - LOGGER.info(80 * '-') - LOGGER.info(init_str.format(LogConfig.now())) - LOGGER.info(80 * '-') + # model configuration + fs += self._build_model_repr_() + + # domain adaptation + if self.uda: + fs += '\n (adaptation)' + '\n' + 8 * ' ' + fs += repr(self.uda_loss_function).replace('\n', '\n' + 8 * ' ') + + fs += '\n)' + return fs @dataclasses.dataclass @@ -1950,1017 +1914,1157 @@ class MetricTracker(BaseConfig): for k in self.valid_metrics}} -@dataclasses.dataclass -class NetworkTrainer(BaseConfig): - """Model training class. +class EarlyStopping(object): + """`Early Stopping`_ algorithm. + + This implementation of the early stopping algorithm advances a counter each + time a metric did not improve over a training epoch. If the metric does not + improve over more than ``patience`` epochs, the early stopping criterion is + met. - Train an instance of :py:class:`pysegcnn.core.models.Network` on a dataset - of type :py:class:`pysegcnn.core.dataset.ImageDataset`. + See the :py:meth:`pysegcnn.core.trainer.NetworkTrainer.train` method for an + example implementation. - Supports training a model on a single source domain only and on a source - and target domain using deep domain adaptation. + .. _Early Stopping: + https://en.wikipedia.org/wiki/Early_stopping Attributes ---------- - model : :py:class:`pysegcnn.core.models.Network` - The model to train. An instance of - :py:class:`pysegcnn.core.models.Network`. - optimizer : :py:class:`torch.optim.Optimizer` - The optimizer to update the model weights. An instance of - :py:class:`torch.optim.Optimizer`. - state_file : :py:class:`pathlib.Path` - Path to save the model state. - bands : `list` [`str`] - The spectral bands used to train ``model``. - src_train_dl : :py:class:`torch.utils.data.DataLoader` - The source domain training :py:class:`torch.utils.data.DataLoader` - instance build from an instance of - :py:class:`pysegcnn.core.split.CustomSubset`. - src_valid_dl : :py:class:`torch.utils.data.DataLoader` - The source domain validation :py:class:`torch.utils.data.DataLoader` - instance build from an instance of - :py:class:`pysegcnn.core.split.CustomSubset`. - src_test_dl : :py:class:`torch.utils.data.DataLoader` - The source domain test :py:class:`torch.utils.data.DataLoader` - instance build from an instance of - :py:class:`pysegcnn.core.split.CustomSubset`. - cla_loss_function : :py:class:`torch.nn.Module` - The classification loss function to compute the model error. An - instance of :py:class:`torch.nn.Module`. - trg_train_dl : :py:class:`torch.utils.data.DataLoader` - The target domain training :py:class:`torch.utils.data.DataLoader` - instance build from an instance of - :py:class:`pysegcnn.core.split.CustomSubset`. The default is an empty - :py:class:`torch.utils.data.DataLoader`. - trg_valid_dl : :py:class:`torch.utils.data.DataLoader` - The target domain validation :py:class:`torch.utils.data.DataLoader` - instance build from an instance of - :py:class:`pysegcnn.core.split.CustomSubset`. The default is an empty - :py:class:`torch.utils.data.DataLoader`. - trg_test_dl : :py:class:`torch.utils.data.DataLoader` - The target domain test :py:class:`torch.utils.data.DataLoader` - instance build from an instance of - :py:class:`pysegcnn.core.split.CustomSubset`. The default is an empty - :py:class:`torch.utils.data.DataLoader`. - uda_loss_function : :py:class:`torch.nn.Module` - The domain adaptation loss function. An instance of - :py:class:`torch.nn.Module`. - The default is :py:class:`pysegcnn.core.uda.CoralLoss`. - uda_lambda : `float` - The weight of the domain adaptation, trading off adaptation with - classification accuracy on the source domain. The default is `0`. - uda_pos : `str` - The layer where to compute the domain adaptation loss. The default - is `'enc'`, i.e. compute the domain adaptation loss using the output of - the model encoder. - epochs : `int` - The maximum number of epochs to train. The default is `1`. - nthreads : `int` - The number of cpu threads to use during training. The default is - :py:func:`torch.get_num_threads()`. - early_stop : `bool` - Whether to apply `Early Stopping`_. The default is `False`. mode : `str` - The early stopping mode. Depends on the metric measuring - performance. When using model loss as metric, use ``mode='min'``, - however, when using accuracy as metric, use ``mode='max'``. For now, - only ``mode='max'`` is supported. Only used if ``early_stop=True``. - The default is `'max'`. - delta : `float` + The early stopping mode. + best : `float` + Best metric score. + min_delta : `float` Minimum change in early stopping metric to be considered as an - improvement. Only used if ``early_stop=True``. The default is `0`. + improvement. patience : `int` - The number of epochs to wait for an improvement in the early stopping - metric. If the model does not improve over more than ``patience`` - epochs, quit training. Only used if ``early_stop=True``. The default is - `10`. - checkpoint_state : `dict` [`str`, :py:class:`numpy.ndarray`] - A model checkpoint for ``model``. If specified, ``checkpoint_state`` - should be a dictionary with keys describing the training metric. - The default is `{}`. - save : `bool` - Whether to save the model state to ``state_file``. The default is - `True`. - device : `str` - The device to train the model on, i.e. `cpu` or `cuda`. - tracker : :py:class:`pysegcnn.core.trainer.MetricTracker` - A :py:class:`pysegcnn.core.trainer.MetricTracker` instance tracking - training metrics, i.e. loss and accuracy. - uda : `bool` - Whether to apply deep domain adaptation. - max_accuracy : `float` - Maximum accuracy of ``model`` on the validation dataset. - es : `None` or :py:class:`pysegcnn.core.trainer.EarlyStopping` - The early stopping instance if ``early_stop=True``, else `None`. - tmbatch : `int` - Number of mini-batches in the training dataset. - vmbatch : `int` - Number of mini-batches in the validation dataset. - training_state : `dict` [`str`, :py:class:`numpy.ndarray`] - The training state dictionary. The keys describe the type of the - training metric. - - .. _Early Stopping: - https://en.wikipedia.org/wiki/Early_stopping + The number of epochs to wait for an improvement. + is_better : `function` + Function indicating whether the metric improved. + early_stop : `bool` + Whether the early stopping criterion is met. + counter : `int` + The counter advancing each time a metric does not improve. """ - model: Network - optimizer: Optimizer - state_file: pathlib.Path - bands: list - src_train_dl: DataLoader - src_valid_dl: DataLoader - src_test_dl: DataLoader - cla_loss_function: nn.Module - trg_train_dl: DataLoader = DataLoader(None) - trg_valid_dl: DataLoader = DataLoader(None) - trg_test_dl: DataLoader = DataLoader(None) - uda_loss_function: nn.Module = CoralLoss(uda_lambda=0) - uda_lambda: float = 0 - uda_pos: str = 'enc' - epochs: int = 1 - nthreads: int = torch.get_num_threads() - early_stop: bool = False - mode: str = 'max' - delta: float = 0 - patience: int = 10 - checkpoint_state: dict = dataclasses.field(default_factory={}) - save: bool = True + def __init__(self, mode='max', best=0, min_delta=0, patience=10): + """Initialize. + + Parameters + ---------- + mode : `str`, optional + The early stopping mode. Depends on the metric measuring + performance. When using model loss as metric, use ``mode='min'``, + however, when using accuracy as metric, use ``mode='max'``. For + now, only ``mode='max'`` is supported. Only used if + ``early_stop=True``. The default is `'max'`. + best : `float`, optional + Threshold indicating the best metric score. At instanciation, set + ``best`` to the worst possible score of the metric. ``best`` will + be overwritten during training. The default is `0`. + min_delta : `float`, optional + Minimum change in early stopping metric to be considered as an + improvement. Only used if ``early_stop=True``. The default is `0`. + patience : `int`, optional + The number of epochs to wait for an improvement in the early + stopping metric. If the model does not improve over more than + ``patience`` epochs, quit training. Only used if + ``early_stop=True``. The default is `10`. + + Raises + ------ + ValueError + Raised if ``mode`` is not either 'min' or 'max'. + + """ + # check if mode is correctly specified + if mode not in ['min', 'max']: + raise ValueError('Mode "{}" not supported. ' + 'Mode is either "min" (check whether the metric ' + 'decreased, e.g. loss) or "max" (check whether ' + 'the metric increased, e.g. accuracy).' + .format(mode)) + + # mode to determine if metric improved + self.mode = mode + + # whether to check for an increase or a decrease in a given metric + self.is_better = self.decreased if mode == 'min' else self.increased + + # minimum change in metric to be considered as an improvement + self.min_delta = min_delta + + # number of epochs to wait for improvement + self.patience = patience + + # initialize best metric + self.best = best + + # initialize early stopping flag + self.early_stop = False + + # initialize the early stop counter + self.counter = 0 + + def stop(self, metric): + """Advance early stopping counter. + + Parameters + ---------- + metric : `float` + The current metric score. + + Returns + ------- + early_stop : `bool` + Whether the early stopping criterion is met. + + """ + # if the metric improved, reset the epochs counter, else, advance + if self.is_better(metric, self.best, self.min_delta): + self.counter = 0 + self.best = metric + else: + self.counter += 1 + LOGGER.info('Early stopping counter: {}/{}'.format( + self.counter, self.patience)) + + # if the metric did not improve over the last patience epochs, + # the early stopping criterion is met + if self.counter >= self.patience: + LOGGER.info('Early stopping criterion met, stopping training.') + self.early_stop = True + + return self.early_stop + + def decreased(self, metric, best, min_delta): + """Whether a metric decreased with respect to a best score. + + Measure improvement for metrics that are considered as 'better' when + they decrease, e.g. model loss, mean squared error, etc. + + Parameters + ---------- + metric : `float` + The current score. + best : `float` + The current best score. + min_delta : `float` + Minimum change to be considered as an improvement. + + Returns + ------- + `bool` + Whether the metric improved. + + """ + return metric < best - min_delta + + def increased(self, metric, best, min_delta): + """Whether a metric increased with respect to a best score. + + Measure improvement for metrics that are considered as 'better' when + they increase, e.g. accuracy, precision, recall, etc. + + Parameters + ---------- + metric : `float` + The current score. + best : `float` + The current best score. + min_delta : `float` + Minimum change to be considered as an improvement. + + Returns + ------- + `bool` + Whether the metric improved. + + """ + return metric > best + min_delta + + def __repr__(self): + """Representation. + + Returns + ------- + fs : `str` + Representation string. + + """ + fs = self.__class__.__name__ + fs += '(mode={}, best={:.2f}, delta={}, patience={})'.format( + self.mode, self.best, self.min_delta, self.patience) + + return fs + + +@dataclasses.dataclass +class NetworkInference(BaseConfig): + """Model inference configuration. + + Evaluate a model. + + Attributes + ---------- + state_file : :py:class:`pathlib.Path` + Path to the model to evaluate. + implicit : `bool` + Whether to evaluate the model on the datasets defined at training time. + domain : `str` + Whether to evaluate on the source domain (``domain='src'``), i.e. the + domain the model in ``state_file`` was trained on, or a target domain + (``domain='trg'``). + test : `bool` or `None` + Whether to evaluate the model on the training(``test=None``), the + validation (``test=False``) or the test set (``test=True``). + map_labels : `bool` + Whether to map the model labels from the model source domain to the + defined ``domain`` in case the domain class labels differ. + ds : `dict` + The dataset configuration dictionary passed to + :py:class:`pysegcnn.core.trainer.DatasetConfig` when evaluating on + an explicitly defined dataset, i.e. ``implicit=False``. + ds_split : `dict` + The dataset split configuration dictionary passed to + :py:class:`pysegcnn.core.trainer.SplitConfig` when evaluating on + an explicitly defined dataset, i.e. ``implicit=False``. + predict_scene : `bool` + The model prediction order. If False, the samples (tiles) of a dataset + are predicted in any order and the scenes are not reconstructed. + If True, the samples (tiles) are ordered according to the scene they + belong to and a model prediction for each entire reconstructed scene is + returned. The default is `False`. + plot_samples : `bool` + Whether to save a plot of false color composite, ground truth and model + prediction for each sample (tile). Only used if ``predict_scene=False`` + . The default is `False`. + plot_scenes : `bool` + Whether to save a plot of false color composite, ground truth and model + prediction for each entire scene. Only used if ``predict_scene=True``. + The default is `False`. + plot_bands : `list` [`str`] + The bands to build the false color composite. The default is + `['nir', 'red', 'green']`. + cm : `bool` + Whether to compute and plot the confusion matrix. The default is `True` + . + figsize : `tuple` + The figure size in centimeters. The default is `(10, 10)`. + alpha : `int` + The level of the percentiles for contrast stretching of the false color + compsite. The default is `0`, i.e. no stretching. + animate : `bool` + Whether to create an animation of (input, ground truth, prediction) for + the scenes in the train/validation/test dataset. Only works if + ``predict_scene=True`` and ``plot_scene=True``. + device : `str` + The device to evaluate the model on, i.e. `cpu` or `cuda`. + base_path : :py:class:`pathlib.Path` + Root path to store model output. + sample_path : :py:class:`pathlib.Path` + Path to store plots of model predictions for single samples. + scenes_path : :py:class:`pathlib.Path` + Path to store plots of model predictions for entire scenes. + perfmc_path : :py:class:`pathlib.Path` + Path to store plots of model performance, e.g. confusion matrix. + animtn_path : :py:class:`pathlib.Path` + Path to store animations. + models_path : :py:class:`pathlib.Path` + Path to search for model state files, i.e. pretrained models. + kwargs : `dict` + Keyword arguments for :py:func:`pysegcnn.core.graphics.plot_sample` + basename : `str` + Base filename for each plot. + model : :py:class:`pysegcnn.core.models.Network` + The model to use for inference. + model_state : `dict` + A dictionary containing the model and optimizer state, as + constructed by :py:meth:`~pysegcnn.core.Network.save`. + trg_ds : :py:class:`pysegcnn.core.split.CustomSubset` + The dataset to evaluate ``model`` on. + src_ds : :py:class:`pysegcnn.core.split.CustomSubset` + The model source domain training dataset. + fig : :py:class:`matplotlib.figure.Figure` + A :py:class:`matplotlib.figure.Figure` instance to iteratively plot to. + anim : :py:class:`pysegcnn.core.graphics.Animate` + An instance :py:class:`pysegcnn.core.graphics.Animate` Used to create + animations if ``animate=True``. + conf_mat : :py:class:`numpy.ndarray` + The model confusion matrix. + + """ + + state_file: pathlib.Path + implicit: bool + domain: str + test: object + map_labels: bool + ds: dict = dataclasses.field(default_factory={}) + ds_split: dict = dataclasses.field(default_factory={}) + predict_scene: bool = False + plot_samples: bool = False + plot_scenes: bool = False + plot_bands: list = dataclasses.field( + default_factory=lambda: ['nir', 'red', 'green']) + cm: bool = True + figsize: tuple = (10, 10) + alpha: int = 5 + animate: bool = False def __post_init__(self): """Check the type of each argument. - Configure the device to train the model on, i.e. train on the gpu if - available. - - Configure early stopping if required. + Configure figure output paths. - Initialize training metric tracking. + Raises + ------ + TypeError + Raised if ``test`` is not of type `bool` or `None`. + ValueError + Raised if ``domain`` is not 'src' or 'trg'. """ super().__post_init__() - # the device to train the model on - self.device = torch.device('cuda:0' if torch.cuda.is_available() else - 'cpu') - # set the number of threads - torch.set_num_threads(self.nthreads) - - # send the model to the gpu if available - self.model = self.model.to(self.device) - - # instanciate metric tracker - self.tracker = MetricTracker( - train_metrics=['train_loss', 'train_accu'], - valid_metrics=['valid_loss', 'valid_accu']) - - # whether to train using supervised transfer learning or - # deep domain adaptation + # check whether the test input parameter is correctly specified + if self.test not in [None, False, True]: + raise TypeError('Expected "test" to be None, True or False, got ' + '{}.'.format(self.test)) - # dummy variables for easy model evaluation - self.uda = False - if self.trg_train_dl.dataset is not None and self.uda_lambda > 0: + # check whether the domain is correctly specified + if self.domain not in ['src', 'trg']: + raise ValueError('Expected "domain" to be "src" or "trg", got {}.' + .format(self.domain)) - # set the device for computing domain adaptation loss - self.uda_loss_function.device = self.device + # the device to compute on, use gpu if available + self.device = torch.device("cuda:0" if torch.cuda.is_available() else + "cpu") - # adjust metrics and initialize metric tracker - self.tracker.train_metrics.extend(['cla_loss', 'uda_loss']) + # the output paths for the different graphics + self.base_path = pathlib.Path(HERE) + self.sample_path = self.base_path.joinpath('_samples') + self.scenes_path = self.base_path.joinpath('_scenes') + self.perfmc_path = self.base_path.joinpath('_graphics') + self.animtn_path = self.base_path.joinpath('_animations') - # train using deep domain adaptation - self.uda = True + # input path for model state files + self.models_path = self.base_path.joinpath('_models') + self.state_file = self.models_path.joinpath(self.state_file) - # forward function for deep domain adaptation - self.uda_forward = self._uda_frwd() + # initialize logging + log = LogConfig(self.state_file) + dictConfig(log_conf(log.log_file)) + log.init_log('{}: ' + 'Evaluating model: {}.' + .format(self.state_file.name)) - # initialize metric tracker - self.tracker.initialize() + # plotting keyword arguments + self.kwargs = {'bands': self.plot_bands, + 'alpha': self.alpha, + 'figsize': self.figsize} - # maximum accuracy on the validation set - self.max_accuracy = 0 - if self.checkpoint_state: - self.max_accuracy = self.checkpoint_state['valid_accu'].mean( - axis=0).max().item() + # base filename for each plot + self.basename = self.state_file.stem - # whether to use early stopping - self.es = None - if self.early_stop: - self.es = EarlyStopping(self.mode, self.max_accuracy, self.delta, - self.patience) + # load the model state + self.model, _, self.model_state = Network.load(self.state_file) - # number of mini-batches in the training and validation sets - self.tmbatch = len(self.src_train_dl) - self.vmbatch = len(self.src_valid_dl) + # load the target dataset: dataset to evaluate the model on + self.trg_ds = self.load_dataset() - # log representation - LOGGER.info(repr(self)) + # load the source dataset: dataset the model was trained on + self.src_ds = self.model_state['src_train_dl'].dataset.dataset - # initialize training log - LOGGER.info(35 * '-' + ' Training ' + 35 * '-') + # create a figure to use for plotting + self.fig, _ = plt.subplots(1, 3, figsize=self.kwargs['figsize']) - # log the device and number of threads - LOGGER.info('Device: {}'.format(self.device)) - LOGGER.info('Number of cpu threads: {}'.format(self.nthreads)) + # check if the animate parameter is correctly specified + if self.animate: + if not self.plot: + LOGGER.warning('animate requires plot_scenes=True or ' + 'plot_samples=True. Can not create animation.') + else: + # check whether the output path is valid + if not self.anim_path.exists(): + # create output path + self.anim_path.mkdir(parents=True, exist_ok=True) + self.anim = Animate(self.anim_path) - def _train_source_domain(self, epoch): - """Train a model for an epoch on the source domain. + @staticmethod + def get_scene_tiles(ds, scene_id): + """Return the tiles of the scene with id ``scene_id``. Parameters ---------- - epoch : `int` - The current epoch. - - """ - # iterate over the dataloader object - for batch, (inputs, labels) in enumerate(self.src_train_dl): - - # send the data to the gpu if available - inputs = inputs.to(self.device) - labels = labels.to(self.device) - - # reset the gradients - self.optimizer.zero_grad() - - # perform forward pass - outputs = self.model(inputs) - - # compute loss - loss = self.cla_loss_function(outputs, labels.long()) - - # compute the gradients of the loss function w.r.t. - # the network weights - loss.backward() - - # update the weights - self.optimizer.step() - - # calculate predicted class labels - ypred = F.softmax(outputs, dim=1).argmax(dim=1) - - # calculate accuracy on current batch - acc = accuracy_function(ypred, labels) - - # print progress - LOGGER.info('Epoch: {:d}/{:d}, Mini-batch: {:d}/{:d}, ' - 'Loss: {:.2f}, Accuracy: {:.2f}' - .format(epoch + 1, self.epochs, batch + 1, - self.tmbatch, loss.item(), acc)) - - # update training metrics - self.tracker.batch_update(self.tracker.train_metrics, - [loss.item(), acc]) - - def _enc_uda(self, src_input, trg_input): - - # perform forward pass: encoded source domain features - src_feature = self.model.encoder(src_input) - src_dec_feature = self.model.decoder(src_feature, - self.model.encoder.cache) - # model logits on source domain - src_prdctn = self.model.classifier(src_dec_feature) - del self.model.encoder.cache # clear intermediate encoder outputs - - # perform forward pass: encoded target domain features - trg_feature = self.model.encoder(trg_input) - - return src_feature, trg_feature, src_prdctn - - def _dec_uda(self, src_input, trg_input): - - # perform forward pass: decoded source domain features - src_feature = self.model.encoder(src_input) - src_feature = self.model.decoder(src_feature, - self.model.encoder.cache) - # model logits on source domain - src_prdctn = self.model.classifier(src_feature) - del self.model.encoder.cache # clear intermediate encoder outputs - - # perform forward pass: decoded target domain features - trg_feature = self.model.encoder(trg_input) - trg_feature = self.model.decoder(trg_feature, - self.model.encoder.cache) - del self.model.encoder.cache - - return src_feature, trg_feature, src_prdctn + ds : :py:class:`pysegcnn.core.split.CustomSubset` + A instance of a subclass of + :py:class:`pysegcnn.core.split.CustomSubset`. + scene_id : `str` + A valid scene identifier. - def _cla_uda(self, src_input, trg_input): + Raises + ------ + ValueError + Raised if ``scene_id`` is not a valid scene identifier for the + dataset ``ds``. - # perform forward pass: classified source domain features - src_feature = self.model(src_input) + Returns + ------- + indices : `list` [`int`] + List of indices of the tiles of the scene with id ``scene_id`` in + ``ds``. + date : :py:class:`datetime.datetime` + The date of the scene with id ``scene_id``. - # perform forward pass: target domain features - trg_feature = self.model(trg_input) + """ + # check if the scene id is valid + scene_meta = ds.dataset.parse_scene_id(scene_id) + if scene_meta is None: + raise ValueError('{} is not a valid scene identifier' + .format(scene_id)) - return src_feature, trg_feature, src_feature + # iterate over the scenes of the dataset + indices = [] + for i, scene in enumerate(ds.scenes): + # if the scene id matches a given id, save the index of the scene + if scene['id'] == scene_id: + indices.append(i) - def _uda_frwd(self): - if self.uda_pos == 'enc': - forward = self._enc_uda + return indices, scene_meta['date'] - if self.uda_pos == 'dec': - forward = self._dec_uda + @staticmethod + def replace_dataset_path(ds, drive_path): + """Replace the path to the datasets. - if self.uda_pos == 'cla': - forward = self._cla_uda + Useful to evaluate models on machines, that are different from the + machine the model was trained on. - return forward + .. important:: - def _train_domain_adaptation(self, epoch): - """Train a model for an epoch on the source and target domain. + This function assumes that the datasets are stored in a directory + named "Datasets" on each machine. - This function implements deep domain adaptation by extending the - standard classification loss by a "domain adaptation loss" calculated - from unlabelled target domain samples. + See ``DRIVE_PATH`` in :py:mod:`pysegcnn.main.config`. Parameters ---------- - epoch : `int` - The current epoch. + ds : :py:class:`pysegcnn.core.split.CustomSubset` + A subset of an instance of + :py:class:`pysegcnn.core.dataset.ImageDataset`. + drive_path : `str` + Base path to the datasets on the current machine. ``drive_path`` + should end with `'Datasets'`. + + Raises + ------ + TypeError + Raised if ``ds`` is not an instance of + :py:class:`pysegcnn.core.split.CustomSubset` and if ``ds`` is not + a subset of an instance of + :py:class:`pysegcnn.core.dataset.ImageDataset`. """ - # create target domain iterator - target = iter(self.trg_train_dl) + # check input type + if isinstance(ds, CustomSubset): + # check the type of the dataset + if not isinstance(ds.dataset, ImageDataset): + raise TypeError('ds should be a subset created from a {}.' + .format(repr(ImageDataset))) + else: + raise TypeError('ds should be an instance of {}.' + .format(repr(CustomSubset))) - # increase domain adaptation weight with increasing epochs - uda_lambda = self.uda_lambda * ((epoch + 1) / self.epochs) + # iterate over the scenes of the dataset + for scene in ds.dataset.scenes: + for k, v in scene.items(): + # do only look for paths + if isinstance(v, str) and k != 'id': - # iterate over the number of samples - for batch, (src_input, src_label) in enumerate(self.src_train_dl): + # drive path: match path before "Datasets" + # dpath = re.search('^(.*)(?=(/.*Datasets))', v) - # get the target domain input data - try: - trg_input, _ = target.next() - # in case the iterator is finished, re-instanciate it - except StopIteration: - target = iter(self.trg_train_dl) - trg_input, _ = target.next() + # drive path: match path up to "Datasets" + dpath = re.search('^(.*?Datasets)', v)[0] - # send the data to the gpu if available - src_input, src_label = (src_input.to(self.device), - src_label.to(self.device)) - trg_input = trg_input.to(self.device) + # replace drive path + if dpath != drive_path: + scene[k] = v.replace(str(dpath), drive_path) - # reset the gradients - self.optimizer.zero_grad() + def load_dataset(self): + """Load the defined dataset. - # forward pass - src_feature, trg_feature, src_prdctn = self.uda_forward(src_input, - trg_input) + Raises + ------ + ValueError + Raised if the requested dataset was not available at training time, + if ``implicit=True``. - # compute classification loss - cla_loss = self.cla_loss_function(src_prdctn, src_label.long()) + Raised if the dataset ``ds`` does not have the same spectral bands + as the model to evaluate, if ``implicit=False``. - # compute domain adaptation loss: - # the difference between source and target domain is computed - # from the compressed representation of the model encoder - uda_loss = self.uda_loss_function(src_feature, trg_feature) + Returns + ------- + ds : :py:class:`pysegcnn.core.split.CustomSubset` + The dataset to evaluate the model on. - # total loss - tot_loss = cla_loss + uda_lambda * uda_loss + """ + # check whether to evaluate on the datasets defined at training time + if self.implicit: - # compute the gradients of the loss function w.r.t. - # the network weights - tot_loss.backward() + # check whether to evaluate the model on the training, validation + # or test set + if self.test is None: + ds_set = 'train' + else: + ds_set = 'test' if self.test else 'valid' - # update the weights - self.optimizer.step() + # the dataset to evaluate the model on + ds = self.model_state[ + self.domain + '_{}_dl'.format(ds_set)].dataset + if ds is None: + raise ValueError('Requested dataset "{}" is not available.' + .format(self.domain + '_{}_dl'.format(ds_set)) + ) - # calculate predicted class labels - ypred = F.softmax(src_prdctn, dim=1).argmax(dim=1) + # log dataset representation + LOGGER.info('Evaluating on {} set of the {} domain defined at ' + 'training time.'.format(ds_set, self.domain)) - # calculate accuracy on current batch - acc = accuracy_function(ypred, src_label) + else: + # explicitly defined dataset + ds = DatasetConfig(**self.ds).init_dataset() - # print progress - LOGGER.info('Epoch: {:d}/{:d}, Mini-batch: {:d}/{:d}, ' - 'Cla_loss: {:.2f}, Uda_loss: {:.2f}, ' - 'Tot_loss: {:.2f}, Acc: {:.2f}' - .format(epoch + 1, self.epochs, batch + 1, - self.tmbatch, cla_loss.item(), - uda_loss.item(), tot_loss.item(), acc)) + # check if the spectral bands match + if ds.use_bands != self.model_state['bands']: + raise ValueError('The model was trained with bands {}, not ' + 'with bands {}.'.format( + self.model_state['bands'], ds.use_bands)) - # update training metrics - self.tracker.batch_update(self.tracker.train_metrics, - [tot_loss.item(), acc, - cla_loss.item(), uda_loss.item()]) + # split configuration + sc = SplitConfig(**self.ds_split) + train_ds, valid_ds, test_ds = sc.train_val_test_split(ds) - def train_epoch(self, epoch): - """Wrap the function to train a model for a single epoch. + # check whether to evaluate the model on the training, validation + # or test set + if self.test is None: + ds = train_ds + else: + ds = test_ds if self.test else valid_ds - Depends on whether to apply deep domain adaptation. + # log dataset representation + LOGGER.info('Evaluating on {} set of explicitly defined dataset: ' + '\n {}'.format(ds.name, repr(ds.dataset))) - Parameters - ---------- - epoch : `int` - The current epoch. + # check the dataset path: replace by path on current machine + self.replace_dataset_path(ds, DRIVE_PATH) + + return ds + + @property + def source_labels(self): + """Class labels of the source domain the model was trained on. Returns ------- - `function` - The function to train a model for a single epoch. + source_labels : `dict` [`int`, `dict`] + The class labels of the source domain. """ - if self.uda: - self._train_domain_adaptation(epoch) - else: - self._train_source_domain(epoch) + return self.src_ds.labels - def train(self): - """Train the model. + @property + def target_labels(self): + """Class labels of the dataset to evaluate. Returns ------- - training_state : `dict` [`str`, :py:class:`numpy.ndarray`] - The training state dictionary. The keys describe the type of the - training metric. See - :py:meth:`~pysegcnn.core.trainer.NetworkTrainer.training_state`. + target_labels : `dict` [`int`, `dict`] + The class labels of the dataset to evaluate. """ - # initialize the training: iterate over the entire training dataset - for epoch in range(self.epochs): - - # set the model to training mode - LOGGER.info('Setting model to training mode ...') - self.model.train() - - # train model for a single epoch - self.train_epoch(epoch) - - # update the number of epochs trained - self.model.epoch += 1 - - # whether to evaluate model performance on the validation set and - # early stop the training process - if self.early_stop: - - # model predictions on the validation set - valid_accu, valid_loss = self.predict(self.src_valid_dl) - - # update validation metrics - self.tracker.batch_update(self.tracker.valid_metrics, - [valid_loss, valid_accu]) + return self.trg_ds.dataset.labels - # metric to assess model performance on the validation set - epoch_acc = np.mean(valid_accu) + # @property + # def label_map(self): + # """Label mapping from the source to the target domain. - # whether the model improved with respect to the previous epoch - if self.es.increased(epoch_acc, self.max_accuracy, self.delta): - self.max_accuracy = epoch_acc + # See :py:func:`pysegcnn.core.constants.map_labels`. - # save model state if the model improved with - # respect to the previous epoch - self.save_state() + # Returns + # ------- + # label_map : `dict` [`int`, `int`] + # Dictionary with source labels as keys and corresponding target + # labels as values. - # whether the early stopping criterion is met - if self.es.stop(epoch_acc): - break + # """ + # # check whether the source domain labels are the same as the target + # # domain labels + # return map_labels(self.src_ds.get_labels(), + # self.trg_ds.dataset.get_labels()) - else: - # if no early stopping is required, the model state is - # saved after each epoch - self.save_state() + @property + def source_is_target(self): + """Whether the source and target domain labels are the same. - return self.training_state + Returns + ------- + source_is_target : `bool` + `True` if the source and target domain labels are the same, `False` + if not. - def predict(self, dataloader): - """Model inference at training time. + """ + return self.label_map is None - Parameters - ---------- - dataloader : :py:class:`torch.utils.data.DataLoader` - The validation dataloader to evaluate the model predictions. + @property + def apply_label_map(self): + """Whether to map source labels to target labels. Returns ------- - accuracy : :py:class:`numpy.ndarray` - The mean model prediction accuracy on each mini-batch in the - validation set. - loss : :py:class:`numpy.ndarray` - The model loss for each mini-batch in the validation set. + apply_label_map : `bool` + `True` if source and target labels differ and label mapping is + requested, `False` otherwise. """ - # set the model to evaluation mode - LOGGER.info('Setting model to evaluation mode ...') - self.model.eval() - - # create arrays of the observed loss and accuracy - accuracy = [] - loss = [] + return not self.source_is_target and self.map_labels - # iterate over the validation/test set - LOGGER.info('Calculating accuracy on the validation set ...') - for batch, (inputs, labels) in enumerate(dataloader): + @property + def use_labels(self): + """Labels to be predicted. - # send the data to the gpu if available - inputs = inputs.to(self.device) - labels = labels.to(self.device) + Returns + ------- + use_labels : `dict` [`int`, `dict`] + The labels of the classes to be predicted. - # calculate network outputs - with torch.no_grad(): - outputs = self.model(inputs) + """ + return (self.target_labels if self.apply_label_map else + self.source_labels) - # compute loss - cla_loss = self.cla_loss_function(outputs, labels.long()) - loss.append(cla_loss.item()) + @property + def bands(self): + """Spectral bands the model was trained with. - # calculate predicted class labels - pred = F.softmax(outputs, dim=1).argmax(dim=1) + Returns + ------- + bands : `list` [`str`] + A list of the named spectral bands used to train the model. - # calculate accuracy on current batch - acc = accuracy_function(pred, labels) - accuracy.append(acc) + """ + return self.src_ds.use_bands - # print progress - LOGGER.info('Mini-batch: {:d}/{:d}, Accuracy: {:.2f}' - .format(batch + 1, len(dataloader), acc)) + @property + def compute_cm(self): + """Whether to compute the confusion matrix. - # calculate overall accuracy on the validation/test set - LOGGER.info('Epoch: {:d}, Mean accuracy: {:.2f}%.' - .format(self.model.epoch, np.mean(accuracy) * 100)) + Returns + ------- + compute_cm : `bool` + Whether the confusion matrix can be computed. For datasets with + labels different from the source domain labels, the confusion + matrix can not be computed. - return accuracy, loss + """ + return (False if not self.source_is_target and not self.map_labels + else self.cm) @property - def training_state(self): - """Model training metrics. + def plot(self): + """Whether to save plots of (input, ground truth, prediction). Returns ------- - state : `dict` [`str`, :py:class:`numpy.ndarray`] - The training state dictionary. The keys describe the type of the - training metric and the values are :py:class:`numpy.ndarray`'s of - the corresponding metric observed during training with - shape=(mini_batch, epoch). + plot : `bool` + Save plots for each sample or for each scene of the target dataset, + depending on ``self.predict_scene``. """ - # current training state - state = self.tracker.np_state(self.tmbatch, self.vmbatch) + return self.plot_scenes if self.predict_scene else self.plot_samples - # optional: training state of the model checkpoint - if self.checkpoint_state: - # prepend values from checkpoint to current training state - state = {k1: np.hstack([v1, v2]) for (k1, v1), (k2, v2) in - zip(self.checkpoint_state.items(), state.items()) - if k1 == k2} + @property + def is_scene_subset(self): + """Check the type of the target dataset. - return state + Whether ``self.trg_ds`` is an instance of + :py:class:`pysegcnn.core.split.SceneSubset`, as required when + ``self.predict_scene=True``. - def save_state(self): - """Save the model state.""" - if self.save: - _ = self.model.save(self.state_file, - self.optimizer, - bands=self.bands, - nclasses=self.model.nclasses, - src_train_dl=self.src_train_dl, - src_valid_dl=self.src_valid_dl, - src_test_dl=self.src_test_dl, - trg_train_dl=self.trg_train_dl, - trg_valid_dl=self.trg_valid_dl, - trg_test_dl=self.trg_test_dl, - state=self.training_state, - uda_lambda=self.uda_lambda - ) + Returns + ------- + is_scene_subset : `bool` + Whether ``self.trg_ds`` is an instance of + :py:class:`pysegcnn.core.split.SceneSubset`. - @staticmethod - def init_network_trainer(src_ds_config, src_split_config, trg_ds_config, - trg_split_config, model_config): - """Prepare network training. + """ + return isinstance(self.trg_ds, SceneSubset) - Parameters - ---------- - src_ds_config : :py:class:`pysegcnn.core.trainer.DatasetConfig` - The source domain dataset configuration. - src_split_config : :py:class:`pysegcnn.core.trainer.SplitConfig` - The source domain dataset split configuration. - trg_ds_config : :py:class:`pysegcnn.core.trainer.DatasetConfig` - The target domain dataset configuration.. - trg_split_config : :py:class:`pysegcnn.core.trainer.SplitConfig` - The target domain dataset split configuration. - model_config : :py:class:`pysegcnn.core.trainer.ModelConfig` - The model configuration. + @property + def dataloader(self): + """Dataloader instance for model inference. Returns ------- - trainer : :py:class:`pysegcnn.core.trainer.NetworkTrainer` - A network trainer instance. - - See :py:mod:`pysegcnn.main.train.py` for an example on how to - instanciate a :py:class:`pysegcnn.core.trainer.NetworkTrainer` - instance. - - """ - # (i) instanciate the source domain configurations - src_dc = DatasetConfig(**src_ds_config) # source domain dataset - src_sc = SplitConfig(**src_split_config) # source domain dataset split - - # (ii) instanciate the target domain configuration - trg_dc = DatasetConfig(**trg_ds_config) # target domain dataset - trg_sc = SplitConfig(**trg_split_config) # target domain dataset split - - # (iii) instanciate the model configuration - mdlcfg = ModelConfig(**model_config) - - # (iv) instanciate the model state file - sttcfg = StateConfig(src_dc, src_sc, trg_dc, trg_sc, mdlcfg) - state_file = sttcfg.init_state() - - # (v) instanciate logging configuration - logcfg = LogConfig(state_file) - dictConfig(log_conf(logcfg.log_file)) - - # (vi) instanciate the source domain dataset - src_ds = src_dc.init_dataset() - - # the spectral bands used to train the model - bands = src_ds.use_bands - - # (vii) instanciate the training, validation and test datasets and - # dataloaders for the source domain - (src_train_ds, - src_valid_ds, - src_test_ds) = src_sc.train_val_test_split(src_ds) - (src_train_dl, - src_valid_dl, - src_test_dl) = src_sc.dataloaders(src_train_ds, - src_valid_ds, - src_test_ds, - batch_size=mdlcfg.batch_size, - shuffle=True, drop_last=False) - - # (viii) instanciate the loss function - cla_loss_function = mdlcfg.init_cla_loss_function() - - # (ix) check whether to apply transfer learning - if mdlcfg.transfer: - - # (a) instanciate the target domain dataset - trg_ds = trg_dc.init_dataset() - - # (b) instanciate the training, validation and test datasets and - # dataloaders for the target domain - (trg_train_ds, - trg_valid_ds, - trg_test_ds) = trg_sc.train_val_test_split(trg_ds) - (trg_train_dl, - trg_valid_dl, - trg_test_dl) = trg_sc.dataloaders(trg_train_ds, - trg_valid_ds, - trg_test_ds, - batch_size=mdlcfg.batch_size, - shuffle=True, drop_last=False) - - # (c) instanciate the model: supervised transfer learning - if mdlcfg.supervised: - model, optimizer, checkpoint_state = mdlcfg.init_model( - trg_ds, state_file) - - # (x) instanciate the network trainer - trainer = NetworkTrainer(model, - optimizer, - state_file, - bands, - trg_train_dl, - trg_valid_dl, - trg_test_dl, - cla_loss_function, - epochs=mdlcfg.epochs, - nthreads=mdlcfg.nthreads, - early_stop=mdlcfg.early_stop, - mode=mdlcfg.mode, - delta=mdlcfg.delta, - patience=mdlcfg.patience, - checkpoint_state=checkpoint_state, - save=mdlcfg.save) - - # (c) instanciate the model: unsupervised transfer learning - else: - model, optimizer, checkpoint_state = mdlcfg.init_model( - src_ds, state_file) - - # (x) instanciate the domain adaptation loss - uda_loss_function = mdlcfg.init_uda_loss_function( - mdlcfg.uda_lambda) - - # (xi) instanciate the network trainer - trainer = NetworkTrainer(model, - optimizer, - state_file, - bands, - src_train_dl, - src_valid_dl, - src_test_dl, - cla_loss_function, - trg_train_dl, - trg_valid_dl, - trg_test_dl, - uda_loss_function, - mdlcfg.uda_lambda, - mdlcfg.uda_pos, - mdlcfg.epochs, - mdlcfg.nthreads, - mdlcfg.early_stop, - mdlcfg.mode, - mdlcfg.delta, - mdlcfg.patience, - checkpoint_state, - mdlcfg.save) + dataloader : :py:class:`torch.utils.data.DataLoader` + The dataset for model inference. - else: - # (x) instanciate the model - model, optimizer, checkpoint_state = mdlcfg.init_model( - src_ds, state_file) - - # (xi) instanciate the network trainer - trainer = NetworkTrainer(model, - optimizer, - state_file, - bands, - src_train_dl, - src_valid_dl, - src_test_dl, - cla_loss_function, - epochs=mdlcfg.epochs, - nthreads=mdlcfg.nthreads, - early_stop=mdlcfg.early_stop, - mode=mdlcfg.mode, - delta=mdlcfg.delta, - patience=mdlcfg.patience, - checkpoint_state=checkpoint_state, - save=mdlcfg.save) - - return trainer - - # def _build_ds_repr(self, train_dl, valid_dl, test_dl): - # """Build the dataset representation. + """ + # build the dataloader for model inference + return DataLoader(self.trg_ds, batch_size=self._batch_size, + shuffle=False, drop_last=False) - # Returns - # ------- - # fs : `str` - # Representation string. + @property + def _original_source_labels(self): + """Original source domain labels. - # """ - # # dataset configuration - # fs = ' (dataset):\n ' - # fs += ''.join(repr(train_dl.dataset.dataset)).replace('\n', - # '\n' + 8 * ' ') - # fs += '\n (batch):\n ' - # fs += '- batch size: {}\n '.format(train_dl.batch_size) - # fs += '- mini-batch shape (b, c, h, w): {}'.format( - # ((train_dl.batch_size, len(train_dl.dataset.dataset.use_bands),) + - # 2 * (train_dl.dataset.dataset.tile_size,))) - - # # dataset split - # fs += '\n (split):' - # for dl in [train_dl, valid_dl, test_dl]: - # if dl.dataset is not None: - # fs += '\n' + 8 * ' ' + repr(dl.dataset) - - # return fs - - # def _build_model_repr_(self): - # """Build the model representation. + Since PyTorch requires class labels to be an ascending sequence + starting from 0, the actual class labels in the ground truth may differ + from the class labels fed to the model. - # Returns - # ------- - # fs : `str` - # Representation string. + Returns + ------- + original_source_labels : `dict` [`int`, `dict`] + The original class labels of the source domain. - # """ - # # model - # fs = '\n (model):' + '\n' + 8 * ' ' - # fs += ''.join(repr(self.model)).replace('\n', '\n' + 8 * ' ') + """ + return self.src_ds._labels + + @property + def _original_target_labels(self): + """Original target domain labels. + + Returns + ------- + original_target_labels : `dict` [`int`, `dict`] + The original class labels of the target domain. - # # optimizer - # fs += '\n (optimizer):' + '\n' + 8 * ' ' - # fs += ''.join(repr(self.optimizer)).replace('\n', '\n' + 8 * ' ') + """ + return self.trg_ds.dataset._labels - # # early stopping - # fs += '\n (early stop):' + '\n' + 8 * ' ' - # fs += ''.join(repr(self.es)).replace('\n', '\n' + 8 * ' ') + @property + def _label_log(self): + """Log if a label mapping is applied. - # # domain adaptation - # if self.uda: - # fs += '\n (adaptation)' + '\n' + 8 * ' ' - # fs += repr(self.uda_loss_function).replace('\n', '\n' + 8 * ' ') + Returns + ------- + log : `str` + Represenation of the label mapping. - # return fs + """ + log = 'Retaining model labels ({})'.format( + ', '.join([v['label'] for _, v in self.source_labels.items()])) + if self.apply_label_map: + log = (log.replace('Retaining', 'Mapping') + ' to target labels ' + '({}).'.format(', '.join([v['label'] for _, v in + self.target_labels.items()]))) + return log - # def __repr__(self): - # """Representation. + @property + def _batch_size(self): + """Batch size of the inference dataloader. - # Returns - # ------- - # fs : `str` - # Representation string. + Returns + ------- + batch_size : `int` + The batch size of the dataloader used for model inference. Depends + on whether to predict each sample of the target dataset + individually or whether to reconstruct each scene in the target + dataset. - # """ - # # representation string to print - # fs = self.__class__.__name__ + '(\n' + """ + return (self.trg_ds.dataset.tiles if self.predict_scene and + self.is_scene_subset else 1) - # # source domain - # fs += ' (source domain)\n ' - # fs += self._build_ds_repr( - # self.src_train_dl, self.src_valid_dl, self.src_test_dl).replace( - # '\n', '\n' + 4 * ' ') + def _check_long_filename(self, filename): + """Modify filenames that exceed Windows' maximum filename length. - # # target domain - # if self.uda: - # fs += '\n (target domain)\n ' - # fs += self._build_ds_repr( - # self.trg_train_dl, - # self.trg_valid_dl, - # self.trg_test_dl).replace('\n', '\n' + 4 * ' ') + Parameters + ---------- + filename : `str` + The filename to check. - # # model configuration - # fs += self._build_model_repr_() + Returns + ------- + filename : `str` + The modified filename, in case ``filename`` exceeds 255 characters. - # fs += '\n)' - # return fs + """ + # check for maximum path component length: Windows allows a maximum + # of 255 characters for each component in a path + if len(filename) >= 255: + # try to parse the scene identifier + scene_metadata = self.trg_ds.dataset.parse_scene_id(filename) + if scene_metadata is not None: + # new, shorter name for the current batch + batch_name = '_'.join([scene_metadata['satellite'], + scene_metadata['tile'], + datetime.datetime.strftime( + scene_metadata['date'], '%Y%m%d')]) + filename = filename.replace(scene_metadata['id'], batch_name) + return filename -class EarlyStopping(object): - """`Early Stopping`_ algorithm. + def map_to_target(self, prd): + """Map source domain labels to target domain labels. - This implementation of the early stopping algorithm advances a counter each - time a metric did not improve over a training epoch. If the metric does not - improve over more than ``patience`` epochs, the early stopping criterion is - met. + Parameters + ---------- + prd : :py:class:`torch.Tensor` + The source domain class labels as predicted by ```self.model``. - See the :py:meth:`pysegcnn.core.trainer.NetworkTrainer.train` method for an - example implementation. + Returns + ------- + prd : :py:class:`torch.Tensor` + The predicted target domain labels. - .. _Early Stopping: - https://en.wikipedia.org/wiki/Early_stopping + """ + # map actual source labels to original source labels + for aid, oid in zip(self.source_labels.keys(), + self._original_source_labels.keys()): + prd[torch.where(prd == aid)] = oid - Attributes - ---------- - mode : `str` - The early stopping mode. - best : `float` - Best metric score. - min_delta : `float` - Minimum change in early stopping metric to be considered as an - improvement. - patience : `int` - The number of epochs to wait for an improvement. - is_better : `function` - Function indicating whether the metric improved. - early_stop : `bool` - Whether the early stopping criterion is met. - counter : `int` - The counter advancing each time a metric does not improve. + # apply the label mapping + for src_label, trg_label in self.label_map.items(): + prd[torch.where(prd == src_label)] = trg_label - """ + # map original target labels to actual target labels + for oid, aid in zip(self._original_target_labels.keys(), + self.target_labels.keys()): + prd[torch.where(prd == oid)] = aid - def __init__(self, mode='max', best=0, min_delta=0, patience=10): - """Initialize. + return prd - Parameters - ---------- - mode : `str`, optional - The early stopping mode. Depends on the metric measuring - performance. When using model loss as metric, use ``mode='min'``, - however, when using accuracy as metric, use ``mode='max'``. For - now, only ``mode='max'`` is supported. Only used if - ``early_stop=True``. The default is `'max'`. - best : `float`, optional - Threshold indicating the best metric score. At instanciation, set - ``best`` to the worst possible score of the metric. ``best`` will - be overwritten during training. The default is `0`. - min_delta : `float`, optional - Minimum change in early stopping metric to be considered as an - improvement. Only used if ``early_stop=True``. The default is `0`. - patience : `int`, optional - The number of epochs to wait for an improvement in the early - stopping metric. If the model does not improve over more than - ``patience`` epochs, quit training. Only used if - ``early_stop=True``. The default is `10`. + def predict(self): + """Classify the samples of the target dataset. - Raises - ------ - ValueError - Raised if ``mode`` is not either 'min' or 'max'. + Returns + ------- + output : `dict` [`str`, `dict`] + The inference output dictionary. The keys are either the number of + the samples (``self.predict_scene=False``) or the name of the + scenes of the target dataset (``self.predict_scene=True``). The + values are dictionaries with keys: + ``'input'`` + Model input data of the sample (:py:class:`numpy.ndarray`). + ``'labels' + Ground truth class labels (:py:class:`numpy.ndarray`). + ``'prediction'`` + Model prediction class labels (:py:class:`numpy.ndarray`). """ - # check if mode is correctly specified - if mode not in ['min', 'max']: - raise ValueError('Mode "{}" not supported. ' - 'Mode is either "min" (check whether the metric ' - 'decreased, e.g. loss) or "max" (check whether ' - 'the metric increased, e.g. accuracy).' - .format(mode)) + # iterate over the samples of the target dataset + output = {} + for batch, (inputs, labels) in enumerate(self.dataloader): - # mode to determine if metric improved - self.mode = mode + # send inputs and labels to device + inputs = inputs.to(self.device) + labels = labels.to(self.device) - # whether to check for an increase or a decrease in a given metric - self.is_better = self.decreased if mode == 'min' else self.increased + # compute model predictions + with torch.no_grad(): + prdctn = F.softmax(self.model(inputs), + dim=1).argmax(dim=1).squeeze() - # minimum change in metric to be considered as an improvement - self.min_delta = min_delta + # map source labels to target dataset labels + if self.apply_label_map: + prdctn = self.map_to_target(prdctn) - # number of epochs to wait for improvement - self.patience = patience + # update confusion matrix + if self.compute_cm: + for ytrue, ypred in zip(labels.view(-1), prdctn.view(-1)): + self.conf_mat[ytrue.long(), ypred.long()] += 1 - # initialize best metric - self.best = best + # convert torch tensors to numpy arrays + inputs = inputs.numpy() + labels = labels.numpy() + prdctn = prdctn.numpy() - # initialize early stopping flag - self.early_stop = False + # progress string to log + progress = 'Sample: {:d}/{:d}'.format(batch + 1, + len(self.dataloader)) - # initialize the early stop counter - self.counter = 0 + # check whether to reconstruct the scene + date = None + if self.dataloader.batch_size > 1: - def stop(self, metric): - """Advance early stopping counter. + # id and date of the current scene + batch = self.trg_ds.ids[batch] + if self.trg_ds.split_mode == 'date': + date = self.trg_ds.dataset.parse_scene_id(batch)['date'] - Parameters - ---------- - metric : `float` - The current metric score. + # modify the progress string + progress = progress.replace('Sample', 'Scene') + progress += ' Id: {}'.format(batch) - Returns - ------- - early_stop : `bool` - Whether the early stopping criterion is met. + # reconstruct the entire scene + inputs = reconstruct_scene(inputs) + labels = reconstruct_scene(labels) + prdctn = reconstruct_scene(prdctn) - """ - # if the metric improved, reset the epochs counter, else, advance - if self.is_better(metric, self.best, self.min_delta): - self.counter = 0 - self.best = metric - else: - self.counter += 1 - LOGGER.info('Early stopping counter: {}/{}'.format( - self.counter, self.patience)) + # save current batch to output dictionary + output[batch] = {'input': inputs, 'labels': labels, + 'prediction': prdctn} - # if the metric did not improve over the last patience epochs, - # the early stopping criterion is met - if self.counter >= self.patience: - LOGGER.info('Early stopping criterion met, stopping training.') - self.early_stop = True + # filename for the plot of the current batch + batch_name = self.basename + '_{}_{}.pt'.format(self.trg_ds.name, + batch) - return self.early_stop + # check if the current batch name exceeds the Windows limit of + # 255 characters + batch_name = self._check_long_filename(batch_name) - def decreased(self, metric, best, min_delta): - """Whether a metric decreased with respect to a best score. + # in case the source and target class labels are the same or a + # label mapping is applied, the accuracy of the prediction can be + # calculated + if self.source_is_target or self.apply_label_map: + progress += ', Accuracy: {:.2f}'.format( + accuracy_function(prdctn, labels)) + LOGGER.info(progress) - Measure improvement for metrics that are considered as 'better' when - they decrease, e.g. model loss, mean squared error, etc. + # plot current scene + if self.plot: - Parameters - ---------- - metric : `float` - The current score. - best : `float` - The current best score. - min_delta : `float` - Minimum change to be considered as an improvement. + # in case the source and target class labels are the same or a + # label mapping is applied, plot the ground truth, otherwise + # just plot the model prediction + gt = None + if self.source_is_target or self.apply_label_map: + gt = labels - Returns - ------- - `bool` - Whether the metric improved. + # plot inputs, ground truth and model predictions + plot_sample(inputs.clip(0, 1), + self.bands, + self.use_labels, + y=gt, + y_pred={self.model.__class__.__name__: prdctn}, + accuracy=True, + state=batch_name, + date=date, + fig=self.fig, + plot_path=self.scenes_path, + **self.kwargs) - """ - return metric < best - min_delta + # save current figure state as frame for animation + if self.animate: + self.anim.frame(self.fig.axes) - def increased(self, metric, best, min_delta): - """Whether a metric increased with respect to a best score. + # save animation + if self.animate: + self.anim.animate(self.fig, interval=1000, repeat=True, blit=True) + self.anim.save(self.basename + '_{}.gif'.format(self.trg_ds.name), + dpi=200) - Measure improvement for metrics that are considered as 'better' when - they increase, e.g. accuracy, precision, recall, etc. + return output - Parameters - ---------- - metric : `float` - The current score. - best : `float` - The current best score. - min_delta : `float` - Minimum change to be considered as an improvement. + def evaluate(self): + """Evaluate a pretrained model on a defined dataset. Returns ------- - `bool` - Whether the metric improved. + output : `dict` [`str`, `dict`] + The inference output dictionary. The keys are either the number of + the samples (``self.predict_scene=False``) or the name of the + scenes of the target dataset (``self.predict_scene=True``). The + values are dictionaries with keys: + ``'input'`` + Model input data of the sample (:py:class:`numpy.ndarray`). + ``'labels' + Ground truth class labels (:py:class:`numpy.ndarray`). + ``'prediction'`` + Model prediction class labels (:py:class:`numpy.ndarray`). """ - return metric > best + min_delta + # plot loss and accuracy + plot_loss(check_filename_length(self.state_file), + outpath=self.perfmc_path) - def __repr__(self): - """Representation. + # set the model to evaluation mode + LOGGER.info('Setting model to evaluation mode ...') + self.model.eval() + self.model.to(self.device) - Returns - ------- - fs : `str` - Representation string. + # initialize confusion matrix + self.conf_mat = np.zeros(shape=2 * (len(self.use_labels), )) - """ - fs = self.__class__.__name__ - fs += '(mode={}, best={:.2f}, delta={}, patience={})'.format( - self.mode, self.best, self.min_delta, self.patience) + # log which labels the model predicts + LOGGER.info(self._label_log) - return fs + # evaluate the model on the target dataset + output = self.predict() + + # whether to plot the confusion matrix + if self.compute_cm: + plot_confusion_matrix(self.conf_mat, self.use_labels, + state_file=self.state_file, + subset=self.domain + '_' + self.trg_ds.name, + outpath=self.perfmc_path) + + return output + +# def init_network_trainer(src_ds_config, src_split_config, trg_ds_config, +# trg_split_config, model_config): +# """Prepare network training. + +# Parameters +# ---------- +# src_ds_config : :py:class:`pysegcnn.core.trainer.DatasetConfig` +# The source domain dataset configuration. +# src_split_config : :py:class:`pysegcnn.core.trainer.SplitConfig` +# The source domain dataset split configuration. +# trg_ds_config : :py:class:`pysegcnn.core.trainer.DatasetConfig` +# The target domain dataset configuration.. +# trg_split_config : :py:class:`pysegcnn.core.trainer.SplitConfig` +# The target domain dataset split configuration. +# model_config : :py:class:`pysegcnn.core.trainer.ModelConfig` +# The model configuration. + +# Returns +# ------- +# trainer : :py:class:`pysegcnn.core.trainer.NetworkTrainer` +# A network trainer instance. + +# See :py:mod:`pysegcnn.main.train.py` for an example on how to +# instanciate a :py:class:`pysegcnn.core.trainer.NetworkTrainer` +# instance. + +# """ +# # (i) instanciate the source domain configurations +# src_dc = DatasetConfig(**src_ds_config) # source domain dataset +# src_sc = SplitConfig(**src_split_config) # source domain dataset split + +# # (ii) instanciate the target domain configuration +# trg_dc = DatasetConfig(**trg_ds_config) # target domain dataset +# trg_sc = SplitConfig(**trg_split_config) # target domain dataset split + +# # (iii) instanciate the model configuration +# mdlcfg = ModelConfig(**model_config) + +# # (iv) instanciate the model state file +# sttcfg = StateConfig(src_dc, src_sc, trg_dc, trg_sc, mdlcfg) +# state_file = sttcfg.init_state() + +# # (v) instanciate logging configuration +# logcfg = LogConfig(state_file) +# dictConfig(log_conf(logcfg.log_file)) + +# # (vi) instanciate the source domain dataset +# src_ds = src_dc.init_dataset() + +# # the spectral bands used to train the model +# bands = src_ds.use_bands + +# # (vii) instanciate the training, validation and test datasets and +# # dataloaders for the source domain +# (src_train_ds, +# src_valid_ds, +# src_test_ds) = src_sc.train_val_test_split(src_ds) +# (src_train_dl, +# src_valid_dl, +# src_test_dl) = src_sc.dataloaders(src_train_ds, +# src_valid_ds, +# src_test_ds, +# batch_size=mdlcfg.batch_size, +# shuffle=True, drop_last=False) + +# # (viii) instanciate the loss function +# cla_loss_function = mdlcfg.init_cla_loss_function() + +# # (ix) check whether to apply transfer learning +# if mdlcfg.transfer: + +# # (a) instanciate the target domain dataset +# trg_ds = trg_dc.init_dataset() + +# # (b) instanciate the training, validation and test datasets and +# # dataloaders for the target domain +# (trg_train_ds, +# trg_valid_ds, +# trg_test_ds) = trg_sc.train_val_test_split(trg_ds) +# (trg_train_dl, +# trg_valid_dl, +# trg_test_dl) = trg_sc.dataloaders(trg_train_ds, +# trg_valid_ds, +# trg_test_ds, +# batch_size=mdlcfg.batch_size, +# shuffle=True, drop_last=False) + +# # (c) instanciate the model: supervised transfer learning +# if mdlcfg.supervised: +# model, optimizer, checkpoint_state = mdlcfg.init_model( +# trg_ds, state_file) + +# # (x) instanciate the network trainer +# trainer = NetworkTrainer(model, +# optimizer, +# state_file, +# bands, +# trg_train_dl, +# trg_valid_dl, +# trg_test_dl, +# cla_loss_function, +# epochs=mdlcfg.epochs, +# nthreads=mdlcfg.nthreads, +# early_stop=mdlcfg.early_stop, +# mode=mdlcfg.mode, +# delta=mdlcfg.delta, +# patience=mdlcfg.patience, +# checkpoint_state=checkpoint_state, +# save=mdlcfg.save) + +# # (c) instanciate the model: unsupervised transfer learning +# else: +# model, optimizer, checkpoint_state = mdlcfg.init_model( +# src_ds, state_file) + +# # (x) instanciate the domain adaptation loss +# uda_loss_function = mdlcfg.init_uda_loss_function( +# mdlcfg.uda_lambda) + +# # (xi) instanciate the network trainer +# trainer = NetworkTrainer(model, +# optimizer, +# state_file, +# bands, +# src_train_dl, +# src_valid_dl, +# src_test_dl, +# cla_loss_function, +# trg_train_dl, +# trg_valid_dl, +# trg_test_dl, +# uda_loss_function, +# mdlcfg.uda_lambda, +# mdlcfg.uda_pos, +# mdlcfg.epochs, +# mdlcfg.nthreads, +# mdlcfg.early_stop, +# mdlcfg.mode, +# mdlcfg.delta, +# mdlcfg.patience, +# checkpoint_state, +# mdlcfg.save) + +# else: +# # (x) instanciate the model +# model, optimizer, checkpoint_state = mdlcfg.init_model( +# src_ds, state_file) + +# # (xi) instanciate the network trainer +# trainer = NetworkTrainer(model, +# optimizer, +# state_file, +# bands, +# src_train_dl, +# src_valid_dl, +# src_test_dl, +# cla_loss_function, +# epochs=mdlcfg.epochs, +# nthreads=mdlcfg.nthreads, +# early_stop=mdlcfg.early_stop, +# mode=mdlcfg.mode, +# delta=mdlcfg.delta, +# patience=mdlcfg.patience, +# checkpoint_state=checkpoint_state, +# save=mdlcfg.save) + +# return trainer