Skip to content
Snippets Groups Projects
Commit 4338ddfb authored by Frisinghelli Daniel's avatar Frisinghelli Daniel
Browse files

Adapted main modules to changes in pysegcnn.core.trainer.py

parent 5128033c
No related branches found
No related tags found
No related merge requests found
......@@ -16,14 +16,19 @@ import os
# path to this file
HERE = os.path.abspath(os.path.dirname(__file__))
# path to the datasets
DATASET_PATH = 'C:/Eurac/2020/_Datasets/'
# DATASET_PATH = '/mnt/CEPH_PROJECTS/cci_snow/dfrisinghelli/_Datasets/'
# path to the datasets on the current machine
DRIVE_PATH ='C:/Eurac/2020/_Datasets/'
# DRIVE_PATH = '/mnt/CEPH_PROJECTS/cci_snow/dfrisinghelli/_Datasets/'
# name of the datasets
DATASET_NAME = 'Sparcs'
# DATASET_NAME = 'Cloud95/Training/'
# DATASET_NAME =' ProSnow/Garmisch/
# DATASET_NAME = 'Cloud95'
# DATASET_NAME = 'Garmisch'
# path to the dataset
DATASET_PATH = os.path.join(DRIVE_PATH, DATASET_NAME)
# DATASET_PATH = os.path.join(DRIVE_PATH, DATASET_NAME, 'Training')
# DATASET_PATH = os.path.join(DRIVE_PATH, 'ProSnow', DATASET_NAME)
# the dataset configuration dictionary
dataset_config = {
......@@ -32,8 +37,11 @@ dataset_config = {
# -------------------------------------------------------------------------
# name of the dataset
'dataset_name': DATASET_NAME,
# path to the dataset
'root_dir': os.path.join(DATASET_PATH, DATASET_NAME),
'root_dir': DATASET_PATH,
# a pattern to match the ground truth file naming convention
'gt_pattern': '*mask.png',
......@@ -53,8 +61,8 @@ dataset_config = {
# tiles of size (tile_size, tile_size)
'pad': True,
# set random seed for reproducibility of the training, validation
# and test data split
# the random seed for the numpy random number generator
# ensures reproducibility of the training, validation and test data split
# used if split_mode='random' and split_mode='scene'
'seed': 0,
......@@ -134,12 +142,12 @@ split_config = {
# (ttratio * 100) % of the dataset will be used for training and
# validation
# used if split_mode='random' and split_mode='scene'
'ttratio': 1,
'ttratio': 0.05,
# (ttratio * tvratio) * 100 % will be used as for training
# (1 - ttratio * tvratio) * 100 % will be used for validation
# used if split_mode='random' and split_mode='scene'
'tvratio': 0.8,
'tvratio': 0.5,
# the date to split the scenes
# format: 'yyyymmdd'
......@@ -211,7 +219,10 @@ model_config = {
# define the batch size
# determines how many samples of the dataset are processed until the
# weights of the network are updated (via mini-batch gradient descent)
'batch_size': 64
'batch_size': 64,
# the seed for the random number generator intializing the network weights
'torch_seed': 0
}
......@@ -223,6 +234,9 @@ train_config = {
# -------------------------------------------------------------------------
# whether to save the model state to disk
'save': True,
# whether to early stop training if the accuracy on the validation set
# does not increase more than delta over patience epochs
'early_stop': True,
......@@ -246,7 +260,7 @@ train_config = {
}
# the evaluation configuration file
evaluation_config = {
eval_config = {
# ----------------------------- Evaluation --------------------------------
......@@ -256,8 +270,8 @@ evaluation_config = {
# pysegcnn.main.eval.py
# the dataset to evaluate the model on
# test=False means evaluating on the validation set
# test=True means evaluating on the test set
# test=False, 0 means evaluating on the validation set
# test=True, 1 means evaluating on the test set
# test=None means evaluating on the training set
'test': False,
......@@ -294,4 +308,4 @@ config = {**dataset_config,
**split_config,
**model_config,
**train_config,
**evaluation_config}
**eval_config}
......@@ -8,63 +8,83 @@ Created on Wed Jul 29 15:57:01 2020
import os
# locals
from pysegcnn.core.trainer import NetworkTrainer
from pysegcnn.core.trainer import (DatasetConfig, SplitConfig, ModelConfig,
TrainConfig, EvalConfig)
from pysegcnn.core.predict import predict_samples, predict_scenes
from pysegcnn.main.config import config, HERE
from pysegcnn.main.config import (dataset_config, split_config, model_config,
train_config, eval_config, HERE)
from pysegcnn.core.graphics import plot_confusion_matrix, plot_loss
if __name__ == '__main__':
# instanciate the NetworkTrainer class
trainer = NetworkTrainer(config)
trainer
# (i) instanciate the configurations
dc = DatasetConfig(**dataset_config)
sc = SplitConfig(**split_config)
mc = ModelConfig(**model_config)
tc = TrainConfig(**train_config)
ec = EvalConfig(**eval_config)
# (ii) instanciate the dataset
ds = dc.init_dataset()
ds
# (iii) instanciate the training, validation and test datasets
train_ds, valid_ds, test_ds = sc.train_val_test_split(ds)
# (iv) instanciate the model state files
state_file, loss_state = mc.init_state(ds, sc, tc)
# (v) instanciate the model
model = mc.init_model(ds)
# (vi) instanciate the optimizer
optimizer = tc.init_optimizer(model)
# plot loss and accuracy
plot_loss(trainer.loss_state, outpath=os.path.join(HERE, '_graphics/'))
plot_loss(loss_state, outpath=os.path.join(HERE, '_graphics/'))
# check whether to evaluate the model on the training set, validation set
# or the test set
if trainer.test is None:
ds = trainer.train_ds
if ec.test is None:
ds = train_ds
else:
ds = trainer.test_ds if trainer.test else trainer.valid_ds
ds = test_ds if ec.test else valid_ds
# keyword arguments for plotting
kwargs = {'bands': ec.plot_bands,
'outpath': os.path.join(HERE, '_scenes/'),
'stretch': True,
'alpha': 5}
# whether to predict each sample or each scene individually
if trainer.predict_scene:
if ec.predict_scene:
# reconstruct and predict the scenes in the validation/test set
scenes, cm = predict_scenes(ds,
trainer.model,
trainer.optimizer,
trainer.state_path,
trainer.state_file,
None,
trainer.cm,
trainer.plot_scenes,
bands=trainer.plot_bands,
outpath=os.path.join(HERE, '_scenes/'),
stretch=True,
alpha=5)
model,
optimizer,
state_file,
scene_id=None,
cm=ec.cm,
plot_scenes=ec.plot_scenes,
**kwargs
)
else:
# predict the samples in the validation/test set
samples, cm = predict_samples(ds,
trainer.model,
trainer.optimizer,
trainer.state_path,
trainer.state_file,
trainer.cm,
trainer.plot_samples,
bands=trainer.plot_bands,
outpath=os.path.join(HERE, '_samples/'),
stretch=True,
alpha=5)
model,
optimizer,
state_file,
cm=ec.cm,
plot_scenes=ec.plot_scenes,
**kwargs)
# whether to plot the confusion matrix
if trainer.cm:
if ec.cm:
plot_confusion_matrix(cm,
ds.dataset.labels,
normalize=True,
state=trainer.state_file,
state=state_file.name.replace('.pt', '.png'),
outpath=os.path.join(HERE, '_graphics/')
)
......@@ -6,19 +6,68 @@ Created on Tue Jun 30 09:33:38 2020
@author: Daniel
"""
# locals
from pysegcnn.core.initconf import NetworkTrainer
from pysegcnn.core.trainer import (DatasetConfig, SplitConfig, ModelConfig,
TrainConfig, NetworkTrainer)
from pysegcnn.main.config import (dataset_config, split_config,
model_config, train_config)
if __name__ == '__main__':
# instanciate the NetworkTrainer class
trainer = NetworkTrainer(dconfig=dataset_config,
sconfig=split_config,
mconfig=model_config,
tconfig=train_config)
trainer
# write code that checks for list of seeds, band combinations etc. here.
# train the network
# (i) instanciate the configurations
dc = DatasetConfig(**dataset_config)
sc = SplitConfig(**split_config)
mc = ModelConfig(**model_config)
tc = TrainConfig(**train_config)
# (ii) instanciate the dataset
ds = dc.init_dataset()
ds
# (iii) instanciate the training, validation and test datasets and
# dataloaders
train_ds, valid_ds, test_ds = sc.train_val_test_split(ds)
train_dl, valid_dl, test_dl = sc.dataloaders(train_ds,
valid_ds,
test_ds,
batch_size=mc.batch_size,
shuffle=True,
drop_last=False)
# (iv) instanciate the model state files
state_file, loss_state = mc.init_state(ds, sc, tc)
# (v) instanciate the model
model = mc.init_model(ds)
# (vi) instanciate the optimizer and the loss function
optimizer = tc.init_optimizer(model)
loss_function = tc.init_loss_function()
# (vii) resume training from an existing model checkpoint
checkpoint_state, max_accuracy = mc.load_checkpoint(state_file, loss_state,
model, optimizer)
# (viii) initialize network trainer class for eays model training
trainer = NetworkTrainer(model=model,
optimizer=optimizer,
loss_function=loss_function,
train_dl=train_dl,
valid_dl=valid_dl,
state_file=state_file,
loss_state=loss_state,
epochs=tc.epochs,
nthreads=tc.nthreads,
early_stop=tc.early_stop,
mode=tc.mode,
delta=tc.delta,
patience=tc.patience,
max_accuracy=max_accuracy,
checkpoint_state=checkpoint_state,
save=tc.save
)
# (ix) train model
training_state = trainer.train()
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment