Skip to content
Snippets Groups Projects
Commit f86b5825 authored by Frisinghelli Daniel's avatar Frisinghelli Daniel
Browse files

Implemented representations

parent 95912567
No related branches found
No related tags found
No related merge requests found
......@@ -18,7 +18,9 @@ from torch.utils.data import DataLoader
from pysegcnn.core.dataset import SupportedDatasets
from pysegcnn.core.layers import Conv2dSame
from pysegcnn.core.utils import img2np, accuracy_function
from pysegcnn.core.split import Split
from pysegcnn.core.split import (RandomTileSplit, RandomSceneSplit, DateSplit,
VALID_SPLIT_MODES)
class NetworkTrainer(object):
......@@ -340,13 +342,27 @@ class NetworkTrainer(object):
'\n'.join(name for name, _ in
SupportedDatasets.__members__.items()))
# instanciate the Split class handling the dataset split
self.subset = Split(self.dataset, self.split_mode,
tvratio=self.tvratio,
ttratio=self.ttratio,
seed=self.seed,
date=self.date,
dateformat=self.dateformat)
# the mode to split
if self.split_mode not in VALID_SPLIT_MODES:
raise ValueError('{} is not supported. Valid modes are {}, see '
'pysegcnn.main.config.py for a description of '
'each mode.'.format(self.split_mode,
VALID_SPLIT_MODES))
if self.split_mode == 'random':
self.subset = RandomTileSplit(self.dataset,
self.ttratio,
self.tvratio,
self.seed)
if self.split_mode == 'scene':
self.subset = RandomSceneSplit(self.dataset,
self.ttratio,
self.tvratio,
self.seed)
if self.split_mode == 'date':
self.subset = DateSplit(self.dataset,
self.date,
self.dateformat)
# the training, validation and dataset
self.train_ds, self.valid_ds, self.test_ds = self.subset.split()
......@@ -442,16 +458,10 @@ class NetworkTrainer(object):
# representation string to print
fs = self.__class__.__name__ + '(\n'
fs += ' (bands):\n '
# bands used for the segmentation
fs += '\n '.join('- Band {}: {}'.format(i, b) for i, b in
enumerate(self.dataset.use_bands))
# classes of interest
fs += '\n (classes):\n '
fs += '\n '.join('- Class {}: {}'.format(k, v['label']) for
k, v in self.dataset.labels.items())
# dataset
fs += ' (dataset):\n '
fs += ''.join(self.dataset.__repr__()).replace('\n', '\n ')
# batch size
fs += '\n (batch):\n '
......@@ -459,19 +469,18 @@ class NetworkTrainer(object):
fs += '- batch shape (b, h, w): {}'.format(self.batch_shape)
# dataset split
fs += '\n (dataset):\n '
fs += '\n '.join(
'- {}: {:d} batches ({:.2f}%)'
.format(k, v[0], v[1] * 100) for k, v in
{'Training': (len(self.train_ds), self.ttratio * self.tvratio),
'Validation': (len(self.valid_ds),
self.ttratio * (1 - self.tvratio)),
'Test': (len(self.test_ds), 1 - self.ttratio)}.items())
fs += '\n (split):\n '
fs += ''.join(self.subset.__repr__()).replace('\n', '\n ')
# model
fs += '\n (model):\n '
fs += ''.join(self.model.__repr__()).replace('\n', '\n ')
# optimizer
fs += '\n (optimizer):\n '
fs += ''.join(self.optimizer.__repr__()).replace('\n', '\n ')
fs += '\n)'
return fs
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment