diff --git a/main/config.py b/main/config.py index ccb1807b3c2d76e1a00f7b89fcd314e57de4e871..89e8771f9b9a8435cd50cb4c7516a493c32fed43 100755 --- a/main/config.py +++ b/main/config.py @@ -28,12 +28,12 @@ from pytorch.models import UNet wd = '/mnt/CEPH_PROJECTS/cci_snow/dfrisinghelli/' # define which dataset to train on -# dataset_name = 'Sparcs' -dataset_name = 'Cloud95' +dataset_name = 'Sparcs' +# dataset_name = 'Cloud95' # path to the dataset -# dataset_path = os.path.join(wd, '_Datasets/Sparcs') -dataset_path = os.path.join(wd, '_Datasets/Cloud95/Training') +dataset_path = os.path.join(wd, '_Datasets/Sparcs') +# dataset_path = os.path.join(wd, '_Datasets/Cloud95/Training') # the csv file containing the names of the informative patches of the # Cloud95 dataset @@ -46,7 +46,7 @@ bands = ['red', 'green', 'blue', 'nir'] # define the size of the network input # if None, the size will default to the size of a scene -tile_size = 192 +tile_size = 125 # ----------------------------------------------------------------------------- # ----------------------------------------------------------------------------- @@ -83,7 +83,7 @@ kwargs = {'kernel_size': 3, # the size of the convolving kernel state_path = os.path.join(wd, 'git/deep-learning/main/_models/') # whether to use a pretrained model -pretrained = True +pretrained = False # name of the pretrained model pretrained_model = 'UNet_SparcsDataset_t125_b64_rgbn.pt' @@ -100,12 +100,12 @@ ttratio = 1 # (ttratio * tvratio) * 100 % will be used as the training dataset # (1 - ttratio * tvratio) * 100 % will be used as the validation dataset -tvratio = 0.05 +tvratio = 0.8 # define the batch size # determines how many samples of the dataset are processed until the weights # of the network are updated -batch_size = 64 +batch_size = 128 # Training configuration ------------------------------------------------------ @@ -114,14 +114,14 @@ checkpoint = False # whether to early stop training if the accuracy (loss) on the validation set # does not increase (decrease) more than delta over patience epochs -early_stop = False +early_stop = True mode = 'max' delta = 0 patience = 10 # define the number of epochs: the number of maximum iterations over the whole # training dataset -epochs = 5 +epochs = 200 # define the number of threads nthreads = os.cpu_count()