diff --git a/main/config.py b/main/config.py index 89e8771f9b9a8435cd50cb4c7516a493c32fed43..277c11fd2f8bb1106057f1eba9e377387615eada 100755 --- a/main/config.py +++ b/main/config.py @@ -28,12 +28,12 @@ from pytorch.models import UNet wd = '/mnt/CEPH_PROJECTS/cci_snow/dfrisinghelli/' # define which dataset to train on -dataset_name = 'Sparcs' -# dataset_name = 'Cloud95' +# dataset_name = 'Sparcs' +dataset_name = 'Cloud95' # path to the dataset -dataset_path = os.path.join(wd, '_Datasets/Sparcs') -# dataset_path = os.path.join(wd, '_Datasets/Cloud95/Training') +# dataset_path = os.path.join(wd, '_Datasets/Sparcs') +dataset_path = os.path.join(wd, '_Datasets/Cloud95/Training') # the csv file containing the names of the informative patches of the # Cloud95 dataset @@ -46,7 +46,7 @@ bands = ['red', 'green', 'blue', 'nir'] # define the size of the network input # if None, the size will default to the size of a scene -tile_size = 125 +tile_size = 192 # ----------------------------------------------------------------------------- # ----------------------------------------------------------------------------- @@ -83,10 +83,10 @@ kwargs = {'kernel_size': 3, # the size of the convolving kernel state_path = os.path.join(wd, 'git/deep-learning/main/_models/') # whether to use a pretrained model -pretrained = False +pretrained = True # name of the pretrained model -pretrained_model = 'UNet_SparcsDataset_t125_b64_rgbn.pt' +pretrained_model = 'UNet_SparcsDataset_t125_b128_rgbn.pt' # Dataset split --------------------------------------------------------------- @@ -100,12 +100,12 @@ ttratio = 1 # (ttratio * tvratio) * 100 % will be used as the training dataset # (1 - ttratio * tvratio) * 100 % will be used as the validation dataset -tvratio = 0.8 +tvratio = 0.05 # define the batch size # determines how many samples of the dataset are processed until the weights # of the network are updated -batch_size = 128 +batch_size = 64 # Training configuration ------------------------------------------------------ @@ -114,14 +114,14 @@ checkpoint = False # whether to early stop training if the accuracy (loss) on the validation set # does not increase (decrease) more than delta over patience epochs -early_stop = True +early_stop = False mode = 'max' delta = 0 patience = 10 # define the number of epochs: the number of maximum iterations over the whole # training dataset -epochs = 200 +epochs = 10 # define the number of threads nthreads = os.cpu_count() @@ -149,7 +149,7 @@ plot_cm = False # whether to save plots of (input, ground truth, prediction) of the validation # dataset to disk # output path is: current_working_directory/_samples/ -plot_samples = True +plot_samples = False # number of samples to plot # if nsamples = -1, all samples are plotted