diff --git a/pysegcnn/main/config.py b/pysegcnn/main/config.py index d2e8c4de7d90a8f974364ee7fde1d2f0db578cec..5325097f7d8a06d47ef5b1cbe69e855e0af2120b 100644 --- a/pysegcnn/main/config.py +++ b/pysegcnn/main/config.py @@ -16,14 +16,19 @@ import os # path to this file HERE = os.path.abspath(os.path.dirname(__file__)) -# path to the datasets -DATASET_PATH = 'C:/Eurac/2020/_Datasets/' -# DATASET_PATH = '/mnt/CEPH_PROJECTS/cci_snow/dfrisinghelli/_Datasets/' +# path to the datasets on the current machine +DRIVE_PATH ='C:/Eurac/2020/_Datasets/' +# DRIVE_PATH = '/mnt/CEPH_PROJECTS/cci_snow/dfrisinghelli/_Datasets/' # name of the datasets DATASET_NAME = 'Sparcs' -# DATASET_NAME = 'Cloud95/Training/' -# DATASET_NAME =' ProSnow/Garmisch/ +# DATASET_NAME = 'Cloud95' +# DATASET_NAME = 'Garmisch' + +# path to the dataset +DATASET_PATH = os.path.join(DRIVE_PATH, DATASET_NAME) +# DATASET_PATH = os.path.join(DRIVE_PATH, DATASET_NAME, 'Training') +# DATASET_PATH = os.path.join(DRIVE_PATH, 'ProSnow', DATASET_NAME) # the dataset configuration dictionary dataset_config = { @@ -32,8 +37,11 @@ dataset_config = { # ------------------------------------------------------------------------- + # name of the dataset + 'dataset_name': DATASET_NAME, + # path to the dataset - 'root_dir': os.path.join(DATASET_PATH, DATASET_NAME), + 'root_dir': DATASET_PATH, # a pattern to match the ground truth file naming convention 'gt_pattern': '*mask.png', @@ -53,8 +61,8 @@ dataset_config = { # tiles of size (tile_size, tile_size) 'pad': True, - # set random seed for reproducibility of the training, validation - # and test data split + # the random seed for the numpy random number generator + # ensures reproducibility of the training, validation and test data split # used if split_mode='random' and split_mode='scene' 'seed': 0, @@ -134,12 +142,12 @@ split_config = { # (ttratio * 100) % of the dataset will be used for training and # validation # used if split_mode='random' and split_mode='scene' - 'ttratio': 1, + 'ttratio': 0.05, # (ttratio * tvratio) * 100 % will be used as for training # (1 - ttratio * tvratio) * 100 % will be used for validation # used if split_mode='random' and split_mode='scene' - 'tvratio': 0.8, + 'tvratio': 0.5, # the date to split the scenes # format: 'yyyymmdd' @@ -211,7 +219,10 @@ model_config = { # define the batch size # determines how many samples of the dataset are processed until the # weights of the network are updated (via mini-batch gradient descent) - 'batch_size': 64 + 'batch_size': 64, + + # the seed for the random number generator intializing the network weights + 'torch_seed': 0 } @@ -223,6 +234,9 @@ train_config = { # ------------------------------------------------------------------------- + # whether to save the model state to disk + 'save': True, + # whether to early stop training if the accuracy on the validation set # does not increase more than delta over patience epochs 'early_stop': True, @@ -246,7 +260,7 @@ train_config = { } # the evaluation configuration file -evaluation_config = { +eval_config = { # ----------------------------- Evaluation -------------------------------- @@ -256,8 +270,8 @@ evaluation_config = { # pysegcnn.main.eval.py # the dataset to evaluate the model on - # test=False means evaluating on the validation set - # test=True means evaluating on the test set + # test=False, 0 means evaluating on the validation set + # test=True, 1 means evaluating on the test set # test=None means evaluating on the training set 'test': False, @@ -294,4 +308,4 @@ config = {**dataset_config, **split_config, **model_config, **train_config, - **evaluation_config} + **eval_config} diff --git a/pysegcnn/main/eval.py b/pysegcnn/main/eval.py index 73faa10c10220ca859fa0ee8b6f7c8c83c287839..2479340c895fdb5695584f812dd09358d610f4e7 100644 --- a/pysegcnn/main/eval.py +++ b/pysegcnn/main/eval.py @@ -8,63 +8,83 @@ Created on Wed Jul 29 15:57:01 2020 import os # locals -from pysegcnn.core.trainer import NetworkTrainer +from pysegcnn.core.trainer import (DatasetConfig, SplitConfig, ModelConfig, + TrainConfig, EvalConfig) from pysegcnn.core.predict import predict_samples, predict_scenes -from pysegcnn.main.config import config, HERE +from pysegcnn.main.config import (dataset_config, split_config, model_config, + train_config, eval_config, HERE) from pysegcnn.core.graphics import plot_confusion_matrix, plot_loss if __name__ == '__main__': - # instanciate the NetworkTrainer class - trainer = NetworkTrainer(config) - trainer + # (i) instanciate the configurations + dc = DatasetConfig(**dataset_config) + sc = SplitConfig(**split_config) + mc = ModelConfig(**model_config) + tc = TrainConfig(**train_config) + ec = EvalConfig(**eval_config) + + # (ii) instanciate the dataset + ds = dc.init_dataset() + ds + + # (iii) instanciate the training, validation and test datasets + train_ds, valid_ds, test_ds = sc.train_val_test_split(ds) + + # (iv) instanciate the model state files + state_file, loss_state = mc.init_state(ds, sc, tc) + + # (v) instanciate the model + model = mc.init_model(ds) + + # (vi) instanciate the optimizer + optimizer = tc.init_optimizer(model) # plot loss and accuracy - plot_loss(trainer.loss_state, outpath=os.path.join(HERE, '_graphics/')) + plot_loss(loss_state, outpath=os.path.join(HERE, '_graphics/')) # check whether to evaluate the model on the training set, validation set # or the test set - if trainer.test is None: - ds = trainer.train_ds + if ec.test is None: + ds = train_ds else: - ds = trainer.test_ds if trainer.test else trainer.valid_ds + ds = test_ds if ec.test else valid_ds + + # keyword arguments for plotting + kwargs = {'bands': ec.plot_bands, + 'outpath': os.path.join(HERE, '_scenes/'), + 'stretch': True, + 'alpha': 5} # whether to predict each sample or each scene individually - if trainer.predict_scene: + if ec.predict_scene: # reconstruct and predict the scenes in the validation/test set scenes, cm = predict_scenes(ds, - trainer.model, - trainer.optimizer, - trainer.state_path, - trainer.state_file, - None, - trainer.cm, - trainer.plot_scenes, - bands=trainer.plot_bands, - outpath=os.path.join(HERE, '_scenes/'), - stretch=True, - alpha=5) + model, + optimizer, + state_file, + scene_id=None, + cm=ec.cm, + plot_scenes=ec.plot_scenes, + **kwargs + ) else: # predict the samples in the validation/test set samples, cm = predict_samples(ds, - trainer.model, - trainer.optimizer, - trainer.state_path, - trainer.state_file, - trainer.cm, - trainer.plot_samples, - bands=trainer.plot_bands, - outpath=os.path.join(HERE, '_samples/'), - stretch=True, - alpha=5) + model, + optimizer, + state_file, + cm=ec.cm, + plot_scenes=ec.plot_scenes, + **kwargs) # whether to plot the confusion matrix - if trainer.cm: + if ec.cm: plot_confusion_matrix(cm, ds.dataset.labels, normalize=True, - state=trainer.state_file, + state=state_file.name.replace('.pt', '.png'), outpath=os.path.join(HERE, '_graphics/') ) diff --git a/pysegcnn/main/train.py b/pysegcnn/main/train.py index af0099e6c358b20f8e099ef77ecb4c78e0853b9f..6c036ea6e10ae589595af6c98ff4246eae19afd0 100644 --- a/pysegcnn/main/train.py +++ b/pysegcnn/main/train.py @@ -6,19 +6,68 @@ Created on Tue Jun 30 09:33:38 2020 @author: Daniel """ # locals -from pysegcnn.core.initconf import NetworkTrainer +from pysegcnn.core.trainer import (DatasetConfig, SplitConfig, ModelConfig, + TrainConfig, NetworkTrainer) from pysegcnn.main.config import (dataset_config, split_config, model_config, train_config) if __name__ == '__main__': - # instanciate the NetworkTrainer class - trainer = NetworkTrainer(dconfig=dataset_config, - sconfig=split_config, - mconfig=model_config, - tconfig=train_config) - trainer + # write code that checks for list of seeds, band combinations etc. here. - # train the network + # (i) instanciate the configurations + dc = DatasetConfig(**dataset_config) + sc = SplitConfig(**split_config) + mc = ModelConfig(**model_config) + tc = TrainConfig(**train_config) + + # (ii) instanciate the dataset + ds = dc.init_dataset() + ds + + # (iii) instanciate the training, validation and test datasets and + # dataloaders + train_ds, valid_ds, test_ds = sc.train_val_test_split(ds) + train_dl, valid_dl, test_dl = sc.dataloaders(train_ds, + valid_ds, + test_ds, + batch_size=mc.batch_size, + shuffle=True, + drop_last=False) + + # (iv) instanciate the model state files + state_file, loss_state = mc.init_state(ds, sc, tc) + + # (v) instanciate the model + model = mc.init_model(ds) + + # (vi) instanciate the optimizer and the loss function + optimizer = tc.init_optimizer(model) + loss_function = tc.init_loss_function() + + # (vii) resume training from an existing model checkpoint + checkpoint_state, max_accuracy = mc.load_checkpoint(state_file, loss_state, + model, optimizer) + + # (viii) initialize network trainer class for eays model training + trainer = NetworkTrainer(model=model, + optimizer=optimizer, + loss_function=loss_function, + train_dl=train_dl, + valid_dl=valid_dl, + state_file=state_file, + loss_state=loss_state, + epochs=tc.epochs, + nthreads=tc.nthreads, + early_stop=tc.early_stop, + mode=tc.mode, + delta=tc.delta, + patience=tc.patience, + max_accuracy=max_accuracy, + checkpoint_state=checkpoint_state, + save=tc.save + ) + + # (ix) train model training_state = trainer.train()