diff --git a/pysegcnn/core/trainer.py b/pysegcnn/core/trainer.py index f7ac066361124da8852f053216ef504c79ef3a30..77b0b25daeeda80e76d706825bc9221b84b8730a 100644 --- a/pysegcnn/core/trainer.py +++ b/pysegcnn/core/trainer.py @@ -35,23 +35,25 @@ import torch.nn as nn import torch.nn.functional as F from torch.utils.data import DataLoader from torch.optim import Optimizer -import matplotlib.pyplot as plt # locals from pysegcnn.core.dataset import SupportedDatasets, ImageDataset from pysegcnn.core.transforms import Augment from pysegcnn.core.utils import (img2np, item_in_enum, accuracy_function, - reconstruct_scene, check_filename_length) -from pysegcnn.core.split import SupportedSplits, CustomSubset, SceneSubset + reconstruct_scene, check_filename_length, + array_replace) +from pysegcnn.core.split import SupportedSplits from pysegcnn.core.models import (SupportedModels, SupportedOptimizers, Network) from pysegcnn.core.uda import SupportedUdaMethods, CoralLoss, UDA_POSITIONS from pysegcnn.core.layers import Conv2dSame from pysegcnn.core.logging import log_conf from pysegcnn.core.graphics import (plot_loss, plot_confusion_matrix, - plot_sample, Animate) + plot_sample) +from pysegcnn.core.constants import map_labels from pysegcnn.main.config import HERE, DRIVE_PATH + # module level logger LOGGER = logging.getLogger(__name__) @@ -2105,42 +2107,44 @@ class EarlyStopping(object): class NetworkInference(BaseConfig): """Model inference configuration. - Evaluate a model. + Evaluate model(s) trained by an instance of + :py:class:`pysegcnn.core.trainer.DomainAdaptationTrainer` on an instance of + a :py:class:`pysegcnn.core.dataset.ImageDataset` dataset. Attributes ---------- - state_file : :py:class:`pathlib.Path` - Path to the model to evaluate. + state_files : `list´ [:py:class:`pathlib.Path`] + Path to the model(s) to evaluate. implicit : `bool` Whether to evaluate the model on the datasets defined at training time. + The default is `True`. domain : `str` Whether to evaluate on the source domain (``domain='src'``), i.e. the - domain the model in ``state_file`` was trained on, or a target domain - (``domain='trg'``). + domain a model was trained on, or a target domain (``domain='trg'``). + The default is `'src'`. test : `bool` or `None` - Whether to evaluate the model on the training(``test=None``), the - validation (``test=False``) or the test set (``test=True``). - map_labels : `bool` - Whether to map the model labels from the model source domain to the - defined ``domain`` in case the domain class labels differ. + Whether to evaluate the model on the training (``test=None``), the + validation (``test=False``) or the test set (``test=True``). The + default is `False`. ds : `dict` The dataset configuration dictionary passed to :py:class:`pysegcnn.core.trainer.DatasetConfig` when evaluating on - an explicitly defined dataset, i.e. ``implicit=False``. + an explicitly defined dataset, i.e. ``implicit=False``. The default is + `{}`. ds_split : `dict` The dataset split configuration dictionary passed to :py:class:`pysegcnn.core.trainer.SplitConfig` when evaluating on - an explicitly defined dataset, i.e. ``implicit=False``. + an explicitly defined dataset, i.e. ``implicit=False``. The default is + `{}`. + map_labels : `bool` + Whether to map the model labels from the model source domain to the + defined ``domain`` in case the domain class labels differ. The default + is `False`. predict_scene : `bool` - The model prediction order. If False, the samples (tiles) of a dataset - are predicted in any order and the scenes are not reconstructed. - If True, the samples (tiles) are ordered according to the scene they - belong to and a model prediction for each entire reconstructed scene is - returned. The default is `False`. - plot_samples : `bool` - Whether to save a plot of false color composite, ground truth and model - prediction for each sample (tile). Only used if ``predict_scene=False`` - . The default is `False`. + The model prediction order. If ``predict_scene=False``, the samples of + a dataset are predicted in any order.If ``predict_scene=True``, the + samples are ordered according to their scene and a model prediction for + each entire reconstructed scene is returned. The default is `True`. plot_scenes : `bool` Whether to save a plot of false color composite, ground truth and model prediction for each entire scene. Only used if ``predict_scene=True``. @@ -2149,8 +2153,7 @@ class NetworkInference(BaseConfig): The bands to build the false color composite. The default is `['nir', 'red', 'green']`. cm : `bool` - Whether to compute and plot the confusion matrix. The default is `True` - . + Whether to compute the confusion matrix. The default is `True`. figsize : `tuple` The figure size in centimeters. The default is `(10, 10)`. alpha : `int` @@ -2173,8 +2176,8 @@ class NetworkInference(BaseConfig): animtn_path : :py:class:`pathlib.Path` Path to store animations. models_path : :py:class:`pathlib.Path` - Path to search for model state files, i.e. pretrained models. - kwargs : `dict` + Path to search for model state files ``state_files``. + plot_kwargs : `dict` Keyword arguments for :py:func:`pysegcnn.core.graphics.plot_sample` basename : `str` Base filename for each plot. @@ -2197,15 +2200,14 @@ class NetworkInference(BaseConfig): """ - state_file: pathlib.Path - implicit: bool - domain: str - test: object - map_labels: bool + state_files: list + implicit: bool = True + domain: str = 'src' + test: object = False ds: dict = dataclasses.field(default_factory={}) ds_split: dict = dataclasses.field(default_factory={}) - predict_scene: bool = False - plot_samples: bool = False + map_labels: bool = False + predict_scene: bool = True plot_scenes: bool = False plot_bands: list = dataclasses.field( default_factory=lambda: ['nir', 'red', 'green']) @@ -2240,8 +2242,8 @@ class NetworkInference(BaseConfig): .format(self.domain)) # the device to compute on, use gpu if available - self.device = torch.device("cuda:0" if torch.cuda.is_available() else - "cpu") + self.device = torch.device( + 'cuda:' if torch.cuda.is_available() else 'cpu') # the output paths for the different graphics self.base_path = pathlib.Path(HERE) @@ -2252,87 +2254,13 @@ class NetworkInference(BaseConfig): # input path for model state files self.models_path = self.base_path.joinpath('_models') - self.state_file = self.models_path.joinpath(self.state_file) - - # initialize logging - log = LogConfig(self.state_file) - dictConfig(log_conf(log.log_file)) - log.init_log('{}: ' + 'Evaluating model: {}.' - .format(self.state_file.name)) + self.state_files = [self.models_path.joinpath(s) for s in + self.state_files] # plotting keyword arguments - self.kwargs = {'bands': self.plot_bands, - 'alpha': self.alpha, - 'figsize': self.figsize} - - # base filename for each plot - self.basename = self.state_file.stem - - # load the model state - self.model, _, self.model_state = Network.load(self.state_file) - - # load the target dataset: dataset to evaluate the model on - self.trg_ds = self.load_dataset() - - # load the source dataset: dataset the model was trained on - self.src_ds = self.model_state['src_train_dl'].dataset.dataset - - # create a figure to use for plotting - self.fig, _ = plt.subplots(1, 3, figsize=self.kwargs['figsize']) - - # check if the animate parameter is correctly specified - if self.animate: - if not self.plot: - LOGGER.warning('animate requires plot_scenes=True or ' - 'plot_samples=True. Can not create animation.') - else: - # check whether the output path is valid - if not self.anim_path.exists(): - # create output path - self.anim_path.mkdir(parents=True, exist_ok=True) - self.anim = Animate(self.anim_path) - - @staticmethod - def get_scene_tiles(ds, scene_id): - """Return the tiles of the scene with id ``scene_id``. - - Parameters - ---------- - ds : :py:class:`pysegcnn.core.split.CustomSubset` - A instance of a subclass of - :py:class:`pysegcnn.core.split.CustomSubset`. - scene_id : `str` - A valid scene identifier. - - Raises - ------ - ValueError - Raised if ``scene_id`` is not a valid scene identifier for the - dataset ``ds``. - - Returns - ------- - indices : `list` [`int`] - List of indices of the tiles of the scene with id ``scene_id`` in - ``ds``. - date : :py:class:`datetime.datetime` - The date of the scene with id ``scene_id``. - - """ - # check if the scene id is valid - scene_meta = ds.dataset.parse_scene_id(scene_id) - if scene_meta is None: - raise ValueError('{} is not a valid scene identifier' - .format(scene_id)) - - # iterate over the scenes of the dataset - indices = [] - for i, scene in enumerate(ds.scenes): - # if the scene id matches a given id, save the index of the scene - if scene['id'] == scene_id: - indices.append(i) - - return indices, scene_meta['date'] + self.plot_kwargs = {'bands': self.plot_bands, + 'alpha': self.alpha, + 'figsize': self.figsize} @staticmethod def replace_dataset_path(ds, drive_path): @@ -2366,16 +2294,6 @@ class NetworkInference(BaseConfig): :py:class:`pysegcnn.core.dataset.ImageDataset`. """ - # check input type - if isinstance(ds, CustomSubset): - # check the type of the dataset - if not isinstance(ds.dataset, ImageDataset): - raise TypeError('ds should be a subset created from a {}.' - .format(repr(ImageDataset))) - else: - raise TypeError('ds should be an instance of {}.' - .format(repr(CustomSubset))) - # iterate over the scenes of the dataset for scene in ds.dataset.scenes: for k, v in scene.items(): @@ -2392,7 +2310,7 @@ class NetworkInference(BaseConfig): if dpath != drive_path: scene[k] = v.replace(str(dpath), drive_path) - def load_dataset(self): + def load_dataset(self, state, implicit=True, test=False, domain='src'): """Load the defined dataset. Raises @@ -2401,57 +2319,50 @@ class NetworkInference(BaseConfig): Raised if the requested dataset was not available at training time, if ``implicit=True``. - Raised if the dataset ``ds`` does not have the same spectral bands - as the model to evaluate, if ``implicit=False``. - Returns ------- ds : :py:class:`pysegcnn.core.split.CustomSubset` The dataset to evaluate the model on. """ + # load model state + model_state = Network.load(state) + # check whether to evaluate on the datasets defined at training time - if self.implicit: + if implicit: # check whether to evaluate the model on the training, validation # or test set - if self.test is None: + if test is None: ds_set = 'train' else: - ds_set = 'test' if self.test else 'valid' + ds_set = 'test' if test else 'valid' # the dataset to evaluate the model on - ds = self.model_state[ - self.domain + '_{}_dl'.format(ds_set)].dataset + ds = model_state[domain + '_{}_dl'.format(ds_set)].dataset if ds is None: raise ValueError('Requested dataset "{}" is not available.' - .format(self.domain + '_{}_dl'.format(ds_set)) + .format(domain + '_{}_dl'.format(ds_set)) ) # log dataset representation LOGGER.info('Evaluating on {} set of the {} domain defined at ' - 'training time.'.format(ds_set, self.domain)) + 'training time.'.format(ds_set, domain)) else: # explicitly defined dataset ds = DatasetConfig(**self.ds).init_dataset() - # check if the spectral bands match - if ds.use_bands != self.model_state['bands']: - raise ValueError('The model was trained with bands {}, not ' - 'with bands {}.'.format( - self.model_state['bands'], ds.use_bands)) - # split configuration sc = SplitConfig(**self.ds_split) train_ds, valid_ds, test_ds = sc.train_val_test_split(ds) # check whether to evaluate the model on the training, validation # or test set - if self.test is None: + if test is None: ds = train_ds else: - ds = test_ds if self.test else valid_ds + ds = test_ds if test else valid_ds # log dataset representation LOGGER.info('Evaluating on {} set of explicitly defined dataset: ' @@ -2476,33 +2387,33 @@ class NetworkInference(BaseConfig): @property def target_labels(self): - """Class labels of the dataset to evaluate. + """Class labels of the target domain the model is evaluated on. Returns ------- - target_labels : `dict` [`int`, `dict`] - The class labels of the dataset to evaluate. + target_labels : `dict` [`int`, `dict`] + The class labels of the target domain. """ - return self.trg_ds.dataset.labels + return self.trg_ds.labels - # @property - # def label_map(self): - # """Label mapping from the source to the target domain. + @property + def label_map(self): + """Label mapping dictionary from the source to the target domain. - # See :py:func:`pysegcnn.core.constants.map_labels`. + See :py:class:`pysegcnn.core.constants.LabelMapping`. - # Returns - # ------- - # label_map : `dict` [`int`, `int`] - # Dictionary with source labels as keys and corresponding target - # labels as values. + Returns + ------- + label_map : `dict` [`int`, `int`] + Dictionary with source labels as keys and corresponding target + labels as values. - # """ - # # check whether the source domain labels are the same as the target - # # domain labels - # return map_labels(self.src_ds.get_labels(), - # self.trg_ds.dataset.get_labels()) + """ + # check whether the source domain labels are the same as the target + # domain labels + return map_labels(self.src_ds.get_labels(), + self.trg_ds.dataset.get_labels()) @property def source_is_target(self): @@ -2528,7 +2439,7 @@ class NetworkInference(BaseConfig): requested, `False` otherwise. """ - return not self.source_is_target and self.map_labels + return self.map_labels and not self.source_is_target @property def use_labels(self): @@ -2555,21 +2466,6 @@ class NetworkInference(BaseConfig): """ return self.src_ds.use_bands - @property - def compute_cm(self): - """Whether to compute the confusion matrix. - - Returns - ------- - compute_cm : `bool` - Whether the confusion matrix can be computed. For datasets with - labels different from the source domain labels, the confusion - matrix can not be computed. - - """ - return (False if not self.source_is_target and not self.map_labels - else self.cm) - @property def plot(self): """Whether to save plots of (input, ground truth, prediction). @@ -2581,24 +2477,7 @@ class NetworkInference(BaseConfig): depending on ``self.predict_scene``. """ - return self.plot_scenes if self.predict_scene else self.plot_samples - - @property - def is_scene_subset(self): - """Check the type of the target dataset. - - Whether ``self.trg_ds`` is an instance of - :py:class:`pysegcnn.core.split.SceneSubset`, as required when - ``self.predict_scene=True``. - - Returns - ------- - is_scene_subset : `bool` - Whether ``self.trg_ds`` is an instance of - :py:class:`pysegcnn.core.split.SceneSubset`. - - """ - return isinstance(self.trg_ds, SceneSubset) + return self.plot_scenes if self.predict_scene else False @property def dataloader(self): @@ -2614,52 +2493,6 @@ class NetworkInference(BaseConfig): return DataLoader(self.trg_ds, batch_size=self._batch_size, shuffle=False, drop_last=False) - @property - def _original_source_labels(self): - """Original source domain labels. - - Since PyTorch requires class labels to be an ascending sequence - starting from 0, the actual class labels in the ground truth may differ - from the class labels fed to the model. - - Returns - ------- - original_source_labels : `dict` [`int`, `dict`] - The original class labels of the source domain. - - """ - return self.src_ds._labels - - @property - def _original_target_labels(self): - """Original target domain labels. - - Returns - ------- - original_target_labels : `dict` [`int`, `dict`] - The original class labels of the target domain. - - """ - return self.trg_ds.dataset._labels - - @property - def _label_log(self): - """Log if a label mapping is applied. - - Returns - ------- - log : `str` - Represenation of the label mapping. - - """ - log = 'Retaining model labels ({})'.format( - ', '.join([v['label'] for _, v in self.source_labels.items()])) - if self.apply_label_map: - log = (log.replace('Retaining', 'Mapping') + ' to target labels ' - '({}).'.format(', '.join([v['label'] for _, v in - self.target_labels.items()]))) - return log - @property def _batch_size(self): """Batch size of the inference dataloader. @@ -2673,37 +2506,7 @@ class NetworkInference(BaseConfig): dataset. """ - return (self.trg_ds.dataset.tiles if self.predict_scene and - self.is_scene_subset else 1) - - def _check_long_filename(self, filename): - """Modify filenames that exceed Windows' maximum filename length. - - Parameters - ---------- - filename : `str` - The filename to check. - - Returns - ------- - filename : `str` - The modified filename, in case ``filename`` exceeds 255 characters. - - """ - # check for maximum path component length: Windows allows a maximum - # of 255 characters for each component in a path - if len(filename) >= 255: - # try to parse the scene identifier - scene_metadata = self.trg_ds.dataset.parse_scene_id(filename) - if scene_metadata is not None: - # new, shorter name for the current batch - batch_name = '_'.join([scene_metadata['satellite'], - scene_metadata['tile'], - datetime.datetime.strftime( - scene_metadata['date'], '%Y%m%d')]) - filename = filename.replace(scene_metadata['id'], batch_name) - - return filename + return self.trg_ds.dataset.tiles if self.predict_scene else 1 def map_to_target(self, prd): """Map source domain labels to target domain labels. @@ -2711,7 +2514,7 @@ class NetworkInference(BaseConfig): Parameters ---------- prd : :py:class:`torch.Tensor` - The source domain class labels as predicted by ```self.model``. + The source domain class labels as predicted by the model. Returns ------- @@ -2720,22 +2523,20 @@ class NetworkInference(BaseConfig): """ # map actual source labels to original source labels - for aid, oid in zip(self.source_labels.keys(), - self._original_source_labels.keys()): - prd[torch.where(prd == aid)] = oid + # prd = array_replace(prd, np.array([self.src_ds.labels.keys(), + # self.src_ds._labels.keys()])) # apply the label mapping - for src_label, trg_label in self.label_map.items(): - prd[torch.where(prd == src_label)] = trg_label + # prd = array_replace(prd, self.label_map.to_numpy()) # map original target labels to actual target labels - for oid, aid in zip(self._original_target_labels.keys(), - self.target_labels.keys()): - prd[torch.where(prd == oid)] = aid + # for oid, aid in zip(self.trg_ds._labels.keys(), + # self.target_labels.keys()): + # prd[torch.where(prd == oid)] = aid return prd - def predict(self): + def predict(self, model): """Classify the samples of the target dataset. Returns @@ -2753,6 +2554,11 @@ class NetworkInference(BaseConfig): Model prediction class labels (:py:class:`numpy.ndarray`). """ + # set the model to evaluation mode + LOGGER.info('Setting model to evaluation mode ...') + model.eval() + model.to(self.device) + # iterate over the samples of the target dataset output = {} for batch, (inputs, labels) in enumerate(self.dataloader): @@ -2763,35 +2569,18 @@ class NetworkInference(BaseConfig): # compute model predictions with torch.no_grad(): - prdctn = F.softmax(self.model(inputs), - dim=1).argmax(dim=1).squeeze() - - # map source labels to target dataset labels - if self.apply_label_map: - prdctn = self.map_to_target(prdctn) - - # update confusion matrix - if self.compute_cm: - for ytrue, ypred in zip(labels.view(-1), prdctn.view(-1)): - self.conf_mat[ytrue.long(), ypred.long()] += 1 - - # convert torch tensors to numpy arrays - inputs = inputs.numpy() - labels = labels.numpy() - prdctn = prdctn.numpy() + prdctn = F.softmax( + model(inputs), dim=1).argmax(dim=1).squeeze() # progress string to log progress = 'Sample: {:d}/{:d}'.format(batch + 1, len(self.dataloader)) # check whether to reconstruct the scene - date = None if self.dataloader.batch_size > 1: - # id and date of the current scene + # id of the current scene batch = self.trg_ds.ids[batch] - if self.trg_ds.split_mode == 'date': - date = self.trg_ds.dataset.parse_scene_id(batch)['date'] # modify the progress string progress = progress.replace('Sample', 'Scene') @@ -2802,63 +2591,43 @@ class NetworkInference(BaseConfig): labels = reconstruct_scene(labels) prdctn = reconstruct_scene(prdctn) + # check whether the source and target domain labels differ + if self.apply_label_map: + prdctn = self.map_to_target(prdctn) + # save current batch to output dictionary - output[batch] = {'input': inputs, 'labels': labels, - 'prediction': prdctn} + output[batch] = {'x': inputs, 'y': labels, 'y_pred': prdctn} # filename for the plot of the current batch - batch_name = self.basename + '_{}_{}.pt'.format(self.trg_ds.name, - batch) + batch_name = '_'.join(model.state_file.stem, + '{}_{}.pt'.format(self.trg_ds.name, batch)) # check if the current batch name exceeds the Windows limit of # 255 characters - batch_name = self._check_long_filename(batch_name) - - # in case the source and target class labels are the same or a - # label mapping is applied, the accuracy of the prediction can be - # calculated - if self.source_is_target or self.apply_label_map: - progress += ', Accuracy: {:.2f}'.format( - accuracy_function(prdctn, labels)) + batch_name = check_filename_length(batch_name) + + # calculate the accuracy of the prediction + progress += ', Accuracy: {:.2f}'.format( + accuracy_function(prdctn, labels)) LOGGER.info(progress) # plot current scene if self.plot: - # in case the source and target class labels are the same or a - # label mapping is applied, plot the ground truth, otherwise - # just plot the model prediction - gt = None - if self.source_is_target or self.apply_label_map: - gt = labels - # plot inputs, ground truth and model predictions - plot_sample(inputs.clip(0, 1), - self.bands, - self.use_labels, - y=gt, - y_pred={self.model.__class__.__name__: prdctn}, - accuracy=True, - state=batch_name, - date=date, - fig=self.fig, - plot_path=self.scenes_path, - **self.kwargs) - - # save current figure state as frame for animation - if self.animate: - self.anim.frame(self.fig.axes) - - # save animation - if self.animate: - self.anim.animate(self.fig, interval=1000, repeat=True, blit=True) - self.anim.save(self.basename + '_{}.gif'.format(self.trg_ds.name), - dpi=200) - + _ = plot_sample(inputs.clip(0, 1), + self.bands, + self.source_labels, + y=labels, + y_pred={self.model.__class__.__name__: prdctn}, + accuracy=True, + state=batch_name, + plot_path=self.scenes_path, + **self.kwargs) return output def evaluate(self): - """Evaluate a pretrained model on a defined dataset. + """Evaluate the models on a defined dataset. Returns ------- @@ -2867,37 +2636,75 @@ class NetworkInference(BaseConfig): the samples (``self.predict_scene=False``) or the name of the scenes of the target dataset (``self.predict_scene=True``). The values are dictionaries with keys: - ``'input'`` + ``'x'`` Model input data of the sample (:py:class:`numpy.ndarray`). - ``'labels' + ``'y' Ground truth class labels (:py:class:`numpy.ndarray`). - ``'prediction'`` + ``'y_pred'`` Model prediction class labels (:py:class:`numpy.ndarray`). """ - # plot loss and accuracy - plot_loss(check_filename_length(self.state_file), - outpath=self.perfmc_path) + # iterate over the models to evaluate + inference = [] + for state in self.state_files: - # set the model to evaluation mode - LOGGER.info('Setting model to evaluation mode ...') - self.model.eval() - self.model.to(self.device) + # initialize logging + log = LogConfig(state) + dictConfig(log_conf(log.log_file)) + log.init_log('{}: ' + 'Evaluating model: {}.'.format(state.name)) - # initialize confusion matrix - self.conf_mat = np.zeros(shape=2 * (len(self.use_labels), )) + # plot loss and accuracy + plot_loss(check_filename_length(state), outpath=self.perfmc_path) - # log which labels the model predicts - LOGGER.info(self._label_log) + # load the target dataset to evaluate the model on + self.trg_ds = self.load_dataset( + state, implicit=self.implicit, test=self.test, + domain=self.domain) - # evaluate the model on the target dataset - output = self.predict() + # load the source dataset the model was trained on + self.src_ds = self.load_dataset(state, test=None) - # whether to plot the confusion matrix - if self.compute_cm: - plot_confusion_matrix(self.conf_mat, self.use_labels, - state_file=self.state_file, - subset=self.domain + '_' + self.trg_ds.name, - outpath=self.perfmc_path) + # load the pretrained model + model, _ = Network.load_pretrained_model(state) - return output + # evaluate the model on the target dataset + output = self.predict() + + # check whether to calculate confusion matrix + if self.cm: + + # initialize confusion matrix + conf_mat = np.zeros(shape=2 * (len(self.src_ds.labels), )) + + # calculate confusion matrix + for ytrue, ypred in zip(output['y'].flatten(), + output['y_pred'].flatten()): + # update confusion matrix entries + conf_mat[ytrue.long(), ypred.long()] += 1 + + # add confusion matrix to model output + output['cm'] = conf_mat + + # plot confusion matrix + plot_confusion_matrix( + conf_mat, self.source_labels, state_file=state, + subset=self.domain + '_' + self.trg_ds.name, + outpath=self.perfmc_path) + + # save model predictions to file + torch.save(output, str(state).replace('.pt', '_eval.pt')) + + # save model predictions to list + inference.append(output) + + # check whether to compute an aggregated confusion matrix + if self.aggregate and self.cm: + + # initialize aggregated confusion matrix + cm = np.zeros(shape=2 * (len(self.src_ds.labels), )) + + # iterate over the different model runs + for out in inference: + cm += out['cm'] + + # plot aggregated cm