Skip to content
Snippets Groups Projects
Commit 54793cce authored by Frisinghelli Daniel's avatar Frisinghelli Daniel
Browse files

Added training-test ratio to NetworkTrainer call

parent 431b1d2c
No related branches found
No related tags found
No related merge requests found
...@@ -16,8 +16,8 @@ import torch.optim as optim ...@@ -16,8 +16,8 @@ import torch.optim as optim
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# define path to working directory # define path to working directory
wd = 'C:/Eurac/2020/' # wd = 'C:/Eurac/2020/'
# wd = '/mnt/CEPH_PROJECTS/cci_snow/dfrisinghelli/' wd = '/mnt/CEPH_PROJECTS/cci_snow/dfrisinghelli/'
# path to the downloaded sparcs archive # path to the downloaded sparcs archive
sparcs_archive = os.path.join(wd, '_Datasets/Archives/l8cloudmasks.zip') sparcs_archive = os.path.join(wd, '_Datasets/Archives/l8cloudmasks.zip')
...@@ -40,7 +40,7 @@ tile_size = 125 ...@@ -40,7 +40,7 @@ tile_size = 125
seed = 0 seed = 0
# (ttratio * 100) % of the dataset will be used for training and validation # (ttratio * 100) % of the dataset will be used for training and validation
ttratio = 0.8 ttratio = 1
# (ttratio * tvratio) * 100 % will be used as the training dataset # (ttratio * tvratio) * 100 % will be used as the training dataset
# (1 - ttratio * tvratio) * 100 % will be used as the validation dataset # (1 - ttratio * tvratio) * 100 % will be used as the validation dataset
...@@ -49,7 +49,7 @@ tvratio = 0.8 ...@@ -49,7 +49,7 @@ tvratio = 0.8
# define the batch size # define the batch size
# determines how many samples of the dataset are processed until the weights # determines how many samples of the dataset are processed until the weights
# of the network are updated # of the network are updated
batch_size = 16 batch_size = 64
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
......
...@@ -20,7 +20,7 @@ from pytorch.train import NetworkTrainer ...@@ -20,7 +20,7 @@ from pytorch.train import NetworkTrainer
from pytorch.models import SegNet from pytorch.models import SegNet
from sparcs.sparcs_00_config import (sparcs_path, bands, tile_size, tvratio, from sparcs.sparcs_00_config import (sparcs_path, bands, tile_size, tvratio,
filters, skip_connection, kwargs, filters, skip_connection, kwargs,
loss_function, optimizer, lr, loss_function, optimizer, lr, ttratio,
batch_size, seed, state_file) batch_size, seed, state_file)
...@@ -58,5 +58,6 @@ state_file = net.__class__.__name__ + state_file ...@@ -58,5 +58,6 @@ state_file = net.__class__.__name__ + state_file
# instanciate NetworkTrainer class # instanciate NetworkTrainer class
print('------------------------ Dataset split ---------------------------') print('------------------------ Dataset split ---------------------------')
trainer = NetworkTrainer(net, dataset, loss_function, optimizer, trainer = NetworkTrainer(net, dataset, loss_function, optimizer,
batch_size=batch_size, tvratio=tvratio, seed=seed) batch_size=batch_size, tvratio=tvratio,
ttratio=ttratio, seed=seed)
print('------------------------------------------------------------------') print('------------------------------------------------------------------')
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment