From fe57ae654862b0d6164274f2fc0abf33e30a009a Mon Sep 17 00:00:00 2001 From: "Daniel.Frisinghelli" <daniel.frisinghelli@eurac.edu> Date: Mon, 17 Aug 2020 17:25:08 +0200 Subject: [PATCH] Adjusting main executables to changes in core.trainer.py --- pysegcnn/main/config.py | 24 +++++++++++--------- pysegcnn/main/eval.py | 50 +++++++++++++++++------------------------ pysegcnn/main/train.py | 34 ++++++++++++---------------- 3 files changed, 48 insertions(+), 60 deletions(-) diff --git a/pysegcnn/main/config.py b/pysegcnn/main/config.py index 8780589..c307f1d 100644 --- a/pysegcnn/main/config.py +++ b/pysegcnn/main/config.py @@ -30,6 +30,12 @@ DATASET_PATH = os.path.join(DRIVE_PATH, DATASET_NAME) # DATASET_PATH = os.path.join(DRIVE_PATH, DATASET_NAME, 'Training') # DATASET_PATH = os.path.join(DRIVE_PATH, 'ProSnow', DATASET_NAME) +# path to store the model states +MODEL_PATH = os.path.join(HERE, '_models/') + +# path to store model logs +LOG_PATH = os.path.join(HERE, '_logs/') + # the dataset configuration dictionary dataset_config = { @@ -189,9 +195,6 @@ model_config = { 'dilation': 1 # the field of view of the kernel }, - # path to save trained models - 'state_path': os.path.join(HERE, '_models/'), - # Transfer learning ------------------------------------------------------- # Use pretrained=True only if you wish to fine-tune a pre-trained model @@ -203,7 +206,7 @@ model_config = { # was trained on # whether to use a pretrained model for transfer learning - 'transfer': True, + 'transfer': False, # name of the pretrained model to apply to a different dataset 'pretrained_model': 'UNet_SparcsDataset_t125_b64_rgbn.pt', @@ -226,6 +229,8 @@ model_config = { # ------------------------------------------------------------------------- # whether to save the model state to disk + # model states are saved in: pysegcnn/main/_models + # model log files are saved in: pysegcnn/main/_logs 'save': True, # whether to early stop training if the accuracy on the validation set @@ -260,6 +265,10 @@ eval_config = { # these options are only used for evaluating a trained model using # pysegcnn.main.eval.py + # the model to evaluate + 'state_file': os.path.join(MODEL_PATH, + 'UNet_SparcsDataset_Adam_SceneSplit_s0_t005v05_t125_b64_r4g3b2n5.pt'), + # the dataset to evaluate the model on # test=False, 0 means evaluating on the validation set # test=True, 1 means evaluating on the test set @@ -267,6 +276,7 @@ eval_config = { 'test': False, # whether to compute and plot the confusion matrix + # output path is: pysegcnn/main/_graphics/ 'cm': True, # whether to predict each sample or each scene individually @@ -299,9 +309,3 @@ eval_config = { 'alpha': 5 } - -# the complete configuration -config = {**dataset_config, - **split_config, - **model_config, - **eval_config} diff --git a/pysegcnn/main/eval.py b/pysegcnn/main/eval.py index 82fe4e3..687c5bb 100644 --- a/pysegcnn/main/eval.py +++ b/pysegcnn/main/eval.py @@ -5,52 +5,41 @@ Created on Wed Jul 29 15:57:01 2020 @author: Daniel """ # builtins -import os +from logging.config import dictConfig # locals -from pysegcnn.core.trainer import (DatasetConfig, SplitConfig, ModelConfig, - StateConfig, EvalConfig) +from pysegcnn.core.models import Network +from pysegcnn.core.trainer import EvalConfig, LogConfig from pysegcnn.core.predict import predict_samples, predict_scenes -from pysegcnn.main.config import (dataset_config, split_config, model_config, - train_config, eval_config, HERE) +from pysegcnn.core.logging import log_conf from pysegcnn.core.graphics import plot_confusion_matrix, plot_loss +from pysegcnn.main.config import eval_config if __name__ == '__main__': - # (i) instanciate the configurations - dc = DatasetConfig(**dataset_config) - sc = SplitConfig(**split_config) - mc = ModelConfig(**model_config) + # instanciate the evaluation configuration ec = EvalConfig(**eval_config) - # (ii) instanciate the dataset - ds = dc.init_dataset() + # initialize logging + log = LogConfig(ec.state_file) + dictConfig(log_conf(log.log_file)) - # (iii) instanciate the training, validation and test datasets - train_ds, valid_ds, test_ds = sc.train_val_test_split(ds) - - # (iv) instanciate the model state - state = StateConfig(ds, sc, mc) - state_file, loss_state = state.init_state() - - # (vii) load pretrained model weights - model, _ = mc.load_pretrained(state_file) - model.state_file = state_file + # load the model state + model, _, model_state = Network.load(ec.state_file) # plot loss and accuracy - plot_loss(loss_state, outpath=os.path.join(HERE, '_graphics/')) + plot_loss(ec.state_file, outpath=ec.models_path) # check whether to evaluate the model on the training set, validation set # or the test set if ec.test is None: - ds = train_ds + ds = model_state['train_ds'] else: - ds = test_ds if ec.test else valid_ds + ds = model_state['test_ds'] if ec.test else model_state['valid_ds'] # keyword arguments for plotting kwargs = {'bands': ec.plot_bands, - 'outpath': os.path.join(HERE, '_scenes/'), 'alpha': ec.alpha, 'figsize': ec.figsize} @@ -58,16 +47,17 @@ if __name__ == '__main__': if ec.predict_scene: # reconstruct and predict the scenes in the validation/test set scenes, cm = predict_scenes(ds, model, scene_id=None, cm=ec.cm, - plot=ec.plot_scenes, **kwargs) + plot=ec.plot_scenes, + outpath=ec.scenes_path, **kwargs) else: # predict the samples in the validation/test set samples, cm = predict_samples(ds, model, cm=ec.cm, - plot=ec.plot_samples, **kwargs) + plot=ec.plot_samples, + outpath=ec.sample_path, **kwargs) # whether to plot the confusion matrix if ec.cm: plot_confusion_matrix(cm, ds.dataset.labels, - state=state_file.name.replace('.pt', '.png'), - outpath=os.path.join(HERE, '_graphics/') - ) + state=ec.state_file.name.replace('.pt', '.png'), + outpath=ec.models_path) diff --git a/pysegcnn/main/train.py b/pysegcnn/main/train.py index bcffb54..e28f55f 100644 --- a/pysegcnn/main/train.py +++ b/pysegcnn/main/train.py @@ -6,13 +6,14 @@ Created on Tue Jun 30 09:33:38 2020 @author: Daniel """ # builtins -import logging +from logging.config import dictConfig # locals from pysegcnn.core.trainer import (DatasetConfig, SplitConfig, ModelConfig, - StateConfig, NetworkTrainer) + StateConfig, LogConfig, NetworkTrainer) from pysegcnn.core.logging import log_conf -from pysegcnn.main.config import (dataset_config, split_config, model_config) +from pysegcnn.main.config import (dataset_config, split_config, model_config, + LOG_PATH) if __name__ == '__main__': @@ -26,49 +27,42 @@ if __name__ == '__main__': # (ii) instanciate the dataset ds = dc.init_dataset() - ds # (iii) instanciate the model state state = StateConfig(ds, sc, mc) - state_file, loss_state = state.init_state() + state_file = state.init_state() - # initialize logging - log_file = str(state_file).replace('.pt', '_train.log') - logging.config.dictConfig(log_conf(log_file)) + # (iv) initialize logging + log = LogConfig(state_file) + dictConfig(log_conf(log.log_file)) - # (iv) instanciate the training, validation and test datasets and + # (v) instanciate the training, validation and test datasets and # dataloaders train_ds, valid_ds, test_ds = sc.train_val_test_split(ds) train_dl, valid_dl, test_dl = sc.dataloaders( train_ds, valid_ds, test_ds, batch_size=mc.batch_size, shuffle=True, drop_last=False) - # (iv) instanciate the model - model = mc.init_model(ds) + # (vi) instanciate the model + model, optimizer, checkpoint_state = mc.init_model(ds, state_file) - # (vi) instanciate the optimizer and the loss function - optimizer = mc.init_optimizer(model) + # (vii) instanciate the loss function loss_function = mc.init_loss_function() - # (vii) resume training from an existing model checkpoint - (model, optimizer, checkpoint_state, max_accuracy) = mc.from_checkpoint( - model, optimizer, state_file, loss_state) - - # (viii) initialize network trainer class for eays model training + # (viii) initialize network trainer class for easy model training trainer = NetworkTrainer(model=model, optimizer=optimizer, loss_function=loss_function, train_dl=train_dl, valid_dl=valid_dl, + test_dl=test_dl, state_file=state_file, - loss_state=loss_state, epochs=mc.epochs, nthreads=mc.nthreads, early_stop=mc.early_stop, mode=mc.mode, delta=mc.delta, patience=mc.patience, - max_accuracy=max_accuracy, checkpoint_state=checkpoint_state, save=mc.save ) -- GitLab