From 04830e41ef0feb1c539d101fc14a52c80e9bedea Mon Sep 17 00:00:00 2001
From: "Daniel.Frisinghelli" <daniel.frisinghelli@eurac.edu>
Date: Fri, 22 Jan 2021 17:42:19 +0100
Subject: [PATCH] Working on a more stable generation of model state files.

---
 pysegcnn/core/trainer.py | 347 +++++++++++++--------------------------
 1 file changed, 112 insertions(+), 235 deletions(-)

diff --git a/pysegcnn/core/trainer.py b/pysegcnn/core/trainer.py
index 457671b..123be81 100644
--- a/pysegcnn/core/trainer.py
+++ b/pysegcnn/core/trainer.py
@@ -37,9 +37,9 @@ from torch.utils.data import DataLoader
 from torch.optim import Optimizer
 
 # locals
-from pysegcnn.core.dataset import SupportedDatasets, ImageDataset
+from pysegcnn.core.dataset import SupportedDatasets
 from pysegcnn.core.transforms import Augment
-from pysegcnn.core.utils import (img2np, item_in_enum, accuracy_function,
+from pysegcnn.core.utils import (item_in_enum, accuracy_function,
                                  reconstruct_scene, check_filename_length,
                                  array_replace)
 from pysegcnn.core.split import SupportedSplits
@@ -110,9 +110,6 @@ class DatasetConfig(BaseConfig):
         A regural expression to match the ground truth naming convention.
         All directories and subdirectories in ``root_dir`` are searched for
         files matching ``gt_pattern``.
-    seed : `int`
-        The random seed. Used to split the dataset into training,
-        validation and test set. Useful for reproducibility.
     sort : `bool`
         Whether to chronologically sort the samples. Useful for time series
         data. The default is `False`.
@@ -144,7 +141,6 @@ class DatasetConfig(BaseConfig):
     bands: list
     tile_size: object
     gt_pattern: str
-    seed: int
     sort: bool = False
     transforms: list = dataclasses.field(default_factory=[])
     pad: bool = False
@@ -196,7 +192,6 @@ class DatasetConfig(BaseConfig):
                     root_dir=str(self.root_dir),
                     use_bands=self.bands,
                     tile_size=self.tile_size,
-                    seed=self.seed,
                     sort=self.sort,
                     transforms=self.transforms,
                     pad=self.pad,
@@ -217,36 +212,32 @@ class SplitConfig(BaseConfig):
     ----------
     split_mode : `str`
         The mode to split the dataset.
-    ttratio : `float`
-        The ratio of training and validation data to test data, e.g.
-        ``ttratio= 0.6`` means 60% for training and validation, 40% for
-        testing.
-    tvratio : `float`
+    k_folds: `int`, optional
+        The number of folds.
+    seed : `int`, optional
+        The random seed for reproducibility. The default is `0`.
+    shuffle : `bool`, optional
+        Whether to shuffle the data before splitting into batches. The
+        default is `True`.
+    tvratio : `float`, optional
         The ratio of training data to validation data, e.g. ``tvratio=0.8``
-        means 80% training, 20% validation.
-    date : `str`
-        A date. Used if ``split_mode='date'``. The default is  `yyyymmdd`.
-    dateformat : `str`
-        The format of ``date``. ``dateformat`` is used by
-        :py:func:`datetime.datetime.strptime' to parse ``date`` to a
-        :py:class:`datetime.datetime` object. The default is `'%Y%m%d'`.
-    drop : `float`
-        Whether to drop samples (during training only) with a fraction of
-        pixels equal to the constant padding value >= ``drop``. ``drop=0``
-        means, do not drop any samples. The default is `0`.
-    split_class : :py:class:`pysegcnn.core.split.Split`
-        A subclass of :py:class:`pysegcnn.core.split.Split`.
-    dropped : `list` [`dict`]
-        List of the dropped samples.
+        means 80% training, 20% validation. The default is `0.8`. Used if
+        ``k_folds=1``.
+    ttratio : `float`, optional
+        The ratio of training and validation data to test data, e.g.
+        ``ttratio=0.6`` means 60% for training and validation, 40% for
+        testing. The default is `1`. Used if ``k_folds=1``.
+    split_class : :py:class:`pysegcnn.core.split.RandomSplit`
+        A subclass of :py:class:`pysegcnn.core.split.RandomSplit`.
 
     """
 
     split_mode: str
-    ttratio: float
-    tvratio: float
-    date: str = 'yyyymmdd'
-    dateformat: str = '%Y%m%d'
-    drop: float = 0
+    k_folds: int = 1
+    seed: int = 0
+    shuffle: bool = True
+    tvratio: float = 0.8
+    ttratio: float = 1
 
     def __post_init__(self):
         """Check the type of each argument.
@@ -263,52 +254,6 @@ class SplitConfig(BaseConfig):
         # check if the split mode is valid
         self.split_class = item_in_enum(self.split_mode, SupportedSplits)
 
-        # list of dropped samples
-        self.dropped = []
-
-    @staticmethod
-    def drop_samples(ds, drop_threshold=1):
-        """Drop samples with a fraction of pixels equal to the padding value.
-
-        Parameters
-        ----------
-        ds : :py:class:`pysegcnn.core.split.CustomSubset`
-            An instance of :py:class:`pysegcnn.core.split.CustomSubset`.
-        drop_threshold : `float`, optional
-            The threshold above which samples are dropped. ``drop_threshold=1``
-            means a sample is dropped, if all pixels are equal to the padding
-            value. ``drop_threshold=0.8`` means, drop a sample if 80% of the
-            pixels are equal to the padding value, etc. The default is `1`.
-
-        Returns
-        -------
-        dropped : `list` [`dict`]
-            List of the dropped samples.
-
-        """
-        # iterate over the scenes returned by self.compose_scenes()
-        dropped = []
-        for pos, i in enumerate(ds.indices):
-
-            # the current scene
-            s = ds.dataset.scenes[i]
-
-            # the current tile in the ground truth
-            tile_gt = img2np(s['gt'], ds.dataset.tile_size, s['tile'],
-                             ds.dataset.pad, ds.dataset.cval)
-
-            # percent of pixels equal to the constant padding value
-            npixels = (tile_gt[tile_gt == ds.dataset.cval].size / tile_gt.size)
-
-            # drop samples where npixels >= self.drop
-            if npixels >= drop_threshold:
-                LOGGER.info('Skipping scene {}, tile {}: {:.2f}% padded pixels'
-                            ' ...'.format(s['id'], s['tile'], npixels * 100))
-                dropped.append(s)
-                _ = ds.indices.pop(pos)
-
-        return dropped
-
     def train_val_test_split(self, ds):
         """Split ``ds`` into training, validation and test set.
 
@@ -333,29 +278,14 @@ class SplitConfig(BaseConfig):
             The test set.
 
         """
-        if not isinstance(ds, ImageDataset):
-            raise TypeError('Expected "ds" to be {}.'
-                            .format('.'.join([ImageDataset.__module__,
-                                              ImageDataset.__name__])))
-
-        if self.split_mode == 'random' or self.split_mode == 'scene':
-            subset = self.split_class(ds,
-                                      self.ttratio,
-                                      self.tvratio,
-                                      ds.seed)
-
-        else:
-            subset = self.split_class(ds, self.date, self.dateformat)
+        # instanciate the split class
+        split = self.split_class(ds, self.k_folds, self.seed, self.shuffle,
+                                 self.tvratio, self.ttratio)
 
         # the training, validation and test dataset
-        train_ds, valid_ds, test_ds = subset.split()
-
-        # whether to drop training samples with a fraction of pixels equal to
-        # the constant padding value cval >= drop
-        if ds.pad and self.drop > 0:
-            self.dropped = self.drop_samples(train_ds, self.drop)
+        subsets = split.split()
 
-        return train_ds, valid_ds, test_ds
+        return subsets
 
     @staticmethod
     def dataloaders(*args, **kwargs):
@@ -787,6 +717,7 @@ class StateConfig(BaseConfig):
 
     Generate the model state filename according to the following naming
     conventions:
+
         - For source domain without domain adaptation:
             Model_Optim_SourceDataset_ModelParams.pt
 
@@ -801,30 +732,8 @@ class StateConfig(BaseConfig):
             Model_Optim_SourceDataset_TargetDataset_ModelParams_prt_
             NameOfPretrainedModel.pt
 
-    Attributes
-    ----------
-    src_dc : :py:class:`pysegcnn.core.trainer.DatasetConfig`
-        The source domain dataset configuration.
-    src_sc : :py:class:`pysegcnn.core.trainer.SplitConfig`
-        The source domain dataset split configuration.
-    trg_dc : :py:class:`pysegcnn.core.trainer.DatasetConfig`
-        The target domain dataset configuration.
-    trg_sc : :py:class:`pysegcnn.core.trainer.SplitConfig`
-        The target domain dataset split configuration.
-    mc : :py:class:`pysegcnn.core.trainer.ModelConfig`
-        The model configuration.
-    tc : :py:class:`pysegcnn.core.trainer.TransferLearningConfig`
-        The transfer learning configuration.
-
     """
 
-    src_dc: DatasetConfig
-    src_sc: SplitConfig
-    trg_dc: DatasetConfig
-    trg_sc: SplitConfig
-    mc: ModelConfig
-    tc: TransferLearningConfig
-
     def __post_init__(self):
         """Check the type of each argument.
 
@@ -840,85 +749,112 @@ class StateConfig(BaseConfig):
         # base dataset state filename: Dataset_SplitMode_SplitParams
         self.ds_state_file = '{}_{}_{}'
 
-        # base model state filename: Model_Optim
-        self.ml_state_file = '{}_{}'
+        # base model state filename: Model_Optim_BatchSize
+        self.ml_state_file = '{}_{}_b{}'
 
-        # base model state filename extentsion: TileSize_BatchSize_Bands
-        self.ml_state_ext = 't{}_b{}_{}.pt'
-
-        # check that the spectral bands are the same for both source and target
-        # domains
-        if self.src_dc.bands != self.trg_dc.bands:
-            raise ValueError('Spectral bands of the source and target domain '
-                             'have to be equal: \n source: {} \n target: {}'
-                             .format(', '.join(self.src_dc.bands),
-                                     ', '.join(self.trg_dc.bands))
-                             )
+        # base model state filename extentsion: TileSize_Bands
+        self.ds_state_ext = 't{}_{}.pt'
 
-    def init_state(self):
+    def init_state(self, src_dc, src_sc, mc, trg_dc=None, trg_sc=None, tc=None,
+                   fold=0):
         """Generate the model state filename.
 
+        Parameters
+        ----------
+        src_dc : :py:class:`pysegcnn.core.trainer.DatasetConfig`
+            The source domain dataset configuration.
+        src_sc : :py:class:`pysegcnn.core.trainer.SplitConfig`
+            The source domain dataset split configuration.
+        mc : :py:class:`pysegcnn.core.trainer.ModelConfig`
+            The model configuration.
+        trg_dc : :py:class:`pysegcnn.core.trainer.DatasetConfig`
+            The target domain dataset configuration.
+        trg_sc : :py:class:`pysegcnn.core.trainer.SplitConfig`
+            The target domain dataset split configuration.
+        tc : :py:class:`pysegcnn.core.trainer.TransferLearningConfig`
+            The transfer learning configuration.
+
         Returns
         -------
         state : :py:class:`pathlib.Path`
             The path to the model state file.
 
         """
-        # source domain dataset state filename
-        src_ds_state = self._format_ds_state(
-            self.ds_state_file, self.src_dc, self.src_sc)
+        # source domain dataset state filename and extension
+        src_ds_state, src_ds_ext = self.format_ds_state(
+            src_dc, src_sc, fold)
 
-        # model state file name and extension: common to both source and target
-        # domain
-        ml_state, ml_ext = self._format_ml_state(self.src_dc, self.mc)
+        # model state file name: common to both source and target domain
+        ml_state = self.format_model_state(mc)
 
         # state file for models trained only on the source domain
-        state = '_'.join([ml_state, src_ds_state, ml_ext])
+        state = '_'.join([ml_state, src_ds_state, src_ds_ext])
 
         # check whether the model is trained on the source domain only
-        if self.tc.transfer:
+        if tc is not None:
+
+            # check whether the target domain configurations are correctly
+            # specified
+            if trg_dc is None or trg_sc is None:
+                raise ValueError('Target domain configurations required.')
 
             # target domain dataset state filename
-            trg_ds_state = self._format_ds_state(
-                self.ds_state_file, self.trg_dc, self.trg_sc)
+            trg_ds_state, _ = self.format_ds_state(
+                trg_dc, trg_sc, fold)
 
             # check whether a pretrained model is used to fine-tune to the
             # target domain
-            if self.tc.supervised:
+            if tc.supervised:
                 # state file for models fine-tuned to target domain
                 # DatasetConfig_PretrainedModel.pt
-                state = '_'.join([self.tc.pretrained_model,
+
+                # TODO: Is this correct? Trainer is initialized with source
+                # dataloaders
+                state = '_'.join([tc.pretrained_model,
                                   'sda_{}'.format(trg_ds_state)])
             else:
                 # state file for models trained via unsupervised domain
                 # adaptation
                 state = '_'.join([state.replace(
-                    ml_ext, 'uda_{}'.format(self.tc.uda_pos)),
-                    trg_ds_state, ml_ext])
+                    src_ds_ext, 'uda_{}'.format(tc.uda_pos)),
+                    trg_ds_state, src_ds_ext])
 
                 # check whether unsupervised domain adaptation is initialized
                 # from a pretrained model state
-                if self.tc.uda_from_pretrained:
+                if tc.uda_from_pretrained:
                     state = '_'.join(state.replace('.pt', ''),
                                      'prt_{}'.format(
-                                         self.tc.pretrained_model))
+                                         tc.pretrained_model))
 
         # path to model state
-        state = self.tc.state_path.joinpath(state)
+        state = mc.state_path.joinpath(state)
 
         return state
 
-    def _format_ds_state(self, state_file, dc, sc):
+    def format_model_state(self, mc):
+        """Format base model state filename.
+
+        Parameters
+        ----------
+        mc : :py:class:`pysegcnn.core.trainer.ModelConfig`
+            The model configuration.
+
+        """
+        return self.ml_state_file.format(mc.model_name, mc.optim_name,
+                                         mc.batch_size)
+
+    def format_ds_state(self, dc, sc, fold=None):
         """Format base dataset state filename.
 
         Parameters
         ----------
-        state_file : `str`
-            The base dataset state filename.
         dc : :py:class:`pysegcnn.core.trainer.DatasetConfig`
             The dataset configuration.
         sc : :py:class:`pysegcnn.core.trainer.SplitConfig`
             The dataset split configuration.
+        fold : `int` or `None`, optional
+            The number of the current fold. The default is `None`, which means
+            the fold is not reported in the model state filename.
 
         Returns
         -------
@@ -926,52 +862,39 @@ class StateConfig(BaseConfig):
             The formatted dataset state filename.
 
         """
+        # store the random seed for reproducibility
+        split_params = 's{}'.format(sc.seed)
 
-        # check which split mode was used
-        if sc.split_mode == 'date':
-            # store the date of the split
-            split_params = sc.date
+        # check whether the model is trained via cross validation
+        if sc.k_folds > 1 and fold is not None:
+            split_params += 'f{}'.format(fold)
         else:
-            # store the random split parameters
-            split_params = 's{}t{}v{}'.format(
-                dc.seed, str(sc.ttratio).replace('.', ''),
-                str(sc.tvratio).replace('.', ''))
-
-        # model state filename
-        file = state_file.format(dc.dataset_class.__name__ +
-                                 '_m{}'.format(len(dc.merge_labels)),
-                                 sc.split_mode.capitalize(),
-                                 split_params,
-                                 )
-
-        return file
+            # construct dataset split parameters
+            split_params += 't{}v{}'.format(str(sc.ttratio).replace('.', ''),
+                                            str(sc.tvratio).replace('.', ''))
 
-    def _format_ml_state(self, dc, mc):
-        """Format base model state filename.
-
-        Parameters
-        ----------
-        dc : :py:class:`pysegcnn.core.trainer.DatasetConfig`
-            The source domain dataset configuration.
-        mc : :py:class:`pysegcnn.core.trainer.ModelConfig`
-            The model configuration.
-
-        Returns
-        -------
-        extention : `str`
-            The formatted model state filename.
-
-        """
         # get the band numbers
         if dc.bands:
+            # the spectral bands used to train the model
             bands = dc.dataset_class.get_sensor().band_dict()
             bformat = ''.join([(v[0] + str(k)) for k, v in bands.items() if
                                v in dc.bands])
         else:
+            # all available spectral bands are used to train the model
             bformat = 'all'
 
-        return (self.ml_state_file.format(mc.model_name, mc.optim_name),
-                self.ml_state_ext.format(dc.tile_size, mc.batch_size, bformat))
+        # dataset state filename
+        file = self.ds_state_file.format(dc.dataset_class.__name__ +
+                                         '_m{}'.format(len(dc.merge_labels)),
+                                         sc.split_mode.capitalize(),
+                                         split_params,
+                                         )
+
+        # dataset state filename extension: common to both source and target
+        # domain
+        ext = self.ds_state_ext.format(dc.tile_size, bformat)
+
+        return file, ext
 
 
 @dataclasses.dataclass
@@ -1743,38 +1666,6 @@ class DomainAdaptationTrainer(ClassificationNetworkTrainer):
                 'uda_pos': self.uda_pos,
                 'uda_lambda': self.uda_lambda}
 
-    def _build_ds_repr(self, train_dl, valid_dl, test_dl):
-        """Build the dataset representation.
-
-        Returns
-        -------
-        fs : `str`
-            Representation string.
-
-        """
-        # dataset tile size
-        tile_size = (2 * (train_dl.dataset.dataset.tile_size,) if
-                     train_dl.dataset.dataset.tile_size is not None else
-                     train_dl.dataset.dataset.get_size())
-
-        # dataset configuration
-        fs = '    (dataset):\n        '
-        fs += ''.join(repr(train_dl.dataset.dataset)).replace('\n',
-                                                              '\n' + 8 * ' ')
-        fs += '\n    (batch):\n        '
-        fs += '- batch size: {}\n        '.format(train_dl.batch_size)
-        fs += '- mini-batch shape (b, c, h, w): {}'.format(
-            ((min(train_dl.batch_size, len(train_dl.dataset)),
-              len(train_dl.dataset.dataset.use_bands),) + tile_size))
-
-        # dataset split
-        fs += '\n    (split):'
-        for dl in [train_dl, valid_dl, test_dl]:
-            if dl.dataset is not None:
-                fs += '\n' + 8 * ' ' + repr(dl.dataset)
-
-        return fs
-
     def __repr__(self):
         """Representation.
 
@@ -1787,20 +1678,6 @@ class DomainAdaptationTrainer(ClassificationNetworkTrainer):
         # representation string to print
         fs = self.__class__.__name__ + '(\n'
 
-        # source domain
-        fs += '    (source domain)\n    '
-        fs += self._build_ds_repr(
-            self.src_train_dl, self.src_valid_dl, self.src_test_dl).replace(
-                '\n', '\n' + 4 * ' ')
-
-        # target domain
-        if not self.supervised:
-            fs += '\n    (target domain)\n    '
-            fs += self._build_ds_repr(
-                self.trg_train_dl,
-                self.trg_valid_dl,
-                self.trg_test_dl).replace('\n', '\n' + 4 * ' ')
-
         # model configuration
         fs += self._build_model_repr_()
 
-- 
GitLab