diff --git a/pysegcnn/core/trainer.py b/pysegcnn/core/trainer.py index fa8715ed8882478792d8095bc07ad55ba0ce1546..73d3a4399f3bfa6ca3ce4c011e48cf71a3ad143c 100644 --- a/pysegcnn/core/trainer.py +++ b/pysegcnn/core/trainer.py @@ -1,53 +1,338 @@ -# !/usr/bin/env python # -*- coding: utf-8 -*- """ -Created on Fri Jun 26 16:31:36 2020 +Created on Wed Aug 12 10:24:34 2020 @author: Daniel """ # builtins -import os +import dataclasses +import pathlib # externals import numpy as np import torch +import torch.nn as nn import torch.nn.functional as F -from torch.utils.data import DataLoader +from torch.utils.data import DataLoader, Dataset +from torch.optim import Optimizer + # locals -from pysegcnn.core.dataset import SupportedDatasets +from pysegcnn.core.dataset import SupportedDatasets, ImageDataset +from pysegcnn.core.transforms import Augment +from pysegcnn.core.utils import img2np, item_in_enum, accuracy_function +from pysegcnn.core.split import SupportedSplits +from pysegcnn.core.models import (SupportedModels, SupportedOptimizers, + SupportedLossFunctions, Network) from pysegcnn.core.layers import Conv2dSame -from pysegcnn.core.utils import img2np, accuracy_function -from pysegcnn.core.split import (RandomTileSplit, RandomSceneSplit, DateSplit, - VALID_SPLIT_MODES) +from pysegcnn.main.config import HERE + + +@dataclasses.dataclass +class BaseConfig: + + def __post_init__(self): + # check input types + for field in dataclasses.fields(self): + # the value of the current field + value = getattr(self, field.name) + + # check whether the value is of the correct type + if not isinstance(value, field.type): + # try to convert the value to the correct type + try: + setattr(self, field.name, field.type(value)) + except TypeError: + # raise an exception if the conversion fails + raise TypeError('Expected {} to be {}, got {}.' + .format(field.name, field.type, + type(value))) + + +@dataclasses.dataclass +class DatasetConfig(BaseConfig): + dataset_name: str + root_dir: pathlib.Path + bands: list + tile_size: int + gt_pattern: str + seed: int + sort: bool = False + transforms: list = dataclasses.field(default_factory=list) + pad: bool = False + cval: int = 99 + + def __post_init__(self): + # check input types + super().__post_init__() + # check whether the dataset is currently supported + self.dataset_class = item_in_enum(self.dataset_name, SupportedDatasets) -class NetworkTrainer(object): + # check whether the root directory exists + if not self.root_dir.exists(): + raise FileNotFoundError('{} does not exist.'.format(self.root_dir)) - def __init__(self, config): + # check whether the transformations inherit from the correct class + if not all([isinstance(t, Augment) for t in self.transforms if + self.transforms]): + raise TypeError('Each transformation is expected to be an instance' + ' of {}.'.format('.'.join([Augment.__module__, + Augment.__name__]))) - # the configuration file as defined in pysegcnn.main.config.py - for k, v in config.items(): - setattr(self, k, v) + # check whether the constant padding value is within the valid range + if not 0 < self.cval < 255: + raise ValueError('Expecting 0 <= cval <= 255, got cval={}.' + .format(self.cval)) - # whether to use the gpu - self.device = torch.device("cuda:0" if torch.cuda.is_available() else - "cpu") + def init_dataset(self): + + # instanciate the dataset + dataset = self.dataset_class( + root_dir=str(self.root_dir), + use_bands=self.bands, + tile_size=self.tile_size, + seed=self.seed, + sort=self.sort, + transforms=self.transforms, + pad=self.pad, + cval=self.cval, + gt_pattern=self.gt_pattern + ) + + return dataset + + +@dataclasses.dataclass +class SplitConfig(BaseConfig): + split_mode: str + ttratio: float + tvratio: float + date: str = 'yyyymmdd' + dateformat: str = '%Y%m%d' + drop: float = 0 + + def __post_init__(self): + # check input types + super().__post_init__() + + # check if the split mode is valid + self.split_class = item_in_enum(self.split_mode, SupportedSplits) + + # function to drop samples with a fraction of pixels equal to the constant + # padding value self.cval >= self.drop + def _drop_samples(self, ds, drop_threshold=1): + + # iterate over the scenes returned by self.compose_scenes() + dropped = [] + for pos, i in enumerate(ds.indices): + + # the current scene + s = ds.dataset.scenes[i] + + # the current tile in the ground truth + tile_gt = img2np(s['gt'], ds.dataset.tile_size, s['tile'], + ds.dataset.pad, ds.dataset.cval) - # initialize the dataset to train the model on - self._init_dataset() + # percent of pixels equal to the constant padding value + npixels = (tile_gt[tile_gt == ds.dataset.cval].size / tile_gt.size) + + # drop samples where npixels >= self.drop + if npixels >= drop_threshold: + print('Skipping scene {}, tile {}: {:.2f}% padded pixels ...' + .format(s['id'], s['tile'], npixels * 100)) + dropped.append(s) + _ = ds.indices.pop(pos) + + return dropped + + def train_val_test_split(self, ds): + + if not isinstance(ds, ImageDataset): + raise TypeError('Expected "ds" to be {}.' + .format('.'.join([ImageDataset.__module__, + ImageDataset.__name__]))) + + if self.split_mode == 'random' or self.split_mode == 'scene': + subset = self.split_class(ds, + self.ttratio, + self.tvratio, + ds.seed) + + else: + subset = self.split_class(ds, self.date, self.dateformat) + + # the training, validation and test dataset + train_ds, valid_ds, test_ds = subset.split() + + # whether to drop training samples with a fraction of pixels equal to + # the constant padding value cval >= drop + if ds.pad and self.drop > 0: + self.dropped = self._drop_samples(train_ds, self.drop) + + return train_ds, valid_ds, test_ds + + @staticmethod + def dataloaders(*args, **kwargs): + # check whether each dataset in args has the correct type + loaders = [] + for ds in args: + if not isinstance(ds, Dataset): + raise TypeError('Expected {}, got {}.' + .format(repr(Dataset), type(ds))) + + # check if the dataset is not empty + if len(ds) > 0: + # build the dataloader + loader = DataLoader(ds, **kwargs) + else: + loader = None + loaders.append(loader) - # initialize the model state files - self._init_state() + return loaders - # initialize the model - self._init_model() - def from_pretrained(self): +@dataclasses.dataclass +class ModelConfig(BaseConfig): + model_name: str + filters: list + torch_seed: int + skip_connection: bool = True + kwargs: dict = dataclasses.field( + default_factory=lambda: {'kernel_size': 3, 'stride': 1, 'dilation': 1}) + state_path: pathlib.Path = pathlib.Path(HERE).joinpath('_models/') + batch_size: int = 64 + checkpoint: bool = False + pretrained: bool = False + pretrained_model: str = '' + + def __post_init__(self): + # check input types + super().__post_init__() + + # check whether the model is currently supported + self.model_class = item_in_enum(self.model_name, SupportedModels) + + def init_state(self, ds, sc, tc): + + # file to save model state to: + # network_dataset_optim_split_splitparams_tilesize_batchsize_bands.pt + + # model state filename + state_file = '{}_{}_{}_{}Split_{}_t{}_b{}_{}.pt' + + # get the band numbers + bformat = ''.join(band[0] + + str(ds.sensor.__members__[band].value) for + band in ds.use_bands) + + # check which split mode was used + if sc.split_mode == 'date': + # store the date that was used to split the dataset + state_file = state_file.format(self.model_class.__name__, + ds.__class__.__name__, + tc.optim_name, + sc.split_mode.capitalize(), + sc.date, + ds.tile_size, + self.batch_size, + bformat) + else: + # store the random split parameters + split_params = 's{}_t{}v{}'.format( + ds.seed, str(sc.ttratio).replace('.', ''), + str(sc.tvratio).replace('.', '')) + + # model state filename + state_file = state_file.format(self.model_class.__name__, + ds.__class__.__name__, + tc.optim_name, + sc.split_mode.capitalize(), + split_params, + ds.tile_size, + self.batch_size, + bformat) + + # check whether a pretrained model was used and change state filename + # accordingly + if self.pretrained: + # add the configuration of the pretrained model to the state name + state_file = (state_file.replace('.pt', '_') + + 'pretrained_' + self.pretrained_model) + + # path to model state + state = self.state_path.joinpath(state_file) + + # path to model loss/accuracy + loss_state = pathlib.Path(str(state).replace('.pt', '_loss.pt')) + + return state, loss_state + + def init_model(self, ds): + + # case (1): build a new model + if not self.pretrained: + + # set the random seed for reproducibility + torch.manual_seed(self.torch_seed) + + # instanciate the model + model = self.model_class( + in_channels=len(ds.use_bands), + nclasses=len(ds.labels), + filters=self.filters, + skip=self.skip_connection, + **self.kwargs) + + # case (2): load a pretrained model + else: + + # load pretrained model + model = self.load_pretrained() + + return model + + def load_checkpoint(self, state_file, loss_state, model, optimizer): + + # initial accuracy on the validation set + max_accuracy = 0 + + # set the model checkpoint to None, overwritten when resuming + # training from an existing model checkpoint + checkpoint_state = {} + + # whether to resume training from an existing model + if self.checkpoint: + + # check if a model checkpoint exists + if not state_file.exists(): + raise FileNotFoundError('Model checkpoint {} does not exist.' + .format(state_file)) + + # load the model state + state = model.load(state_file.name, optimizer, self.state_path) + print('Found checkpoint: {}'.format(state)) + print('Resuming training from checkpoint ...'.format(state)) + print('Model epoch: {:d}'.format(model.epoch)) + + # load the model loss and accuracy + checkpoint_state = torch.load(loss_state) + + # get all non-zero elements, i.e. get number of epochs trained + # before the early stop + checkpoint_state = {k: v[np.nonzero(v)].reshape(v.shape[0], -1) + for k, v in checkpoint_state.items()} + + # maximum accuracy on the validation set + max_accuracy = checkpoint_state['va'][:, -1].mean().item() + + return checkpoint_state, max_accuracy + + def load_pretrained(self, ds): # load the pretrained model - model_state = os.path.join(self.state_path, self.pretrained_model) - if not os.path.exists(model_state): + model_state = self.state_path.joinpath(self.pretrained_model) + if not model_state.exists(): raise FileNotFoundError('Pretrained model {} does not exist.' .format(model_state)) @@ -61,23 +346,24 @@ class NetworkTrainer(object): filters = model_state['params']['filters'] # check whether the current dataset uses the correct spectral bands - if self.bands != bands: + if ds.use_bands != bands: raise ValueError('The bands of the pretrained network do not ' 'match the specified bands: {}' - .format(self.bands)) + .format(bands)) # instanciate pretrained model architecture - model = self.model(**model_state['params'], **model_state['kwargs']) + model = self.model_class(**model_state['params'], + **model_state['kwargs']) # load pretrained model weights - model.load(self.pretrained_model, inpath=self.state_path) + model.load(self.pretrained_model, inpath=str(self.state_path)) # reset model epoch to 0, since the model is trained on a different # dataset model.epoch = 0 # adjust the number of classes in the model - model.nclasses = len(self.dataset.labels) + model.nclasses = len(ds.labels) # adjust the classification layer to the number of classes of the # current dataset @@ -85,49 +371,103 @@ class NetworkTrainer(object): out_channels=model.nclasses, kernel_size=1) - return model - def from_checkpoint(self): - - # whether to resume training from an existing model - if not os.path.exists(self.state): - raise FileNotFoundError('Model checkpoint {} does not exist.' - .format(self.state)) - - # load the model state - state = self.model.load(self.state_file, self.optimizer, - self.state_path) - print('Resuming training from {} ...'.format(state)) - print('Model epoch: {:d}'.format(self.model.epoch)) - - # load the model loss and accuracy - checkpoint_state = torch.load(self.loss_state) - # get all non-zero elements, i.e. get number of epochs trained - # before the early stop - checkpoint_state = {k: v[np.nonzero(v)].reshape(v.shape[0], -1) for - k, v in checkpoint_state.items()} +@dataclasses.dataclass +class TrainConfig(BaseConfig): + optim_name: str + loss_name: str + lr: float = 0.001 + early_stop: bool = False + mode: str = 'max' + delta: float = 0 + patience: int = 10 + epochs: int = 50 + nthreads: int = torch.get_num_threads() + save: bool = True + + def __post_init__(self): + super().__post_init__() + + # check whether the optimizer is currently supported + self.optim_class = item_in_enum(self.optim_name, SupportedOptimizers) + + # check whether the loss function is currently supported + self.loss_class = item_in_enum(self.loss_name, SupportedLossFunctions) + + def init_optimizer(self, model): + + # initialize the optimizer for the specified model + optimizer = self.optim_class(model.parameters(), self.lr) + + return optimizer + + def init_loss_function(self): + + loss_function = self.loss_class() + + return loss_function + + +@dataclasses.dataclass +class EvalConfig(BaseConfig): + test: object + predict_scene: bool = False + plot_samples: bool = False + plot_scenes: bool = False + plot_bands: list = dataclasses.field( + default_factory=lambda: ['nir', 'red', 'green']) + cm: bool = True + + def __post_init__(self): + super().__post_init__() + + # check whether the test input parameter is correctly specified + if self.test not in [None, False, True]: + raise TypeError('Expected "test" to be None, True or False, got ' + '{}.'.format(self.test)) + +@dataclasses.dataclass +class NetworkTrainer(BaseConfig): + model: Network + optimizer: Optimizer + loss_function: nn.Module + train_dl: DataLoader + valid_dl: DataLoader + state_file: pathlib.Path + loss_state: pathlib.Path + epochs: int = 1 + nthreads: int = torch.get_num_threads() + early_stop: bool = False + mode: str = 'max' + delta: float = 0 + patience: int = 10 + max_accuracy: float = 0 + checkpoint_state: dict = dataclasses.field(default_factory=dict) + save: bool = True + + def __post_init__(self): + super().__post_init__() - # maximum accuracy on the validation set - max_accuracy = checkpoint_state['va'][:, -1].mean().item() + # whether to use the gpu + self.device = torch.device("cuda:0" if torch.cuda.is_available() + else "cpu") - return checkpoint_state, max_accuracy + # whether to use early stopping + self.es = None + if self.early_stop: + self.es = EarlyStopping(self.mode, self.delta, self.patience) def train(self): print('------------------------- Training ---------------------------') # set the number of threads + print('Device: {}'.format(self.device)) + print('Number of cpu threads: {}'.format(self.nthreads)) torch.set_num_threads(self.nthreads) - # instanciate early stopping class - if self.early_stop: - es = EarlyStopping(self.mode, self.delta, self.patience) - print('Initializing early stopping ...') - print('mode = {}, delta = {}, patience = {} epochs ...' - .format(self.mode, self.delta, self.patience)) - # create dictionary of the observed losses and accuracies on the # training and validation dataset tshape = (len(self.train_dl), self.epochs) @@ -207,36 +547,22 @@ class NetworkTrainer(object): epoch_acc = vacc.squeeze().mean() # whether the model improved with respect to the previous epoch - if es.increased(epoch_acc, self.max_accuracy, self.delta): + if self.es.increased(epoch_acc, self.max_accuracy, self.delta): self.max_accuracy = epoch_acc + # save model state if the model improved with # respect to the previous epoch - _ = self.model.save(self.state_file, - self.optimizer, - self.bands, - self.state_path) - - # save losses and accuracy - self._save_loss(training_state, - self.checkpoint, - self.checkpoint_state) + self.save_state(training_state) # whether the early stopping criterion is met - if es.stop(epoch_acc): + if self.es.stop(epoch_acc): break else: - # if no early stopping is required, the model state is saved - # after each epoch - _ = self.model.save(self.state_file, - self.optimizer, - self.bands, - self.state_path) + # if no early stopping is required, the model state is + # saved after each epoch + self.save_state(training_state) - # save losses and accuracy after each epoch - self._save_loss(training_state, - self.checkpoint, - self.checkpoint_state) return training_state @@ -283,217 +609,37 @@ class NetworkTrainer(object): .format(batch + 1, len(self.valid_dl), acc)) # calculate overall accuracy on the validation/test set - print('After training for {:d} epochs, we achieved an overall ' - 'accuracy of {:.2f}% on the validation set!' + print('Epoch {:d}, Overall accuracy: {:.2f}%.' .format(self.model.epoch, accuracies.mean() * 100)) return accuracies, losses - def _init_state(self): - - # file to save model state to - # format: network_dataset_seed_tilesize_batchsize_bands.pt - - # get the band numbers - bformat = ''.join(band[0] + - str(self.dataset.sensor.__members__[band].value) for - band in self.bands) - - # model state filename - self.state_file = ('{}_{}_s{}_t{}_b{}_{}.pt' - .format(self.model.__name__, - self.dataset.__class__.__name__, - self.seed, - self.tile_size, - self.batch_size, - bformat)) - - # check whether a pretrained model was used and change state filename - # accordingly - if self.pretrained: - # add the configuration of the pretrained model to the state name - self.state_file = (self.state_file.replace('.pt', '_') + - 'pretrained_' + self.pretrained_model) - - # path to model state - self.state = os.path.join(self.state_path, self.state_file) - - # path to model loss/accuracy - self.loss_state = self.state.replace('.pt', '_loss.pt') - - def _init_dataset(self): - - # the dataset name - self.dataset_name = os.path.basename(self.root_dir) - - # check whether the dataset is currently supported - if self.dataset_name not in SupportedDatasets.__members__: - raise ValueError('{} is not a valid dataset. ' - .format(self.dataset_name) + - 'Available datasets are: \n' + - '\n'.join(name for name, _ in - SupportedDatasets.__members__.items())) - else: - self.dataset_class = SupportedDatasets.__members__[ - self.dataset_name].value - - # instanciate the dataset - self.dataset = self.dataset_class( - self.root_dir, - use_bands=self.bands, - tile_size=self.tile_size, - sort=self.sort, - transforms=self.transforms, - pad=self.pad, - cval=self.cval, - gt_pattern=self.gt_pattern - ) - - # the mode to split - if self.split_mode not in VALID_SPLIT_MODES: - raise ValueError('{} is not supported. Valid modes are {}, see ' - 'pysegcnn.main.config.py for a description of ' - 'each mode.'.format(self.split_mode, - VALID_SPLIT_MODES)) - if self.split_mode == 'random': - self.subset = RandomTileSplit(self.dataset, - self.ttratio, - self.tvratio, - self.seed) - if self.split_mode == 'scene': - self.subset = RandomSceneSplit(self.dataset, - self.ttratio, - self.tvratio, - self.seed) - if self.split_mode == 'date': - self.subset = DateSplit(self.dataset, - self.date, - self.dateformat) - - # the training, validation and test dataset - self.train_ds, self.valid_ds, self.test_ds = self.subset.split() - - # whether to drop training samples with a fraction of pixels equal to - # the constant padding value self.cval >= self.drop - if self.pad and self.drop: - self._drop(self.train_ds) - - # the shape of a single batch - self.batch_shape = (len(self.bands), self.tile_size, self.tile_size) - - # the training dataloader - self.train_dl = None - if len(self.train_ds) > 0: - self.train_dl = DataLoader(self.train_ds, - self.batch_size, - shuffle=True, - drop_last=False) - # the validation dataloader - self.valid_dl = None - if len(self.valid_ds) > 0: - self.valid_dl = DataLoader(self.valid_ds, - self.batch_size, - shuffle=True, - drop_last=False) - - # the test dataloader - self.test_dl = None - if len(self.test_ds) > 0: - self.test_dl = DataLoader(self.test_ds, - self.batch_size, - shuffle=True, - drop_last=False) - - def _init_model(self): - - # initial accuracy on the validation set - self.max_accuracy = 0 - - # set the model checkpoint to None, overwritten when resuming - # training from an existing model checkpoint - self.checkpoint_state = None - - # case (1): build a model for the specified dataset - if not self.pretrained and not self.checkpoint: - - # instanciate the model - self.model = self.model(in_channels=len(self.dataset.use_bands), - nclasses=len(self.dataset.labels), - filters=self.filters, - skip=self.skip_connection, - **self.kwargs) - - # the optimizer used to update the model weights - self.optimizer = self.optimizer(self.model.parameters(), self.lr) - - # case (2): using a pretrained model withouth existing checkpoint on - # a new dataset, i.e. transfer learning - if self.pretrained and not self.checkpoint: - # load pretrained model - self.model = self.from_pretrained() - - # the optimizer used to update the model weights - self.optimizer = self.optimizer(self.model.parameters(), self.lr) - - # case (3): using a pretrained model with existing checkpoint on the - # same dataset the pretrained model was trained on - elif self.checkpoint: - - # instanciate the model - self.model = self.model(in_channels=len(self.dataset.use_bands), - nclasses=len(self.dataset.labels), - filters=self.filters, - skip=self.skip_connection, - **self.kwargs) - - # the optimizer used to update the model weights - self.optimizer = self.optimizer(self.model.parameters(), self.lr) - - # whether to resume training from an existing model checkpoint - if self.checkpoint: - (self.checkpoint_state, - self.max_accuracy) = self.from_checkpoint() - - # function to drop samples with a fraction of pixels equal to the constant - # padding value self.cval >= self.drop - def _drop(self, ds): + def save_state(self, training_state): - # iterate over the scenes returned by self.compose_scenes() - self.dropped = [] - for pos, i in enumerate(ds.indices): + # whether to save the model state + if self.save: + # save model state + state = self.model.save(self.state_file.name, + self.optimizer, + self.train_dl.dataset.dataset.use_bands, + self.state_file.parent) - # the current scene - s = ds.dataset.scenes[i] + # save losses and accuracy + self._save_loss(training_state) - # the current tile in the ground truth - tile_gt = img2np(s['gt'], self.tile_size, s['tile'], - self.pad, self.cval) + def _save_loss(self, training_state): - # percent of pixels equal to the constant padding value - npixels = (tile_gt[tile_gt == self.cval].size / tile_gt.size) - - # drop samples where npixels >= self.drop - if npixels >= self.drop: - print('Skipping scene {}, tile {}: {:.2f}% padded pixels ...' - .format(s['id'], s['tile'], npixels * 100)) - self.dropped.append(s) - _ = ds.indices.pop(pos) + # save losses and accuracy + state = training_state + if self.checkpoint_state: - def _save_loss(self, training_state, checkpoint=False, - checkpoint_state=None): + # append values from checkpoint to current training state + state = {k1: np.hstack([v1, v2]) for (k1, v1), (k2, v2) in + zip(self.checkpoint_state.items(), training_state.items()) + if k1 == k2} - # save losses and accuracy - if checkpoint and checkpoint_state is not None: - - # append values from checkpoint to current training - # state - torch.save({ - k1: np.hstack([v1, v2]) for (k1, v1), (k2, v2) in - zip(checkpoint_state.items(), training_state.items()) - if k1 == k2}, - self.loss_state) - else: - torch.save(training_state, self.loss_state) + # save the model loss and accuracies to file + torch.save(state, self.loss_state) def __repr__(self): @@ -502,26 +648,37 @@ class NetworkTrainer(object): # dataset fs += ' (dataset):\n ' - fs += ''.join(self.dataset.__repr__()).replace('\n', '\n ') + fs += ''.join( + repr(self.train_dl.dataset.dataset)).replace('\n','\n ') # batch size fs += '\n (batch):\n ' - fs += '- batch size: {}\n '.format(self.batch_size) - fs += '- batch shape (b, h, w): {}'.format(self.batch_shape) + fs += '- batch size: {}\n '.format(self.train_dl.batch_size) + fs += '- mini-batch shape (b, c, h, w): {}'.format( + (self.train_dl.batch_size, + len(self.train_dl.dataset.dataset.use_bands), + self.train_dl.dataset.dataset.tile_size, + self.train_dl.dataset.dataset.tile_size) + ) # dataset split - fs += '\n (split):\n ' - fs += ''.join(self.subset.__repr__()).replace('\n', '\n ') + fs += '\n (split):' + fs += '\n ' + repr(self.train_dl.dataset) + fs += '\n ' + repr(self.valid_dl.dataset) # model fs += '\n (model):\n ' - fs += ''.join(self.model.__repr__()).replace('\n', '\n ') + fs += ''.join(repr(self.model)).replace('\n', '\n ') # optimizer fs += '\n (optimizer):\n ' - fs += ''.join(self.optimizer.__repr__()).replace('\n', '\n ') - fs += '\n)' + fs += ''.join(repr(self.optimizer)).replace('\n', '\n ') + + # early stopping + fs += '\n (early stop):\n ' + fs += ''.join(repr(self.es)).replace('\n', '\n ') + fs += '\n)' return fs @@ -555,6 +712,9 @@ class EarlyStopping(object): # initialize early stopping flag self.early_stop = False + # initialize the early stop counter + self.counter = 0 + def stop(self, metric): if self.best is not None: @@ -584,3 +744,8 @@ class EarlyStopping(object): def increased(self, metric, best, min_delta): return metric > best + min_delta + + def __repr__(self): + fs = (self.__class__.__name__ + '(mode={}, delta={}, patience={})' + .format(self.mode, self.min_delta, self.patience)) + return fs