Skip to content
Snippets Groups Projects
Commit a2814daf authored by Frisinghelli Daniel's avatar Frisinghelli Daniel
Browse files

Improved training initialization workflow

parent 2ac1b506
No related branches found
No related tags found
No related merge requests found
...@@ -7,6 +7,7 @@ Created on Wed Aug 12 10:24:34 2020 ...@@ -7,6 +7,7 @@ Created on Wed Aug 12 10:24:34 2020
# builtins # builtins
import dataclasses import dataclasses
import pathlib import pathlib
import logging
# externals # externals
import numpy as np import numpy as np
...@@ -16,7 +17,6 @@ import torch.nn.functional as F ...@@ -16,7 +17,6 @@ import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset from torch.utils.data import DataLoader, Dataset
from torch.optim import Optimizer from torch.optim import Optimizer
# locals # locals
from pysegcnn.core.dataset import SupportedDatasets, ImageDataset from pysegcnn.core.dataset import SupportedDatasets, ImageDataset
from pysegcnn.core.transforms import Augment from pysegcnn.core.transforms import Augment
...@@ -27,6 +27,9 @@ from pysegcnn.core.models import (SupportedModels, SupportedOptimizers, ...@@ -27,6 +27,9 @@ from pysegcnn.core.models import (SupportedModels, SupportedOptimizers,
from pysegcnn.core.layers import Conv2dSame from pysegcnn.core.layers import Conv2dSame
from pysegcnn.main.config import HERE from pysegcnn.main.config import HERE
# module level logger
LOGGER = logging.getLogger(__name__)
@dataclasses.dataclass @dataclasses.dataclass
class BaseConfig: class BaseConfig:
...@@ -60,7 +63,6 @@ class DatasetConfig(BaseConfig): ...@@ -60,7 +63,6 @@ class DatasetConfig(BaseConfig):
sort: bool = False sort: bool = False
transforms: list = dataclasses.field(default_factory=list) transforms: list = dataclasses.field(default_factory=list)
pad: bool = False pad: bool = False
cval: int = 99
def __post_init__(self): def __post_init__(self):
# check input types # check input types
...@@ -80,11 +82,6 @@ class DatasetConfig(BaseConfig): ...@@ -80,11 +82,6 @@ class DatasetConfig(BaseConfig):
' of {}.'.format('.'.join([Augment.__module__, ' of {}.'.format('.'.join([Augment.__module__,
Augment.__name__]))) Augment.__name__])))
# check whether the constant padding value is within the valid range
if not 0 < self.cval < 255:
raise ValueError('Expecting 0 <= cval <= 255, got cval={}.'
.format(self.cval))
def init_dataset(self): def init_dataset(self):
# instanciate the dataset # instanciate the dataset
...@@ -96,7 +93,6 @@ class DatasetConfig(BaseConfig): ...@@ -96,7 +93,6 @@ class DatasetConfig(BaseConfig):
sort=self.sort, sort=self.sort,
transforms=self.transforms, transforms=self.transforms,
pad=self.pad, pad=self.pad,
cval=self.cval,
gt_pattern=self.gt_pattern gt_pattern=self.gt_pattern
) )
...@@ -121,7 +117,8 @@ class SplitConfig(BaseConfig): ...@@ -121,7 +117,8 @@ class SplitConfig(BaseConfig):
# function to drop samples with a fraction of pixels equal to the constant # function to drop samples with a fraction of pixels equal to the constant
# padding value self.cval >= self.drop # padding value self.cval >= self.drop
def _drop_samples(self, ds, drop_threshold=1): @staticmethod
def _drop_samples(ds, drop_threshold=1):
# iterate over the scenes returned by self.compose_scenes() # iterate over the scenes returned by self.compose_scenes()
dropped = [] dropped = []
...@@ -139,8 +136,8 @@ class SplitConfig(BaseConfig): ...@@ -139,8 +136,8 @@ class SplitConfig(BaseConfig):
# drop samples where npixels >= self.drop # drop samples where npixels >= self.drop
if npixels >= drop_threshold: if npixels >= drop_threshold:
print('Skipping scene {}, tile {}: {:.2f}% padded pixels ...' LOGGER.info('Skipping scene {}, tile {}: {:.2f}% padded pixels'
.format(s['id'], s['tile'], npixels * 100)) ' ...'.format(s['id'], s['tile'], npixels * 100))
dropped.append(s) dropped.append(s)
_ = ds.indices.pop(pos) _ = ds.indices.pop(pos)
...@@ -197,14 +194,24 @@ class ModelConfig(BaseConfig): ...@@ -197,14 +194,24 @@ class ModelConfig(BaseConfig):
model_name: str model_name: str
filters: list filters: list
torch_seed: int torch_seed: int
optim_name: str
loss_name: str
skip_connection: bool = True skip_connection: bool = True
kwargs: dict = dataclasses.field( kwargs: dict = dataclasses.field(
default_factory=lambda: {'kernel_size': 3, 'stride': 1, 'dilation': 1}) default_factory=lambda: {'kernel_size': 3, 'stride': 1, 'dilation': 1})
state_path: pathlib.Path = pathlib.Path(HERE).joinpath('_models/') state_path: pathlib.Path = pathlib.Path(HERE).joinpath('_models/')
batch_size: int = 64 batch_size: int = 64
checkpoint: bool = False checkpoint: bool = False
pretrained: bool = False transfer: bool = False
pretrained_model: str = '' pretrained_model: str = ''
lr: float = 0.001
early_stop: bool = False
mode: str = 'max'
delta: float = 0
patience: int = 10
epochs: int = 50
nthreads: int = torch.get_num_threads()
save: bool = True
def __post_init__(self): def __post_init__(self):
# check input types # check input types
...@@ -213,65 +220,32 @@ class ModelConfig(BaseConfig): ...@@ -213,65 +220,32 @@ class ModelConfig(BaseConfig):
# check whether the model is currently supported # check whether the model is currently supported
self.model_class = item_in_enum(self.model_name, SupportedModels) self.model_class = item_in_enum(self.model_name, SupportedModels)
def init_state(self, ds, sc, tc): # check whether the optimizer is currently supported
self.optim_class = item_in_enum(self.optim_name, SupportedOptimizers)
# file to save model state to:
# network_dataset_optim_split_splitparams_tilesize_batchsize_bands.pt
# model state filename # check whether the loss function is currently supported
state_file = '{}_{}_{}_{}Split_{}_t{}_b{}_{}.pt' self.loss_class = item_in_enum(self.loss_name, SupportedLossFunctions)
# get the band numbers # path to pretrained model
bformat = ''.join(band[0] + self.pretrained_path = self.state_path.joinpath(self.pretrained_model)
str(ds.sensor.__members__[band].value) for
band in ds.use_bands)
# check which split mode was used def init_optimizer(self, model):
if sc.split_mode == 'date':
# store the date that was used to split the dataset
state_file = state_file.format(self.model_class.__name__,
ds.__class__.__name__,
tc.optim_name,
sc.split_mode.capitalize(),
sc.date,
ds.tile_size,
self.batch_size,
bformat)
else:
# store the random split parameters
split_params = 's{}_t{}v{}'.format(
ds.seed, str(sc.ttratio).replace('.', ''),
str(sc.tvratio).replace('.', ''))
# model state filename # initialize the optimizer for the specified model
state_file = state_file.format(self.model_class.__name__, optimizer = self.optim_class(model.parameters(), self.lr)
ds.__class__.__name__,
tc.optim_name,
sc.split_mode.capitalize(),
split_params,
ds.tile_size,
self.batch_size,
bformat)
# check whether a pretrained model was used and change state filename return optimizer
# accordingly
if self.pretrained:
# add the configuration of the pretrained model to the state name
state_file = (state_file.replace('.pt', '_') +
'pretrained_' + self.pretrained_model)
# path to model state def init_loss_function(self):
state = self.state_path.joinpath(state_file)
# path to model loss/accuracy loss_function = self.loss_class()
loss_state = pathlib.Path(str(state).replace('.pt', '_loss.pt'))
return state, loss_state return loss_function
def init_model(self, ds): def init_model(self, ds):
# case (1): build a new model # case (1): build a new model
if not self.pretrained: if not self.transfer:
# set the random seed for reproducibility # set the random seed for reproducibility
torch.manual_seed(self.torch_seed) torch.manual_seed(self.torch_seed)
...@@ -284,130 +258,172 @@ class ModelConfig(BaseConfig): ...@@ -284,130 +258,172 @@ class ModelConfig(BaseConfig):
skip=self.skip_connection, skip=self.skip_connection,
**self.kwargs) **self.kwargs)
# case (2): load a pretrained model # case (2): load a pretrained model for transfer learning
else: else:
# load pretrained model # load pretrained model
model = self.load_pretrained() model, _ = self.load_pretrained(self.pretrained_path, new_ds=ds)
return model return model
def load_checkpoint(self, state_file, loss_state, model, optimizer): def from_checkpoint(self, model, optimizer, state_file, loss_state):
# initial accuracy on the validation set # whether to resume training from an existing model checkpoint
checkpoint_state = {}
max_accuracy = 0 max_accuracy = 0
if self.checkpoint:
# set the model checkpoint to None, overwritten when resuming # check whether the checkpoint exists
# training from an existing model checkpoint if state_file.exists() and loss_state.exists():
checkpoint_state = {} # load model checkpoint
model, optimizer = self.load_pretrained(state_file, optimizer,
new_ds=None)
(checkpoint_state, max_accuracy) = self.load_checkpoint(
loss_state)
else:
LOGGER.info('Checkpoint for model {} does not exist. '
'Initializing new model.'.format(state_file.name))
# whether to resume training from an existing model return model, optimizer, checkpoint_state, max_accuracy
if self.checkpoint:
# check if a model checkpoint exists @staticmethod
if not state_file.exists(): def load_pretrained(state_file, optimizer=None, new_ds=None):
raise FileNotFoundError('Model checkpoint {} does not exist.'
.format(state_file))
# load the model state # load the pretrained model
state = model.load(state_file.name, optimizer, self.state_path) if not state_file.exists():
print('Found checkpoint: {}'.format(state)) raise FileNotFoundError('Pretrained model {} does not exist.'
print('Resuming training from checkpoint ...'.format(state)) .format(state_file))
print('Model epoch: {:d}'.format(model.epoch))
# load the model loss and accuracy LOGGER.info('Loading pretrained model: {}'.format(state_file.name))
checkpoint_state = torch.load(loss_state)
# get all non-zero elements, i.e. get number of epochs trained # load the model state
# before the early stop model_state = torch.load(state_file)
checkpoint_state = {k: v[np.nonzero(v)].reshape(v.shape[0], -1)
for k, v in checkpoint_state.items()}
# maximum accuracy on the validation set # the model class
max_accuracy = checkpoint_state['va'][:, -1].mean().item() model_class = model_state['cls']
return checkpoint_state, max_accuracy # instanciate pretrained model architecture
model = model_class(**model_state['params'], **model_state['kwargs'])
def load_pretrained(self, ds): # load pretrained model weights
_ = model.load(state_file.name, optimizer=optimizer,
inpath=str(state_file.parent))
LOGGER.info('Model epoch: {:d}'.format(model.epoch))
# load the pretrained model # check whether to apply pretrained model on a new dataset
model_state = self.state_path.joinpath(self.pretrained_model) if new_ds is not None:
if not model_state.exists(): LOGGER.info('Configuring model for new dataset: {}.'
raise FileNotFoundError('Pretrained model {} does not exist.' .format(new_ds.__class__.__name__))
.format(model_state))
# load the model state # the bands the model was trained with
model_state = torch.load(model_state) bands = model_state['bands']
# get the input bands of the pretrained model # check whether the current dataset uses the correct spectral bands
bands = model_state['bands'] if new_ds.use_bands != bands:
raise ValueError('The pretrained network was trained with the '
'bands {}, not with: {}'
.format(bands, new_ds.use_bands))
# get the number of convolutional filters # get the number of convolutional filters
filters = model_state['params']['filters'] filters = model_state['params']['filters']
# check whether the current dataset uses the correct spectral bands # reset model epoch to 0, since the model is trained on a different
if ds.use_bands != bands: # dataset
raise ValueError('The bands of the pretrained network do not ' model.epoch = 0
'match the specified bands: {}'
.format(bands))
# instanciate pretrained model architecture # adjust the number of classes in the model
model = self.model_class(**model_state['params'], model.nclasses = len(new_ds.labels)
**model_state['kwargs']) LOGGER.info('Replacing classification layer to classes: {}.'
.format(', '.join('({}, {})'.format(k, v['label'])
for k, v in new_ds.labels.items())))
# load pretrained model weights # adjust the classification layer to the number of classes of the
model.load(self.pretrained_model, inpath=str(self.state_path)) # current dataset
model.classifier = Conv2dSame(in_channels=filters[0],
out_channels=model.nclasses,
kernel_size=1)
# reset model epoch to 0, since the model is trained on a different return model, optimizer
# dataset
model.epoch = 0
# adjust the number of classes in the model @staticmethod
model.nclasses = len(ds.labels) def load_checkpoint(loss_state):
# adjust the classification layer to the number of classes of the # load the model loss and accuracy
# current dataset checkpoint_state = torch.load(loss_state)
model.classifier = Conv2dSame(in_channels=filters[0],
out_channels=model.nclasses,
kernel_size=1)
return model # get all non-zero elements, i.e. get number of epochs trained
# before the early stop
checkpoint_state = {k: v[np.nonzero(v)].reshape(v.shape[0], -1)
for k, v in checkpoint_state.items()}
# maximum accuracy on the validation set
max_accuracy = checkpoint_state['va'][:, -1].mean().item()
return checkpoint_state, max_accuracy
@dataclasses.dataclass @dataclasses.dataclass
class TrainConfig(BaseConfig): class StateConfig(BaseConfig):
optim_name: str ds: ImageDataset
loss_name: str sc: SplitConfig
lr: float = 0.001 mc: ModelConfig
early_stop: bool = False
mode: str = 'max'
delta: float = 0
patience: int = 10
epochs: int = 50
nthreads: int = torch.get_num_threads()
save: bool = True
def __post_init__(self): def __post_init__(self):
super().__post_init__() super().__post_init__()
# check whether the optimizer is currently supported def init_state(self):
self.optim_class = item_in_enum(self.optim_name, SupportedOptimizers)
# check whether the loss function is currently supported # file to save model state to:
self.loss_class = item_in_enum(self.loss_name, SupportedLossFunctions) # network_dataset_optim_split_splitparams_tilesize_batchsize_bands.pt
def init_optimizer(self, model): # model state filename
state_file = '{}_{}_{}_{}Split_{}_t{}_b{}_{}.pt'
# initialize the optimizer for the specified model # get the band numbers
optimizer = self.optim_class(model.parameters(), self.lr) bformat = ''.join(band[0] +
str(self.ds.sensor.__members__[band].value) for
band in self.ds.use_bands)
return optimizer # check which split mode was used
if self.sc.split_mode == 'date':
# store the date that was used to split the dataset
state_file = state_file.format(self.mc.model_name,
self.ds.__class__.__name__,
self.mc.optim_name,
self.sc.split_mode.capitalize(),
self.sc.date,
self.ds.tile_size,
self.mc.batch_size,
bformat)
else:
# store the random split parameters
split_params = 's{}_t{}v{}'.format(
self.ds.seed, str(self.sc.ttratio).replace('.', ''),
str(self.sc.tvratio).replace('.', ''))
def init_loss_function(self): # model state filename
state_file = state_file.format(self.mc.model_name,
self.ds.__class__.__name__,
self.mc.optim_name,
self.sc.split_mode.capitalize(),
split_params,
self.ds.tile_size,
self.mc.batch_size,
bformat)
loss_function = self.loss_class() # check whether a pretrained model was used and change state filename
# accordingly
if self.mc.transfer:
# add the configuration of the pretrained model to the state name
state_file = (state_file.replace('.pt', '_') +
'pretrained_' + self.mc.pretrained_model)
return loss_function # path to model state
state = self.mc.state_path.joinpath(state_file)
# path to model loss/accuracy
loss_state = pathlib.Path(str(state).replace('.pt', '_loss.pt'))
return state, loss_state
@dataclasses.dataclass @dataclasses.dataclass
...@@ -428,6 +444,7 @@ class EvalConfig(BaseConfig): ...@@ -428,6 +444,7 @@ class EvalConfig(BaseConfig):
raise TypeError('Expected "test" to be None, True or False, got ' raise TypeError('Expected "test" to be None, True or False, got '
'{}.'.format(self.test)) '{}.'.format(self.test))
@dataclasses.dataclass @dataclasses.dataclass
class NetworkTrainer(BaseConfig): class NetworkTrainer(BaseConfig):
model: Network model: Network
...@@ -457,15 +474,16 @@ class NetworkTrainer(BaseConfig): ...@@ -457,15 +474,16 @@ class NetworkTrainer(BaseConfig):
# whether to use early stopping # whether to use early stopping
self.es = None self.es = None
if self.early_stop: if self.early_stop:
self.es = EarlyStopping(self.mode, self.delta, self.patience) self.es = EarlyStopping(self.mode, self.max_accuracy, self.delta,
self.patience)
def train(self): def train(self):
print('------------------------- Training ---------------------------') LOGGER.info(30 * '-' + ' Training ' + 30 * '-')
# set the number of threads # set the number of threads
print('Device: {}'.format(self.device)) LOGGER.info('Device: {}'.format(self.device))
print('Number of cpu threads: {}'.format(self.nthreads)) LOGGER.info('Number of cpu threads: {}'.format(self.nthreads))
torch.set_num_threads(self.nthreads) torch.set_num_threads(self.nthreads)
# create dictionary of the observed losses and accuracies on the # create dictionary of the observed losses and accuracies on the
...@@ -485,7 +503,7 @@ class NetworkTrainer(BaseConfig): ...@@ -485,7 +503,7 @@ class NetworkTrainer(BaseConfig):
for epoch in range(self.epochs): for epoch in range(self.epochs):
# set the model to training mode # set the model to training mode
print('Setting model to training mode ...') LOGGER.info('Setting model to training mode ...')
self.model.train() self.model.train()
# iterate over the dataloader object # iterate over the dataloader object
...@@ -521,13 +539,14 @@ class NetworkTrainer(BaseConfig): ...@@ -521,13 +539,14 @@ class NetworkTrainer(BaseConfig):
training_state['ta'][batch, epoch] = observed_accuracy training_state['ta'][batch, epoch] = observed_accuracy
# print progress # print progress
print('Epoch: {:d}/{:d}, Mini-batch: {:d}/{:d}, Loss: {:.2f}, ' LOGGER.info('Epoch: {:d}/{:d}, Mini-batch: {:d}/{:d}, '
'Accuracy: {:.2f}'.format(epoch + 1, 'Loss: {:.2f}, Accuracy: {:.2f}'.format(
self.epochs, epoch + 1,
batch + 1, self.epochs,
len(self.train_dl), batch + 1,
observed_loss, len(self.train_dl),
observed_accuracy)) observed_loss,
observed_accuracy))
# update the number of epochs trained # update the number of epochs trained
self.model.epoch += 1 self.model.epoch += 1
...@@ -568,13 +587,13 @@ class NetworkTrainer(BaseConfig): ...@@ -568,13 +587,13 @@ class NetworkTrainer(BaseConfig):
def predict(self): def predict(self):
print('------------------------ Predicting --------------------------') LOGGER.info(30 * '-' + ' Predicting ' + 30 * '-')
# send the model to the gpu if available # send the model to the gpu if available
self.model = self.model.to(self.device) self.model = self.model.to(self.device)
# set the model to evaluation mode # set the model to evaluation mode
print('Setting model to evaluation mode ...') LOGGER.info('Setting model to evaluation mode ...')
self.model.eval() self.model.eval()
# create arrays of the observed losses and accuracies # create arrays of the observed losses and accuracies
...@@ -582,7 +601,7 @@ class NetworkTrainer(BaseConfig): ...@@ -582,7 +601,7 @@ class NetworkTrainer(BaseConfig):
losses = np.zeros(shape=(len(self.valid_dl), 1)) losses = np.zeros(shape=(len(self.valid_dl), 1))
# iterate over the validation/test set # iterate over the validation/test set
print('Calculating accuracy on the validation set ...') LOGGER.info('Calculating accuracy on the validation set ...')
for batch, (inputs, labels) in enumerate(self.valid_dl): for batch, (inputs, labels) in enumerate(self.valid_dl):
# send the data to the gpu if available # send the data to the gpu if available
...@@ -605,12 +624,12 @@ class NetworkTrainer(BaseConfig): ...@@ -605,12 +624,12 @@ class NetworkTrainer(BaseConfig):
accuracies[batch, 0] = acc accuracies[batch, 0] = acc
# print progress # print progress
print('Mini-batch: {:d}/{:d}, Accuracy: {:.2f}' LOGGER.info('Mini-batch: {:d}/{:d}, Accuracy: {:.2f}'
.format(batch + 1, len(self.valid_dl), acc)) .format(batch + 1, len(self.valid_dl), acc))
# calculate overall accuracy on the validation/test set # calculate overall accuracy on the validation/test set
print('Epoch {:d}, Overall accuracy: {:.2f}%.' LOGGER.info('Epoch {:d}, Overall accuracy: {:.2f}%.'
.format(self.model.epoch, accuracies.mean() * 100)) .format(self.model.epoch, accuracies.mean() * 100))
return accuracies, losses return accuracies, losses
...@@ -649,7 +668,7 @@ class NetworkTrainer(BaseConfig): ...@@ -649,7 +668,7 @@ class NetworkTrainer(BaseConfig):
# dataset # dataset
fs += ' (dataset):\n ' fs += ' (dataset):\n '
fs += ''.join( fs += ''.join(
repr(self.train_dl.dataset.dataset)).replace('\n','\n ') repr(self.train_dl.dataset.dataset)).replace('\n', '\n ')
# batch size # batch size
fs += '\n (batch):\n ' fs += '\n (batch):\n '
...@@ -684,7 +703,7 @@ class NetworkTrainer(BaseConfig): ...@@ -684,7 +703,7 @@ class NetworkTrainer(BaseConfig):
class EarlyStopping(object): class EarlyStopping(object):
def __init__(self, mode='max', min_delta=0, patience=10): def __init__(self, mode='max', best=0, min_delta=0, patience=10):
# check if mode is correctly specified # check if mode is correctly specified
if mode not in ['min', 'max']: if mode not in ['min', 'max']:
...@@ -707,7 +726,7 @@ class EarlyStopping(object): ...@@ -707,7 +726,7 @@ class EarlyStopping(object):
self.patience = patience self.patience = patience
# initialize best metric # initialize best metric
self.best = None self.best = best
# initialize early stopping flag # initialize early stopping flag
self.early_stop = False self.early_stop = False
...@@ -717,25 +736,20 @@ class EarlyStopping(object): ...@@ -717,25 +736,20 @@ class EarlyStopping(object):
def stop(self, metric): def stop(self, metric):
if self.best is not None: # if the metric improved, reset the epochs counter, else, advance
if self.is_better(metric, self.best, self.min_delta):
# if the metric improved, reset the epochs counter, else, advance self.counter = 0
if self.is_better(metric, self.best, self.min_delta):
self.counter = 0
self.best = metric
else:
self.counter += 1
print('Early stopping counter: {}/{}'.format(self.counter,
self.patience))
# if the metric did not improve over the last patience epochs,
# the early stopping criterion is met
if self.counter >= self.patience:
print('Early stopping criterion met, exiting training ...')
self.early_stop = True
else:
self.best = metric self.best = metric
else:
self.counter += 1
LOGGER.info('Early stopping counter: {}/{}'.format(
self.counter, self.patience))
# if the metric did not improve over the last patience epochs,
# the early stopping criterion is met
if self.counter >= self.patience:
LOGGER.info('Early stopping criterion met, stopping training.')
self.early_stop = True
return self.early_stop return self.early_stop
...@@ -746,6 +760,7 @@ class EarlyStopping(object): ...@@ -746,6 +760,7 @@ class EarlyStopping(object):
return metric > best + min_delta return metric > best + min_delta
def __repr__(self): def __repr__(self):
fs = (self.__class__.__name__ + '(mode={}, delta={}, patience={})' fs = self.__class__.__name__
.format(self.mode, self.min_delta, self.patience)) fs += '(mode={}, best={}, delta={}, patience={})'.format(
self.mode, self.best, self.min_delta, self.patience)
return fs return fs
...@@ -5,11 +5,14 @@ Created on Tue Jun 30 09:33:38 2020 ...@@ -5,11 +5,14 @@ Created on Tue Jun 30 09:33:38 2020
@author: Daniel @author: Daniel
""" """
# builtins
import logging
# locals # locals
from pysegcnn.core.trainer import (DatasetConfig, SplitConfig, ModelConfig, from pysegcnn.core.trainer import (DatasetConfig, SplitConfig, ModelConfig,
TrainConfig, NetworkTrainer) StateConfig, NetworkTrainer)
from pysegcnn.main.config import (dataset_config, split_config, from pysegcnn.core.logging import log_conf
model_config, train_config) from pysegcnn.main.config import (dataset_config, split_config, model_config)
if __name__ == '__main__': if __name__ == '__main__':
...@@ -20,35 +23,36 @@ if __name__ == '__main__': ...@@ -20,35 +23,36 @@ if __name__ == '__main__':
dc = DatasetConfig(**dataset_config) dc = DatasetConfig(**dataset_config)
sc = SplitConfig(**split_config) sc = SplitConfig(**split_config)
mc = ModelConfig(**model_config) mc = ModelConfig(**model_config)
tc = TrainConfig(**train_config)
# (ii) instanciate the dataset # (ii) instanciate the dataset
ds = dc.init_dataset() ds = dc.init_dataset()
ds ds
# (iii) instanciate the training, validation and test datasets and # (iii) instanciate the model state
state = StateConfig(ds, sc, mc)
state_file, loss_state = state.init_state()
# initialize logging
log_file = str(state_file).replace('.pt', '_train.log')
logging.config.dictConfig(log_conf(log_file))
# (iv) instanciate the training, validation and test datasets and
# dataloaders # dataloaders
train_ds, valid_ds, test_ds = sc.train_val_test_split(ds) train_ds, valid_ds, test_ds = sc.train_val_test_split(ds)
train_dl, valid_dl, test_dl = sc.dataloaders(train_ds, train_dl, valid_dl, test_dl = sc.dataloaders(
valid_ds, train_ds, valid_ds, test_ds, batch_size=mc.batch_size, shuffle=True,
test_ds, drop_last=False)
batch_size=mc.batch_size,
shuffle=True,
drop_last=False)
# (iv) instanciate the model state files
state_file, loss_state = mc.init_state(ds, sc, tc)
# (v) instanciate the model # (iv) instanciate the model
model = mc.init_model(ds) model = mc.init_model(ds)
# (vi) instanciate the optimizer and the loss function # (vi) instanciate the optimizer and the loss function
optimizer = tc.init_optimizer(model) optimizer = mc.init_optimizer(model)
loss_function = tc.init_loss_function() loss_function = mc.init_loss_function()
# (vii) resume training from an existing model checkpoint # (vii) resume training from an existing model checkpoint
checkpoint_state, max_accuracy = mc.load_checkpoint(state_file, loss_state, (model, optimizer, checkpoint_state, max_accuracy) = mc.from_checkpoint(
model, optimizer) model, optimizer, state_file, loss_state)
# (viii) initialize network trainer class for eays model training # (viii) initialize network trainer class for eays model training
trainer = NetworkTrainer(model=model, trainer = NetworkTrainer(model=model,
...@@ -58,15 +62,15 @@ if __name__ == '__main__': ...@@ -58,15 +62,15 @@ if __name__ == '__main__':
valid_dl=valid_dl, valid_dl=valid_dl,
state_file=state_file, state_file=state_file,
loss_state=loss_state, loss_state=loss_state,
epochs=tc.epochs, epochs=mc.epochs,
nthreads=tc.nthreads, nthreads=mc.nthreads,
early_stop=tc.early_stop, early_stop=mc.early_stop,
mode=tc.mode, mode=mc.mode,
delta=tc.delta, delta=mc.delta,
patience=tc.patience, patience=mc.patience,
max_accuracy=max_accuracy, max_accuracy=max_accuracy,
checkpoint_state=checkpoint_state, checkpoint_state=checkpoint_state,
save=tc.save save=mc.save
) )
# (ix) train model # (ix) train model
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment