diff --git a/pysegcnn/core/split.py b/pysegcnn/core/split.py index 412e7cd24a6362876cfcb7887528d8a2051055e6..78f653a3b41c2d0f552c2485e026d8f604487060 100644 --- a/pysegcnn/core/split.py +++ b/pysegcnn/core/split.py @@ -11,6 +11,9 @@ import datetime import numpy as np from torch.utils.data.dataset import Subset +# the names of the subsets +SUBSET_NAMES = ['train', 'valid', 'test'] + # function calculating number of samples in a dataset given a ratio def _ds_len(ds, ratio): @@ -37,29 +40,26 @@ def random_tile_split(ds, tvratio, ttratio=1, seed=0): # length of the training dataset # number of samples: (ttratio * tvratio * len(ds)) train_len = _ds_len(trav_indices, tvratio) - train_indices = trav_indices[:train_len] + train_ind = trav_indices[:train_len] # length of the validation dataset # number of samples: (ttratio * (1- tvratio) * len(ds)) - valid_indices = trav_indices[train_len:] + valid_ind = trav_indices[train_len:] # length of the test dataset # number of samples: ((1 - ttratio) * len(ds)) - test_indices = indices[trav_len:] + test_ind = indices[trav_len:] # get the tiles of the scenes of each dataset - subsets = [] - for dataset in [train_indices, valid_indices, test_indices]: - - # build the subset: store the scenes - sbst = Subset(dataset=ds, indices=list(dataset)) - sbst.scenes = [ds.scenes[i] for i in dataset] + subsets = {} + for name, dataset in enumerate([train_ind, valid_ind, test_ind]): - # add to list of subsets - subsets.append(sbst) + # store the indices and corresponding tiles of the current subset to + # dictionary + subsets[SUBSET_NAMES[name]] = {k: ds.scenes[k] for k in dataset} # check if the splits are disjoint - assert pairwise_disjoint([s.indices for s in subsets]) + assert pairwise_disjoint([s.keys() for s in subsets.values()]) return subsets @@ -95,28 +95,16 @@ def random_scene_split(ds, tvratio, ttratio=1, seed=0): test_scenes = scene_ids[trav_len:] # get the tiles of the scenes of each dataset - subsets = [] - for dataset in [train_scenes, valid_scenes, test_scenes]: - # the indices of the scenes in the dataset - indices = [] - tiles = [] - - # iterate over the scenes of the whole dataset - for i, scene in enumerate(ds.scenes): - if scene['id'] in dataset: - indices.append(i) - tiles.append(scene) - - # build the subset: store scene ids - sbst = Subset(dataset=ds, indices=indices) - sbst.scenes = tiles - sbst.ids = dataset - - # add to list of subsets - subsets.append(sbst) + subsets = {} + for name, dataset in enumerate([train_scenes, valid_scenes, test_scenes]): + + # store the indices and corresponding tiles of the current subset to + # dictionary + subsets[SUBSET_NAMES[name]] = {k: v for k, v in enumerate(ds.scenes) + if v['id'] in dataset} # check if the splits are disjoint - assert pairwise_disjoint([s.indices for s in subsets]) + assert pairwise_disjoint([s.keys() for s in subsets.values()]) return subsets @@ -135,18 +123,16 @@ def date_scene_split(ds, date, dateformat='%Y%m%d'): test_scenes = {} # build the training and test datasets - subsets = [] - for scenes in [train_scenes, valid_scenes, test_scenes]: - # build the subset: store the scenes - sbst = Subset(dataset=ds, indices=list(scenes.keys())) - sbst.scenes = list(scenes.values()) - sbst.ids = np.unique([s['id'] for s in scenes.values()]) + subsets = {} + for name, scenes in enumerate([train_scenes, valid_scenes, test_scenes]): - # add to list of subsets - subsets.append(sbst) + # store the indices and corresponding tiles of the current subset to + # dictionary + subsets[SUBSET_NAMES[name]] = scenes + # sbst.ids = np.unique([s['id'] for s in scenes.values()]) # check if the splits are disjoint - assert pairwise_disjoint([s.indices for s in subsets]) + assert pairwise_disjoint([s.keys() for s in subsets.values()]) return subsets @@ -155,3 +141,103 @@ def pairwise_disjoint(sets): union = set().union(*sets) n = sum(len(u) for u in sets) return n == len(union) + + +class Split(object): + + # the valid modes + valid_modes = ['random', 'scene', 'date'] + + def __init__(self, ds, mode, **kwargs): + + # check which mode is provided + if mode not in self.valid_modes: + raise ValueError('{} is not supported. Valid modes are {}, see ' + 'pysegcnn.main.config.py for a description of ' + 'each mode.'.format(mode, self.valid_modes)) + self.mode = mode + + # the dataset to split + self.ds = ds + + # the keyword arguments + self.kwargs = kwargs + + # initialize split + self._init_split() + + def _init_split(self): + + if self.mode == 'random': + self.subset = RandomSubset + self.split_function = random_tile_split + self.allowed_kwargs = ['tvratio', 'ttratio', 'seed'] + + if self.mode == 'scene': + self.subset = SceneSubset + self.split_function = random_scene_split + self.allowed_kwargs = ['tvratio', 'ttratio', 'seed'] + + if self.mode == 'date': + self.subset = SceneSubset + self.split_function = date_scene_split + self.allowed_kwargs = ['date', 'dateformat'] + + self._check_kwargs() + + def _check_kwargs(self): + + # check if the correct keyword arguments are provided + if not set(self.allowed_kwargs).issubset(self.kwargs.keys()): + raise TypeError('__init__() expecting keyword arguments: {}.' + .format(', '.join(kwa for kwa in + self.allowed_kwargs))) + # select the correct keyword arguments + self.kwargs = {k: self.kwargs[k] for k in self.allowed_kwargs} + + # function apply the split + def split(self): + + # create the subsets + subsets = self.split_function(self.ds, **self.kwargs) + + # build the subsets + ds_split = [] + for name, sub in subsets.items(): + + # the scene identifiers of the current subset + ids = np.unique([s['id'] for s in sub.values()]) + + # build the subset + subset = self.subset(self.ds, list(sub.keys()), name, + list(sub.values()), ids) + ds_split.append(subset) + + return ds_split + + +class SceneSubset(Subset): + + def __init__(self, ds, indices, name, scenes, scene_ids): + super().__init__(dataset=ds, indices=indices) + + # the name of the subset + self.name = name + + # the scene in the subset + self.scenes = scenes + + # the names of the scenes + self.ids = scene_ids + + +class RandomSubset(Subset): + + def __init__(self, ds, indices, name, scenes): + super().__init__(dataset=ds, indices=indices) + + # the name of the subset + self.name = name + + # the scene in the subset + self.scenes = scenes diff --git a/pysegcnn/core/trainer.py b/pysegcnn/core/trainer.py index eba222cbe9541f2381bc56651a0528e3f032d847..b2535ff67626c601a3d133ced7c7700825d9f011 100644 --- a/pysegcnn/core/trainer.py +++ b/pysegcnn/core/trainer.py @@ -17,9 +17,8 @@ from torch.utils.data import DataLoader # locals from pysegcnn.core.dataset import SupportedDatasets from pysegcnn.core.layers import Conv2dSame -from pysegcnn.core.utils import img2np -from pysegcnn.core.split import (random_tile_split, random_scene_split, - date_scene_split) +from pysegcnn.core.utils import img2np, accuracy_function +from pysegcnn.core.split import Split class NetworkTrainer(object): @@ -244,7 +243,7 @@ class NetworkTrainer(object): return training_state - def predict(self, pretrained=False, confusion=False): + def predict(self): print('------------------------ Predicting --------------------------') @@ -341,18 +340,16 @@ class NetworkTrainer(object): '\n'.join(name for name, _ in SupportedDatasets.__members__.items())) - # the training, validation and dataset - if self.split_mode == 'random': - self.train_ds, self.valid_ds, self.test_ds = random_tile_split( - self.dataset, self.tvratio, self.ttratio, self.seed) - - if self.split_mode == 'scene': - self.train_ds, self.valid_ds, self.test_ds = random_scene_split( - self.dataset, self.tvratio, self.ttratio, self.seed) + # instanciate the Split class handling the dataset split + self.subset = Split(self.dataset, self.split_mode, + tvratio=self.tvratio, + ttratio=self.ttratio, + seed=self.seed, + date=self.date, + dateformat=self.dateformat) - if self.split_mode == 'date': - self.train_ds, self.valid_ds, self.test_ds = date_scene_split( - self.dataset, self.date) + # the training, validation and dataset + self.train_ds, self.valid_ds, self.test_ds = self.subset.split() # whether to drop training samples with a fraction of pixels equal to # the constant padding value self.cval >= self.drop @@ -409,7 +406,7 @@ class NetworkTrainer(object): for pos, i in enumerate(ds.indices): # the current scene - s = self.dataset.scenes[i] + s = ds.dataset.scenes[i] # the current tile in the ground truth tile_gt = img2np(s['gt'], self.tile_size, s['tile'], @@ -537,8 +534,3 @@ class EarlyStopping(object): def increased(self, metric, best, min_delta): return metric > best + min_delta - - -# function calculating prediction accuracy -def accuracy_function(outputs, labels): - return (np.asarray(outputs) == np.asarray(labels)).mean()