Skip to content
Snippets Groups Projects
Commit 5a99371d authored by Frisinghelli Daniel's avatar Frisinghelli Daniel
Browse files

Created distinct class to split dataset into training, validation and test set

parent 6faa33d1
No related branches found
No related tags found
No related merge requests found
......@@ -11,6 +11,9 @@ import datetime
import numpy as np
from torch.utils.data.dataset import Subset
# the names of the subsets
SUBSET_NAMES = ['train', 'valid', 'test']
# function calculating number of samples in a dataset given a ratio
def _ds_len(ds, ratio):
......@@ -37,29 +40,26 @@ def random_tile_split(ds, tvratio, ttratio=1, seed=0):
# length of the training dataset
# number of samples: (ttratio * tvratio * len(ds))
train_len = _ds_len(trav_indices, tvratio)
train_indices = trav_indices[:train_len]
train_ind = trav_indices[:train_len]
# length of the validation dataset
# number of samples: (ttratio * (1- tvratio) * len(ds))
valid_indices = trav_indices[train_len:]
valid_ind = trav_indices[train_len:]
# length of the test dataset
# number of samples: ((1 - ttratio) * len(ds))
test_indices = indices[trav_len:]
test_ind = indices[trav_len:]
# get the tiles of the scenes of each dataset
subsets = []
for dataset in [train_indices, valid_indices, test_indices]:
# build the subset: store the scenes
sbst = Subset(dataset=ds, indices=list(dataset))
sbst.scenes = [ds.scenes[i] for i in dataset]
subsets = {}
for name, dataset in enumerate([train_ind, valid_ind, test_ind]):
# add to list of subsets
subsets.append(sbst)
# store the indices and corresponding tiles of the current subset to
# dictionary
subsets[SUBSET_NAMES[name]] = {k: ds.scenes[k] for k in dataset}
# check if the splits are disjoint
assert pairwise_disjoint([s.indices for s in subsets])
assert pairwise_disjoint([s.keys() for s in subsets.values()])
return subsets
......@@ -95,28 +95,16 @@ def random_scene_split(ds, tvratio, ttratio=1, seed=0):
test_scenes = scene_ids[trav_len:]
# get the tiles of the scenes of each dataset
subsets = []
for dataset in [train_scenes, valid_scenes, test_scenes]:
# the indices of the scenes in the dataset
indices = []
tiles = []
# iterate over the scenes of the whole dataset
for i, scene in enumerate(ds.scenes):
if scene['id'] in dataset:
indices.append(i)
tiles.append(scene)
# build the subset: store scene ids
sbst = Subset(dataset=ds, indices=indices)
sbst.scenes = tiles
sbst.ids = dataset
# add to list of subsets
subsets.append(sbst)
subsets = {}
for name, dataset in enumerate([train_scenes, valid_scenes, test_scenes]):
# store the indices and corresponding tiles of the current subset to
# dictionary
subsets[SUBSET_NAMES[name]] = {k: v for k, v in enumerate(ds.scenes)
if v['id'] in dataset}
# check if the splits are disjoint
assert pairwise_disjoint([s.indices for s in subsets])
assert pairwise_disjoint([s.keys() for s in subsets.values()])
return subsets
......@@ -135,18 +123,16 @@ def date_scene_split(ds, date, dateformat='%Y%m%d'):
test_scenes = {}
# build the training and test datasets
subsets = []
for scenes in [train_scenes, valid_scenes, test_scenes]:
# build the subset: store the scenes
sbst = Subset(dataset=ds, indices=list(scenes.keys()))
sbst.scenes = list(scenes.values())
sbst.ids = np.unique([s['id'] for s in scenes.values()])
subsets = {}
for name, scenes in enumerate([train_scenes, valid_scenes, test_scenes]):
# add to list of subsets
subsets.append(sbst)
# store the indices and corresponding tiles of the current subset to
# dictionary
subsets[SUBSET_NAMES[name]] = scenes
# sbst.ids = np.unique([s['id'] for s in scenes.values()])
# check if the splits are disjoint
assert pairwise_disjoint([s.indices for s in subsets])
assert pairwise_disjoint([s.keys() for s in subsets.values()])
return subsets
......@@ -155,3 +141,103 @@ def pairwise_disjoint(sets):
union = set().union(*sets)
n = sum(len(u) for u in sets)
return n == len(union)
class Split(object):
# the valid modes
valid_modes = ['random', 'scene', 'date']
def __init__(self, ds, mode, **kwargs):
# check which mode is provided
if mode not in self.valid_modes:
raise ValueError('{} is not supported. Valid modes are {}, see '
'pysegcnn.main.config.py for a description of '
'each mode.'.format(mode, self.valid_modes))
self.mode = mode
# the dataset to split
self.ds = ds
# the keyword arguments
self.kwargs = kwargs
# initialize split
self._init_split()
def _init_split(self):
if self.mode == 'random':
self.subset = RandomSubset
self.split_function = random_tile_split
self.allowed_kwargs = ['tvratio', 'ttratio', 'seed']
if self.mode == 'scene':
self.subset = SceneSubset
self.split_function = random_scene_split
self.allowed_kwargs = ['tvratio', 'ttratio', 'seed']
if self.mode == 'date':
self.subset = SceneSubset
self.split_function = date_scene_split
self.allowed_kwargs = ['date', 'dateformat']
self._check_kwargs()
def _check_kwargs(self):
# check if the correct keyword arguments are provided
if not set(self.allowed_kwargs).issubset(self.kwargs.keys()):
raise TypeError('__init__() expecting keyword arguments: {}.'
.format(', '.join(kwa for kwa in
self.allowed_kwargs)))
# select the correct keyword arguments
self.kwargs = {k: self.kwargs[k] for k in self.allowed_kwargs}
# function apply the split
def split(self):
# create the subsets
subsets = self.split_function(self.ds, **self.kwargs)
# build the subsets
ds_split = []
for name, sub in subsets.items():
# the scene identifiers of the current subset
ids = np.unique([s['id'] for s in sub.values()])
# build the subset
subset = self.subset(self.ds, list(sub.keys()), name,
list(sub.values()), ids)
ds_split.append(subset)
return ds_split
class SceneSubset(Subset):
def __init__(self, ds, indices, name, scenes, scene_ids):
super().__init__(dataset=ds, indices=indices)
# the name of the subset
self.name = name
# the scene in the subset
self.scenes = scenes
# the names of the scenes
self.ids = scene_ids
class RandomSubset(Subset):
def __init__(self, ds, indices, name, scenes):
super().__init__(dataset=ds, indices=indices)
# the name of the subset
self.name = name
# the scene in the subset
self.scenes = scenes
......@@ -17,9 +17,8 @@ from torch.utils.data import DataLoader
# locals
from pysegcnn.core.dataset import SupportedDatasets
from pysegcnn.core.layers import Conv2dSame
from pysegcnn.core.utils import img2np
from pysegcnn.core.split import (random_tile_split, random_scene_split,
date_scene_split)
from pysegcnn.core.utils import img2np, accuracy_function
from pysegcnn.core.split import Split
class NetworkTrainer(object):
......@@ -244,7 +243,7 @@ class NetworkTrainer(object):
return training_state
def predict(self, pretrained=False, confusion=False):
def predict(self):
print('------------------------ Predicting --------------------------')
......@@ -341,18 +340,16 @@ class NetworkTrainer(object):
'\n'.join(name for name, _ in
SupportedDatasets.__members__.items()))
# the training, validation and dataset
if self.split_mode == 'random':
self.train_ds, self.valid_ds, self.test_ds = random_tile_split(
self.dataset, self.tvratio, self.ttratio, self.seed)
if self.split_mode == 'scene':
self.train_ds, self.valid_ds, self.test_ds = random_scene_split(
self.dataset, self.tvratio, self.ttratio, self.seed)
# instanciate the Split class handling the dataset split
self.subset = Split(self.dataset, self.split_mode,
tvratio=self.tvratio,
ttratio=self.ttratio,
seed=self.seed,
date=self.date,
dateformat=self.dateformat)
if self.split_mode == 'date':
self.train_ds, self.valid_ds, self.test_ds = date_scene_split(
self.dataset, self.date)
# the training, validation and dataset
self.train_ds, self.valid_ds, self.test_ds = self.subset.split()
# whether to drop training samples with a fraction of pixels equal to
# the constant padding value self.cval >= self.drop
......@@ -409,7 +406,7 @@ class NetworkTrainer(object):
for pos, i in enumerate(ds.indices):
# the current scene
s = self.dataset.scenes[i]
s = ds.dataset.scenes[i]
# the current tile in the ground truth
tile_gt = img2np(s['gt'], self.tile_size, s['tile'],
......@@ -537,8 +534,3 @@ class EarlyStopping(object):
def increased(self, metric, best, min_delta):
return metric > best + min_delta
# function calculating prediction accuracy
def accuracy_function(outputs, labels):
return (np.asarray(outputs) == np.asarray(labels)).mean()
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment