Skip to content
Snippets Groups Projects
Commit fe57ae65 authored by Frisinghelli Daniel's avatar Frisinghelli Daniel
Browse files

Adjusting main executables to changes in core.trainer.py

parent f7112679
No related branches found
No related tags found
No related merge requests found
......@@ -30,6 +30,12 @@ DATASET_PATH = os.path.join(DRIVE_PATH, DATASET_NAME)
# DATASET_PATH = os.path.join(DRIVE_PATH, DATASET_NAME, 'Training')
# DATASET_PATH = os.path.join(DRIVE_PATH, 'ProSnow', DATASET_NAME)
# path to store the model states
MODEL_PATH = os.path.join(HERE, '_models/')
# path to store model logs
LOG_PATH = os.path.join(HERE, '_logs/')
# the dataset configuration dictionary
dataset_config = {
......@@ -189,9 +195,6 @@ model_config = {
'dilation': 1 # the field of view of the kernel
},
# path to save trained models
'state_path': os.path.join(HERE, '_models/'),
# Transfer learning -------------------------------------------------------
# Use pretrained=True only if you wish to fine-tune a pre-trained model
......@@ -203,7 +206,7 @@ model_config = {
# was trained on
# whether to use a pretrained model for transfer learning
'transfer': True,
'transfer': False,
# name of the pretrained model to apply to a different dataset
'pretrained_model': 'UNet_SparcsDataset_t125_b64_rgbn.pt',
......@@ -226,6 +229,8 @@ model_config = {
# -------------------------------------------------------------------------
# whether to save the model state to disk
# model states are saved in: pysegcnn/main/_models
# model log files are saved in: pysegcnn/main/_logs
'save': True,
# whether to early stop training if the accuracy on the validation set
......@@ -260,6 +265,10 @@ eval_config = {
# these options are only used for evaluating a trained model using
# pysegcnn.main.eval.py
# the model to evaluate
'state_file': os.path.join(MODEL_PATH,
'UNet_SparcsDataset_Adam_SceneSplit_s0_t005v05_t125_b64_r4g3b2n5.pt'),
# the dataset to evaluate the model on
# test=False, 0 means evaluating on the validation set
# test=True, 1 means evaluating on the test set
......@@ -267,6 +276,7 @@ eval_config = {
'test': False,
# whether to compute and plot the confusion matrix
# output path is: pysegcnn/main/_graphics/
'cm': True,
# whether to predict each sample or each scene individually
......@@ -299,9 +309,3 @@ eval_config = {
'alpha': 5
}
# the complete configuration
config = {**dataset_config,
**split_config,
**model_config,
**eval_config}
......@@ -5,52 +5,41 @@ Created on Wed Jul 29 15:57:01 2020
@author: Daniel
"""
# builtins
import os
from logging.config import dictConfig
# locals
from pysegcnn.core.trainer import (DatasetConfig, SplitConfig, ModelConfig,
StateConfig, EvalConfig)
from pysegcnn.core.models import Network
from pysegcnn.core.trainer import EvalConfig, LogConfig
from pysegcnn.core.predict import predict_samples, predict_scenes
from pysegcnn.main.config import (dataset_config, split_config, model_config,
train_config, eval_config, HERE)
from pysegcnn.core.logging import log_conf
from pysegcnn.core.graphics import plot_confusion_matrix, plot_loss
from pysegcnn.main.config import eval_config
if __name__ == '__main__':
# (i) instanciate the configurations
dc = DatasetConfig(**dataset_config)
sc = SplitConfig(**split_config)
mc = ModelConfig(**model_config)
# instanciate the evaluation configuration
ec = EvalConfig(**eval_config)
# (ii) instanciate the dataset
ds = dc.init_dataset()
# initialize logging
log = LogConfig(ec.state_file)
dictConfig(log_conf(log.log_file))
# (iii) instanciate the training, validation and test datasets
train_ds, valid_ds, test_ds = sc.train_val_test_split(ds)
# (iv) instanciate the model state
state = StateConfig(ds, sc, mc)
state_file, loss_state = state.init_state()
# (vii) load pretrained model weights
model, _ = mc.load_pretrained(state_file)
model.state_file = state_file
# load the model state
model, _, model_state = Network.load(ec.state_file)
# plot loss and accuracy
plot_loss(loss_state, outpath=os.path.join(HERE, '_graphics/'))
plot_loss(ec.state_file, outpath=ec.models_path)
# check whether to evaluate the model on the training set, validation set
# or the test set
if ec.test is None:
ds = train_ds
ds = model_state['train_ds']
else:
ds = test_ds if ec.test else valid_ds
ds = model_state['test_ds'] if ec.test else model_state['valid_ds']
# keyword arguments for plotting
kwargs = {'bands': ec.plot_bands,
'outpath': os.path.join(HERE, '_scenes/'),
'alpha': ec.alpha,
'figsize': ec.figsize}
......@@ -58,16 +47,17 @@ if __name__ == '__main__':
if ec.predict_scene:
# reconstruct and predict the scenes in the validation/test set
scenes, cm = predict_scenes(ds, model, scene_id=None, cm=ec.cm,
plot=ec.plot_scenes, **kwargs)
plot=ec.plot_scenes,
outpath=ec.scenes_path, **kwargs)
else:
# predict the samples in the validation/test set
samples, cm = predict_samples(ds, model, cm=ec.cm,
plot=ec.plot_samples, **kwargs)
plot=ec.plot_samples,
outpath=ec.sample_path, **kwargs)
# whether to plot the confusion matrix
if ec.cm:
plot_confusion_matrix(cm, ds.dataset.labels,
state=state_file.name.replace('.pt', '.png'),
outpath=os.path.join(HERE, '_graphics/')
)
state=ec.state_file.name.replace('.pt', '.png'),
outpath=ec.models_path)
......@@ -6,13 +6,14 @@ Created on Tue Jun 30 09:33:38 2020
@author: Daniel
"""
# builtins
import logging
from logging.config import dictConfig
# locals
from pysegcnn.core.trainer import (DatasetConfig, SplitConfig, ModelConfig,
StateConfig, NetworkTrainer)
StateConfig, LogConfig, NetworkTrainer)
from pysegcnn.core.logging import log_conf
from pysegcnn.main.config import (dataset_config, split_config, model_config)
from pysegcnn.main.config import (dataset_config, split_config, model_config,
LOG_PATH)
if __name__ == '__main__':
......@@ -26,49 +27,42 @@ if __name__ == '__main__':
# (ii) instanciate the dataset
ds = dc.init_dataset()
ds
# (iii) instanciate the model state
state = StateConfig(ds, sc, mc)
state_file, loss_state = state.init_state()
state_file = state.init_state()
# initialize logging
log_file = str(state_file).replace('.pt', '_train.log')
logging.config.dictConfig(log_conf(log_file))
# (iv) initialize logging
log = LogConfig(state_file)
dictConfig(log_conf(log.log_file))
# (iv) instanciate the training, validation and test datasets and
# (v) instanciate the training, validation and test datasets and
# dataloaders
train_ds, valid_ds, test_ds = sc.train_val_test_split(ds)
train_dl, valid_dl, test_dl = sc.dataloaders(
train_ds, valid_ds, test_ds, batch_size=mc.batch_size, shuffle=True,
drop_last=False)
# (iv) instanciate the model
model = mc.init_model(ds)
# (vi) instanciate the model
model, optimizer, checkpoint_state = mc.init_model(ds, state_file)
# (vi) instanciate the optimizer and the loss function
optimizer = mc.init_optimizer(model)
# (vii) instanciate the loss function
loss_function = mc.init_loss_function()
# (vii) resume training from an existing model checkpoint
(model, optimizer, checkpoint_state, max_accuracy) = mc.from_checkpoint(
model, optimizer, state_file, loss_state)
# (viii) initialize network trainer class for eays model training
# (viii) initialize network trainer class for easy model training
trainer = NetworkTrainer(model=model,
optimizer=optimizer,
loss_function=loss_function,
train_dl=train_dl,
valid_dl=valid_dl,
test_dl=test_dl,
state_file=state_file,
loss_state=loss_state,
epochs=mc.epochs,
nthreads=mc.nthreads,
early_stop=mc.early_stop,
mode=mc.mode,
delta=mc.delta,
patience=mc.patience,
max_accuracy=max_accuracy,
checkpoint_state=checkpoint_state,
save=mc.save
)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment