Skip to content
Snippets Groups Projects
Commit f86b5825 authored by Frisinghelli Daniel's avatar Frisinghelli Daniel
Browse files

Implemented representations

parent 95912567
No related branches found
No related tags found
No related merge requests found
...@@ -18,7 +18,9 @@ from torch.utils.data import DataLoader ...@@ -18,7 +18,9 @@ from torch.utils.data import DataLoader
from pysegcnn.core.dataset import SupportedDatasets from pysegcnn.core.dataset import SupportedDatasets
from pysegcnn.core.layers import Conv2dSame from pysegcnn.core.layers import Conv2dSame
from pysegcnn.core.utils import img2np, accuracy_function from pysegcnn.core.utils import img2np, accuracy_function
from pysegcnn.core.split import Split from pysegcnn.core.split import (RandomTileSplit, RandomSceneSplit, DateSplit,
VALID_SPLIT_MODES)
class NetworkTrainer(object): class NetworkTrainer(object):
...@@ -340,13 +342,27 @@ class NetworkTrainer(object): ...@@ -340,13 +342,27 @@ class NetworkTrainer(object):
'\n'.join(name for name, _ in '\n'.join(name for name, _ in
SupportedDatasets.__members__.items())) SupportedDatasets.__members__.items()))
# instanciate the Split class handling the dataset split
self.subset = Split(self.dataset, self.split_mode, # the mode to split
tvratio=self.tvratio, if self.split_mode not in VALID_SPLIT_MODES:
ttratio=self.ttratio, raise ValueError('{} is not supported. Valid modes are {}, see '
seed=self.seed, 'pysegcnn.main.config.py for a description of '
date=self.date, 'each mode.'.format(self.split_mode,
dateformat=self.dateformat) VALID_SPLIT_MODES))
if self.split_mode == 'random':
self.subset = RandomTileSplit(self.dataset,
self.ttratio,
self.tvratio,
self.seed)
if self.split_mode == 'scene':
self.subset = RandomSceneSplit(self.dataset,
self.ttratio,
self.tvratio,
self.seed)
if self.split_mode == 'date':
self.subset = DateSplit(self.dataset,
self.date,
self.dateformat)
# the training, validation and dataset # the training, validation and dataset
self.train_ds, self.valid_ds, self.test_ds = self.subset.split() self.train_ds, self.valid_ds, self.test_ds = self.subset.split()
...@@ -442,16 +458,10 @@ class NetworkTrainer(object): ...@@ -442,16 +458,10 @@ class NetworkTrainer(object):
# representation string to print # representation string to print
fs = self.__class__.__name__ + '(\n' fs = self.__class__.__name__ + '(\n'
fs += ' (bands):\n '
# bands used for the segmentation
fs += '\n '.join('- Band {}: {}'.format(i, b) for i, b in
enumerate(self.dataset.use_bands))
# classes of interest # dataset
fs += '\n (classes):\n ' fs += ' (dataset):\n '
fs += '\n '.join('- Class {}: {}'.format(k, v['label']) for fs += ''.join(self.dataset.__repr__()).replace('\n', '\n ')
k, v in self.dataset.labels.items())
# batch size # batch size
fs += '\n (batch):\n ' fs += '\n (batch):\n '
...@@ -459,19 +469,18 @@ class NetworkTrainer(object): ...@@ -459,19 +469,18 @@ class NetworkTrainer(object):
fs += '- batch shape (b, h, w): {}'.format(self.batch_shape) fs += '- batch shape (b, h, w): {}'.format(self.batch_shape)
# dataset split # dataset split
fs += '\n (dataset):\n ' fs += '\n (split):\n '
fs += '\n '.join( fs += ''.join(self.subset.__repr__()).replace('\n', '\n ')
'- {}: {:d} batches ({:.2f}%)'
.format(k, v[0], v[1] * 100) for k, v in
{'Training': (len(self.train_ds), self.ttratio * self.tvratio),
'Validation': (len(self.valid_ds),
self.ttratio * (1 - self.tvratio)),
'Test': (len(self.test_ds), 1 - self.ttratio)}.items())
# model # model
fs += '\n (model):\n ' fs += '\n (model):\n '
fs += ''.join(self.model.__repr__()).replace('\n', '\n ') fs += ''.join(self.model.__repr__()).replace('\n', '\n ')
# optimizer
fs += '\n (optimizer):\n '
fs += ''.join(self.optimizer.__repr__()).replace('\n', '\n ')
fs += '\n)' fs += '\n)'
return fs return fs
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment