From c435b1a4c2e6ced7a5cc9c2042b09c6bdb8abef1 Mon Sep 17 00:00:00 2001
From: "Daniel.Frisinghelli" <daniel.frisinghelli@eurac.edu>
Date: Wed, 12 Aug 2020 17:36:38 +0200
Subject: [PATCH] Working on a refactored version of pysegcnn.core.trainer.py

---
 pysegcnn/core/initconf.py | 722 ++++++++++++++++++++++++++++++++++++++
 pysegcnn/main/train.py    |  12 +-
 2 files changed, 730 insertions(+), 4 deletions(-)
 create mode 100644 pysegcnn/core/initconf.py

diff --git a/pysegcnn/core/initconf.py b/pysegcnn/core/initconf.py
new file mode 100644
index 0000000..7e597ab
--- /dev/null
+++ b/pysegcnn/core/initconf.py
@@ -0,0 +1,722 @@
+# -*- coding: utf-8 -*-
+"""
+Created on Wed Aug 12 10:24:34 2020
+
+@author: Daniel
+"""
+# builtins
+import dataclasses
+import pathlib
+
+# externals
+import numpy as np
+import torch
+import torch.nn.functional as F
+from torch.utils.data import DataLoader
+
+# locals
+from pysegcnn.core.dataset import SupportedDatasets, ImageDataset
+from pysegcnn.core.transforms import Augment
+from pysegcnn.core.utils import img2np, item_in_enum, accuracy_function
+from pysegcnn.core.split import SupportedSplits
+from pysegcnn.core.models import (SupportedModels, SupportedOptimizers,
+                                  SupportedLossFunctions)
+from pysegcnn.core.layers import Conv2dSame
+from pysegcnn.main.config import HERE
+
+
+@dataclasses.dataclass
+class BaseConfig:
+
+    def __post_init__(self):
+        # check input types
+        for field in dataclasses.fields(self):
+            # the value of the current field
+            value = getattr(self, field.name)
+
+            # check whether the value is of the correct type
+            if not isinstance(value, field.type):
+                # try to convert the value to the correct type
+                try:
+                    setattr(self, field.name, field.type(value))
+                except TypeError:
+                    # raise an exception if the conversion fails
+                    raise TypeError('Expected {} to be {}, got {}.'
+                                    .format(field.name, field.type,
+                                            type(value)))
+
+
+@dataclasses.dataclass
+class DatasetConfig(BaseConfig):
+    root_dir: pathlib.Path
+    bands: list
+    tile_size: int
+    gt_pattern: str
+    seed: int
+    sort: bool = False
+    transforms: list = dataclasses.field(default_factory=list)
+    pad: bool = False
+    cval: int = 99
+
+    def __post_init__(self):
+        # check input types
+        super().__post_init__()
+
+        # check whether the root directory exists
+        if not self.root_dir.exists():
+            raise FileNotFoundError('{} does not exist.'.format(self.root_dir))
+
+        # check whether the transformations inherit from the correct class
+        if not all([isinstance(t, Augment) for t in self.transforms if
+                    self.transforms]):
+            raise TypeError('Each transformation is expected to be an instance'
+                            ' of {}.'.format('.'.join([Augment.__module__,
+                                                       Augment.__name__])))
+
+        # check whether the constant padding value is within the valid range
+        if not 0 < self.cval < 255:
+            raise ValueError('Expecting 0 <= cval <= 255, got cval={}.'
+                             .format(self.cval))
+
+        # the dataset name
+        self.dataset_name = self.root_dir.name
+
+        # check whether the dataset is currently supported
+        self.dataset_class = item_in_enum(self.dataset_name, SupportedDatasets)
+
+    def init_dataset(self):
+
+        # instanciate the dataset
+        dataset = self.dataset_class(
+                    root_dir=str(self.root_dir),
+                    use_bands=self.bands,
+                    tile_size=self.tile_size,
+                    seed=self.seed,
+                    sort=self.sort,
+                    transforms=self.transforms,
+                    pad=self.pad,
+                    cval=self.cval,
+                    gt_pattern=self.gt_pattern
+                    )
+
+        return dataset
+
+
+@dataclasses.dataclass
+class SplitConfig(BaseConfig):
+    split_mode: str
+    ttratio: float
+    tvratio: float
+    date: str = 'yyyymmdd'
+    dateformat: str = '%Y%m%d'
+    drop: float = 0
+
+    def __post_init__(self):
+        # check input types
+        super().__post_init__()
+
+        # check if the split mode is valid
+        self.split_class = item_in_enum(self.split_mode, SupportedSplits)
+
+    # function to drop samples with a fraction of pixels equal to the constant
+    # padding value self.cval >= self.drop
+    def _drop_samples(self, ds, drop_threshold=1):
+
+        # iterate over the scenes returned by self.compose_scenes()
+        dropped = []
+        for pos, i in enumerate(ds.indices):
+
+            # the current scene
+            s = ds.dataset.scenes[i]
+
+            # the current tile in the ground truth
+            tile_gt = img2np(s['gt'], ds.dataset.tile_size, s['tile'],
+                             ds.dataset.pad, ds.dataset.cval)
+
+            # percent of pixels equal to the constant padding value
+            npixels = (tile_gt[tile_gt == ds.dataset.cval].size / tile_gt.size)
+
+            # drop samples where npixels >= self.drop
+            if npixels >= drop_threshold:
+                print('Skipping scene {}, tile {}: {:.2f}% padded pixels ...'
+                      .format(s['id'], s['tile'], npixels * 100))
+                dropped.append(s)
+                _ = ds.indices.pop(pos)
+
+        return dropped
+
+    def train_val_test_split(self, ds):
+
+        if not isinstance(ds, ImageDataset):
+            raise TypeError('Expected "ds" to be {}.'
+                            .format('.'.join([ImageDataset.__module__,
+                                              ImageDataset.__name__])))
+
+        if self.split_mode == 'random' or self.split_mode == 'scene':
+            subset = self.split_class(ds,
+                                      self.ttratio,
+                                      self.tvratio,
+                                      ds.seed)
+
+        else:
+            subset = self.split_class(ds, self.date, self.dateformat)
+
+        # the training, validation and test dataset
+        train_ds, valid_ds, test_ds = subset.split()
+
+        # whether to drop training samples with a fraction of pixels equal to
+        # the constant padding value cval >= drop
+        if ds.pad and self.drop > 0:
+            self.dropped = self._drop_samples(train_ds, self.drop)
+
+        return train_ds, valid_ds, test_ds
+
+
+@dataclasses.dataclass
+class ModelConfig(BaseConfig):
+    model_name: str
+    filters: list
+    batch_size: int
+    skip_connection: bool = True
+    kwargs: dict = dataclasses.field(
+        default_factory=lambda: {'kernel_size': 3, 'stride': 1, 'dilation': 1})
+    state_path: pathlib.Path = pathlib.Path(HERE).joinpath('_models/')
+    batch_size: int = 64
+    checkpoint: bool = False
+    pretrained: bool = False
+    pretrained_model: str = ''
+
+    def __post_init__(self):
+        # check input types
+        super().__post_init__()
+
+        # check whether the model is currently supported
+        self.model_class = item_in_enum(self.model_name, SupportedModels)
+
+    def init_state(self, ds):
+
+        # file to save model state to
+        # format: network_dataset_seed_tilesize_batchsize_bands.pt
+
+        # get the band numbers
+        bformat = ''.join(band[0] +
+                          str(ds.sensor.__members__[band].value) for
+                          band in ds.use_bands)
+
+        # model state filename
+        state_file = ('{}_{}_s{}_t{}_b{}_{}.pt'
+                      .format(self.model_class.__name__,
+                              ds.__class__.__name__,
+                              ds.seed,
+                              ds.tile_size,
+                              self.batch_size,
+                              bformat))
+
+        # check whether a pretrained model was used and change state filename
+        # accordingly
+        if self.pretrained:
+            # add the configuration of the pretrained model to the state name
+            state_file = (state_file.replace('.pt', '_') +
+                          'pretrained_' + self.pretrained_model)
+
+        # path to model state
+        state = self.state_path.joinpath(state_file)
+
+        # path to model loss/accuracy
+        loss_state = pathlib.Path(str(state).replace('.pt', '_loss.pt'))
+
+        return state, loss_state
+
+    def init_model(self, ds):
+
+        # case (1): build a new model
+        if not self.pretrained:
+
+            # instanciate the model
+            model = self.model_class(
+                in_channels=len(ds.use_bands),
+                nclasses=len(ds.labels),
+                filters=self.filters,
+                skip=self.skip_connection,
+                **self.kwargs)
+
+        # case (2): load a pretrained model
+        else:
+
+            # load pretrained model
+            model = self.load_pretrained()
+
+        return model
+
+    def load_checkpoint(self, state_file, loss_state, model, optimizer):
+
+        # initial accuracy on the validation set
+        max_accuracy = 0
+
+        # set the model checkpoint to None, overwritten when resuming
+        # training from an existing model checkpoint
+        checkpoint_state = None
+
+        # whether to resume training from an existing model
+        if self.checkpoint:
+
+            # check if a model checkpoint exists
+            if not state_file.exists():
+                raise FileNotFoundError('Model checkpoint {} does not exist.'
+                                        .format(state_file))
+
+            # load the model state
+            state = model.load(state_file.name, optimizer, self.state_path)
+            print('Resuming training from {} ...'.format(state))
+            print('Model epoch: {:d}'.format(model.epoch))
+
+            # load the model loss and accuracy
+            checkpoint_state = torch.load(loss_state)
+
+            # get all non-zero elements, i.e. get number of epochs trained
+            # before the early stop
+            checkpoint_state = {k: v[np.nonzero(v)].reshape(v.shape[0], -1)
+                                for k, v in checkpoint_state.items()}
+
+            # maximum accuracy on the validation set
+            max_accuracy = checkpoint_state['va'][:, -1].mean().item()
+
+        return checkpoint_state, max_accuracy
+
+    def load_pretrained(self, ds):
+
+        # load the pretrained model
+        model_state = self.state_path.joinpath(self.pretrained_model)
+        if not model_state.exists():
+            raise FileNotFoundError('Pretrained model {} does not exist.'
+                                    .format(model_state))
+
+        # load the model state
+        model_state = torch.load(model_state)
+
+        # get the input bands of the pretrained model
+        bands = model_state['bands']
+
+        # get the number of convolutional filters
+        filters = model_state['params']['filters']
+
+        # check whether the current dataset uses the correct spectral bands
+        if ds.use_bands != bands:
+            raise ValueError('The bands of the pretrained network do not '
+                             'match the specified bands: {}'
+                             .format(bands))
+
+        # instanciate pretrained model architecture
+        model = self.model_class(**model_state['params'],
+                                 **model_state['kwargs'])
+
+        # load pretrained model weights
+        model.load(self.pretrained_model, inpath=str(self.state_path))
+
+        # reset model epoch to 0, since the model is trained on a different
+        # dataset
+        model.epoch = 0
+
+        # adjust the number of classes in the model
+        model.nclasses = len(ds.labels)
+
+        # adjust the classification layer to the number of classes of the
+        # current dataset
+        model.classifier = Conv2dSame(in_channels=filters[0],
+                                      out_channels=model.nclasses,
+                                      kernel_size=1)
+
+        return model
+
+
+@dataclasses.dataclass
+class TrainingConfig(BaseConfig):
+    optim_name: str
+    loss_name: str
+    lr: float = 0.001
+    early_stop: bool = False
+    mode: str = 'max'
+    delta: float = 0
+    patience: int = 10
+    epochs: int = 50
+    nthreads: int = torch.get_num_threads()
+
+    def __post_init__(self):
+
+        # check whether the optimizer is currently supported
+        self.optim_class = item_in_enum(self.optim_name, SupportedOptimizers)
+
+        # check whether the loss function is currently supported
+        self.loss_class = item_in_enum(self.loss_name, SupportedLossFunctions)
+
+    def init_optimizer(self, model):
+
+        # initialize the optimizer for the specified model
+        optimizer = self.optim_class(model.parameters(), self.lr)
+
+        return optimizer
+
+    def init_loss_function(self):
+
+        loss_function = self.loss_class()
+
+        return loss_function
+
+
+@dataclasses.dataclass
+class NetworkTrainer(BaseConfig):
+    dconfig: dict = dataclasses.field(default_factory=dict)
+    sconfig: dict = dataclasses.field(default_factory=dict)
+    mconfig: dict = dataclasses.field(default_factory=dict)
+    tconfig: dict = dataclasses.field(default_factory=dict)
+
+    def __post_init__(self):
+        super().__post_init__()
+
+        # whether to use the gpu
+        self.device = torch.device("cuda:0" if torch.cuda.is_available() else
+                                   "cpu")
+
+        # instanciate the configurations
+        self.dc = DatasetConfig(**self.dconfig)
+        self.sc = SplitConfig(**self.sconfig)
+        self.mc = ModelConfig(**self.mconfig)
+        self.tc = TrainingConfig(**self.tconfig)
+
+        # initialize the dataset to train the model on
+        self.dataset = self.dc.init_dataset()
+
+        # inialize the training, validation and test dataset
+        (self.train_ds, self.valid_ds,
+         self.test_ds) = self.sc.train_val_test_split(self.dataset)
+
+        # create the dataloaders
+        self._build_dataloaders()
+
+        # initialize the model state files
+        self.state_file, self.loss_state = self.mc.init_state(self.dataset)
+
+        # initialize the model
+        self.model = self.mc.init_model(self.dataset)
+
+        # initialize the optimizer
+        self.optimizer = self.tc.init_optimizer(self.model)
+
+        # initialize the loss function
+        self.loss_function = self.tc.init_loss_function()
+
+        # whether to resume training from an existing model
+        self.checkpoint_state, self.max_accuracy = self.mc.load_checkpoint(
+            self.state_file, self.loss_state, self.model, self.optimizer)
+
+    def train(self):
+
+        print('------------------------- Training ---------------------------')
+
+        # set the number of threads
+        torch.set_num_threads(self.tc.nthreads)
+
+        # instanciate early stopping class
+        if self.tc.early_stop:
+            es = EarlyStopping(self.tc.mode, self.tc.delta, self.tc.patience)
+            print('Initializing early stopping ...')
+            print('mode = {}, delta = {}, patience = {} epochs ...'
+                  .format(self.tc.mode, self.tc.delta, self.tc.patience))
+
+        # create dictionary of the observed losses and accuracies on the
+        # training and validation dataset
+        tshape = (len(self.train_dl), self.tc.epochs)
+        vshape = (len(self.valid_dl), self.tc.epochs)
+        training_state = {'tl': np.zeros(shape=tshape),
+                          'ta': np.zeros(shape=tshape),
+                          'vl': np.zeros(shape=vshape),
+                          'va': np.zeros(shape=vshape)
+                          }
+
+        # send the model to the gpu if available
+        self.model = self.model.to(self.device)
+
+        # initialize the training: iterate over the entire training data set
+        for epoch in range(self.tc.epochs):
+
+            # set the model to training mode
+            print('Setting model to training mode ...')
+            self.model.train()
+
+            # iterate over the dataloader object
+            for batch, (inputs, labels) in enumerate(self.train_dl):
+
+                # send the data to the gpu if available
+                inputs = inputs.to(self.device)
+                labels = labels.to(self.device)
+
+                # reset the gradients
+                self.optimizer.zero_grad()
+
+                # perform forward pass
+                outputs = self.model(inputs)
+
+                # compute loss
+                loss = self.loss_function(outputs, labels.long())
+                observed_loss = loss.detach().numpy().item()
+                training_state['tl'][batch, epoch] = observed_loss
+
+                # compute the gradients of the loss function w.r.t.
+                # the network weights
+                loss.backward()
+
+                # update the weights
+                self.optimizer.step()
+
+                # calculate predicted class labels
+                ypred = F.softmax(outputs, dim=1).argmax(dim=1)
+
+                # calculate accuracy on current batch
+                observed_accuracy = accuracy_function(ypred, labels)
+                training_state['ta'][batch, epoch] = observed_accuracy
+
+                # print progress
+                print('Epoch: {:d}/{:d}, Mini-batch: {:d}/{:d}, Loss: {:.2f}, '
+                      'Accuracy: {:.2f}'.format(epoch + 1,
+                                                self.tc.epochs,
+                                                batch + 1,
+                                                len(self.train_dl),
+                                                observed_loss,
+                                                observed_accuracy))
+
+            # update the number of epochs trained
+            self.model.epoch += 1
+
+            # whether to evaluate model performance on the validation set and
+            # early stop the training process
+            if self.tc.early_stop:
+
+                # model predictions on the validation set
+                vacc, vloss = self.predict()
+
+                # append observed accuracy and loss to arrays
+                training_state['va'][:, epoch] = vacc.squeeze()
+                training_state['vl'][:, epoch] = vloss.squeeze()
+
+                # metric to assess model performance on the validation set
+                epoch_acc = vacc.squeeze().mean()
+
+                # whether the model improved with respect to the previous epoch
+                if es.increased(epoch_acc, self.max_accuracy, self.tc.delta):
+                    self.max_accuracy = epoch_acc
+                    # save model state if the model improved with
+                    # respect to the previous epoch
+                    _ = self.model.save(self.state_file,
+                                        self.optimizer,
+                                        self.dataset.use_bands,
+                                        self.mc.state_path)
+
+                    # save losses and accuracy
+                    self._save_loss(training_state)
+
+                # whether the early stopping criterion is met
+                if es.stop(epoch_acc):
+                    break
+
+            else:
+                # if no early stopping is required, the model state is saved
+                # after each epoch
+                _ = self.model.save(self.state_file,
+                                    self.optimizer,
+                                    self.dataset.use_bands,
+                                    self.mc.state_path)
+
+                # save losses and accuracy after each epoch
+                self._save_loss(training_state)
+
+        return training_state
+
+    def predict(self):
+
+        print('------------------------ Predicting --------------------------')
+
+        # send the model to the gpu if available
+        self.model = self.model.to(self.device)
+
+        # set the model to evaluation mode
+        print('Setting model to evaluation mode ...')
+        self.model.eval()
+
+        # create arrays of the observed losses and accuracies
+        accuracies = np.zeros(shape=(len(self.valid_dl), 1))
+        losses = np.zeros(shape=(len(self.valid_dl), 1))
+
+        # iterate over the validation/test set
+        print('Calculating accuracy on the validation set ...')
+        for batch, (inputs, labels) in enumerate(self.valid_dl):
+
+            # send the data to the gpu if available
+            inputs = inputs.to(self.device)
+            labels = labels.to(self.device)
+
+            # calculate network outputs
+            with torch.no_grad():
+                outputs = self.model(inputs)
+
+            # compute loss
+            loss = self.loss_function(outputs, labels.long())
+            losses[batch, 0] = loss.detach().numpy().item()
+
+            # calculate predicted class labels
+            pred = F.softmax(outputs, dim=1).argmax(dim=1)
+
+            # calculate accuracy on current batch
+            acc = accuracy_function(pred, labels)
+            accuracies[batch, 0] = acc
+
+            # print progress
+            print('Mini-batch: {:d}/{:d}, Accuracy: {:.2f}'
+                  .format(batch + 1, len(self.valid_dl), acc))
+
+        # calculate overall accuracy on the validation/test set
+        print('After training for {:d} epochs, we achieved an overall '
+              'accuracy of {:.2f}%  on the validation set!'
+              .format(self.model.epoch, accuracies.mean() * 100))
+
+        return accuracies, losses
+
+    def _build_dataloaders(self):
+
+        # the shape of a single tile
+        self.tile_shape = (len(self.dataset.use_bands),
+                           self.dataset.tile_size,
+                           self.dataset.tile_size)
+
+        # the training dataloader
+        self.train_dl = None
+        if len(self.train_ds) > 0:
+            self.train_dl = DataLoader(self.train_ds,
+                                       self.mc.batch_size,
+                                       shuffle=True,
+                                       drop_last=False)
+        # the validation dataloader
+        self.valid_dl = None
+        if len(self.valid_ds) > 0:
+            self.valid_dl = DataLoader(self.valid_ds,
+                                       self.mc.batch_size,
+                                       shuffle=True,
+                                       drop_last=False)
+
+        # the test dataloader
+        self.test_dl = None
+        if len(self.test_ds) > 0:
+            self.test_dl = DataLoader(self.test_ds,
+                                      self.mc.batch_size,
+                                      shuffle=True,
+                                      drop_last=False)
+
+    def _save_loss(self, training_state):
+
+        # save losses and accuracy
+        if self.mc.checkpoint and self.checkpoint_state is not None:
+
+            # append values from checkpoint to current training
+            # state
+            torch.save({
+                k1: np.hstack([v1, v2]) for (k1, v1), (k2, v2) in
+                zip(self.checkpoint_state.items(), training_state.items())
+                if k1 == k2},
+                self.loss_state)
+        else:
+            torch.save(training_state, self.loss_state)
+
+
+
+    def __repr__(self):
+
+        # representation string to print
+        fs = self.__class__.__name__ + '(\n'
+
+        # dataset
+        fs += '    (dataset):\n        '
+        fs += ''.join(repr(self.dataset)).replace('\n', '\n        ')
+
+        # batch size
+        fs += '\n    (batch):\n        '
+        fs += '- batch size: {}\n        '.format(self.mc.batch_size)
+        fs += '- tile shape (c, h, w): {}\n        '.format(self.tile_shape)
+        fs += '- mini-batch shape (b, c, h, w): {}'.format(
+            (self.mc.batch_size,) + self.tile_shape)
+
+        # dataset split
+        fs += '\n    (split):'
+        fs += '\n        ' + repr(self.train_ds)
+        fs += '\n        ' + repr(self.valid_ds)
+        fs += '\n        ' + repr(self.test_ds)
+
+        # model
+        fs += '\n    (model):\n        '
+        fs += ''.join(repr(self.model)).replace('\n', '\n        ')
+
+        # optimizer
+        fs += '\n    (optimizer):\n        '
+        fs += ''.join(repr(self.optimizer)).replace('\n', '\n        ')
+        fs += '\n)'
+
+        return fs
+
+
+class EarlyStopping(object):
+
+    def __init__(self, mode='max', min_delta=0, patience=10):
+
+        # check if mode is correctly specified
+        if mode not in ['min', 'max']:
+            raise ValueError('Mode "{}" not supported. '
+                             'Mode is either "min" (check whether the metric '
+                             'decreased, e.g. loss) or "max" (check whether '
+                             'the metric increased, e.g. accuracy).'
+                             .format(mode))
+
+        # mode to determine if metric improved
+        self.mode = mode
+
+        # whether to check for an increase or a decrease in a given metric
+        self.is_better = self.decreased if mode == 'min' else self.increased
+
+        # minimum change in metric to be classified as an improvement
+        self.min_delta = min_delta
+
+        # number of epochs to wait for improvement
+        self.patience = patience
+
+        # initialize best metric
+        self.best = None
+
+        # initialize early stopping flag
+        self.early_stop = False
+
+    def stop(self, metric):
+
+        if self.best is not None:
+
+            # if the metric improved, reset the epochs counter, else, advance
+            if self.is_better(metric, self.best, self.min_delta):
+                self.counter = 0
+                self.best = metric
+            else:
+                self.counter += 1
+                print('Early stopping counter: {}/{}'.format(self.counter,
+                                                             self.patience))
+
+            # if the metric did not improve over the last patience epochs,
+            # the early stopping criterion is met
+            if self.counter >= self.patience:
+                print('Early stopping criterion met, exiting training ...')
+                self.early_stop = True
+
+        else:
+            self.best = metric
+
+        return self.early_stop
+
+    def decreased(self, metric, best, min_delta):
+        return metric < best - min_delta
+
+    def increased(self, metric, best, min_delta):
+        return metric > best + min_delta
diff --git a/pysegcnn/main/train.py b/pysegcnn/main/train.py
index 6d9556c..af0099e 100644
--- a/pysegcnn/main/train.py
+++ b/pysegcnn/main/train.py
@@ -6,15 +6,19 @@ Created on Tue Jun 30 09:33:38 2020
 @author: Daniel
 """
 # locals
-from pysegcnn.core.trainer import NetworkTrainer
-from pysegcnn.main.config import config
+from pysegcnn.core.initconf import NetworkTrainer
+from pysegcnn.main.config import (dataset_config, split_config,
+                                  model_config, train_config)
 
 
 if __name__ == '__main__':
 
     # instanciate the NetworkTrainer class
-    trainer = NetworkTrainer(config)
-    print(trainer)
+    trainer = NetworkTrainer(dconfig=dataset_config,
+                             sconfig=split_config,
+                             mconfig=model_config,
+                             tconfig=train_config)
+    trainer
 
     # train the network
     training_state = trainer.train()
-- 
GitLab