diff --git a/pysegcnn/main/config.py b/pysegcnn/main/config.py index 3a8707004e917bdc62bf51fd62fa714d7d5810f5..d2e8c4de7d90a8f974364ee7fde1d2f0db578cec 100644 --- a/pysegcnn/main/config.py +++ b/pysegcnn/main/config.py @@ -11,13 +11,7 @@ Modify the variable values to your needs, but DO NOT modify the variable names. # builtins import os -# externals -import torch.nn as nn -import torch.optim as optim - -# locals -from pysegcnn.core.models import UNet -from pysegcnn.core.transforms import Augment, FlipLr, FlipUd, Noise +# from pysegcnn.core.transforms import Augment, FlipLr, FlipUd, Noise # path to this file HERE = os.path.abspath(os.path.dirname(__file__)) @@ -59,6 +53,11 @@ dataset_config = { # tiles of size (tile_size, tile_size) 'pad': True, + # set random seed for reproducibility of the training, validation + # and test data split + # used if split_mode='random' and split_mode='scene' + 'seed': 0, + # the constant value to pad around the ground truth mask if pad=True 'cval': 99, @@ -132,11 +131,6 @@ split_config = { # the date build the validation set, the test set is empty 'split_mode': 'scene', - # set random seed for reproducibility of the training, validation - # and test data split - # used if split_mode='random' and split_mode='scene' - 'seed': 0, - # (ttratio * 100) % of the dataset will be used for training and # validation # used if split_mode='random' and split_mode='scene' @@ -155,6 +149,15 @@ split_config = { 'date': 'yyyymmdd', 'dateformat': '%Y%m%d', + # whether to drop samples (during training only) with a fraction of + # pixels equal to the constant padding value cval >= drop + # drop=1 means, do not use a sample if all pixels = cval + # drop=0.8 means, do not use a sample if 80% or more of the pixels are + # equal to cval + # drop=0.2 means, ... + # drop=0 means, do not drop any samples + 'drop': 0, + } # the model configuration dictionary @@ -165,7 +168,7 @@ model_config = { # ------------------------------------------------------------------------- # define the model - 'model': UNet, + 'model_name': 'Unet', # define the number of filters for each convolutional layer # the number of filters should increase with depth @@ -181,16 +184,6 @@ model_config = { 'dilation': 1 # the field of view of the kernel }, -} - - -# the training configuration dictionary -training_config = { - - # ----------------------------- Training --------------------------------- - - # ------------------------------------------------------------------------- - # path to save trained models 'state_path': os.path.join(HERE, '_models/'), @@ -213,7 +206,22 @@ training_config = { # Training ---------------------------------------------------------------- # whether to resume training from an existing model checkpoint - 'checkpoint': True, + 'checkpoint': False, + + # define the batch size + # determines how many samples of the dataset are processed until the + # weights of the network are updated (via mini-batch gradient descent) + 'batch_size': 64 + +} + + +# the training configuration dictionary +train_config = { + + # ----------------------------- Training --------------------------------- + + # ------------------------------------------------------------------------- # whether to early stop training if the accuracy on the validation set # does not increase more than delta over patience epochs @@ -222,31 +230,15 @@ training_config = { 'delta': 0, 'patience': 10, - # whether to drop samples (during training only) with a fraction of - # pixels equal to the constant padding value cval >= drop - # drop=1 means, do not use a sample if all pixels = cval - # drop=0.8 means, do not use a sample if 80% or more of the pixels are - # equal to cval - # drop=0.2 means, ... - 'drop': 1, - - # define the batch size - # determines how many samples of the dataset are processed until the - # weights of the network are updated (via mini-batch gradient descent) - 'batch_size': 64, - # define the number of epochs: the number of maximum iterations over # the whole training dataset 'epochs': 200, - # define the number of threads - 'nthreads': os.cpu_count(), - # define a loss function to calculate the network error - 'loss_function': nn.CrossEntropyLoss(), + 'loss_name': 'CrossEntropy', # define an optimizer to update the network weights - 'optimizer': optim.Adam, + 'optim_name': 'Adam', # define the learning rate 'lr': 0.001, @@ -301,5 +293,5 @@ evaluation_config = { config = {**dataset_config, **split_config, **model_config, - **training_config, + **train_config, **evaluation_config}