Skip to content
Snippets Groups Projects
Commit 45aef51c authored by Frisinghelli Daniel's avatar Frisinghelli Daniel
Browse files

Refactoring transfer learning code.

parent fdbfcf3b
No related branches found
No related tags found
No related merge requests found
......@@ -101,7 +101,7 @@ if __name__ == '__main__':
# transfer learning
net, optimizer, checkpoint = trn_sf.transfer_model(
trn_sf.pretrained_path,
nclasses=len(src_ds).labels,
nclasses=len(src_ds.labels),
optim_kwargs=net_mc.optim_kwargs,
freeze=trn_sf.freeze)
else:
......
......@@ -29,10 +29,14 @@ HERE = pathlib.Path(__file__).resolve().parent
# path to the datasets on the current machine
DRIVE_PATH = pathlib.Path('C:/Eurac/Projects/CCISNOW/Datasets/')
# DRIVE_PATH = pathlib.Path('/mnt/CEPH_PROJECTS/cci_snow/dfrisinghelli/Datasets/') # nopep8
# DRIVE_PATH = pathlib.Path('/home/dfrisinghelli/Datasets/')
# DRIVE_PATH = pathlib.Path('/home/clusterusers/dfrisinghelli_eurac/Datasets/')
# DRIVE_PATH = pathlib.Path('/scratch/dfrisinghelli_eurac/Datasets/')
# DRIVE_PATH = pathlib.Path('/localscratch/dfrisinghelli_eurac/Datasets/')
# name and paths to the datasets
DATASETS = {'Sparcs': DRIVE_PATH.joinpath('Sparcs'),
'Alcd': DRIVE_PATH.joinpath('Alcd/60m')
'Alcd': DRIVE_PATH.joinpath('Alcd')
}
# name of the source domain dataset
......@@ -45,10 +49,10 @@ TRG_DS = 'Alcd'
BANDS = ['red', 'green', 'blue', 'nir', 'swir1', 'swir2']
# tile size of a single sample
TILE_SIZE = 128
TILE_SIZE = 64
# number of folds for cross validation
K_FOLDS = 2
K_FOLDS = 1
# the source dataset configuration dictionary
src_ds_config = {
......@@ -206,7 +210,7 @@ src_split_config = {
# (ttratio * tvratio) * 100 % will be used for training
# (1 - ttratio * tvratio) * 100 % will be used for validation
# used if 'kfolds=1'
'tvratio': 0.8,
'tvratio': 0.05,
}
......@@ -219,7 +223,7 @@ trg_split_config = {
'seed': 0,
'shuffle': True,
'ttratio': 1,
'tvratio': 0.8,
'tvratio': 0.05,
}
......@@ -281,7 +285,7 @@ model_config = {
# define the number of epochs: the number of maximum iterations over
# the whole training dataset
'epochs': 100,
'epochs': 10,
}
......@@ -294,8 +298,8 @@ tlda_config = {
# whether to apply any sort of transfer learning
# if transfer=False, the model is only trained on the source dataset
# 'transfer': True,
'transfer': False,
'transfer': True,
# 'transfer': False,
# Supervised vs. Unsupervised ---------------------------------------------
# -------------------------------------------------------------------------
......@@ -313,13 +317,13 @@ tlda_config = {
# scratch ('uda_from_pretrained=False') or the pretrained
# model in 'pretrained_model' is loaded
# ('uda_from_pretrained=True')
# 'supervised': True,
'supervised': False,
'supervised': True,
# 'supervised': False,
# name of the pretrained model to apply for transfer learning
# required if transfer=True and supervised=True
# optional if transfer=True and supervised=False
'pretrained_model': '', # nopep8
'pretrained_model': 'Segnet_Adam_b128_AlcdDataset_m2_Scene_s0t10v08_t64_b2g3r4.pt', # nopep8
# loss function for unsupervised domain adaptation
# currently supported methods:
......@@ -356,6 +360,7 @@ tlda_config = {
# 'uda_pos': 'cla',
# whether to freeze the pretrained model weights
'freeze': True,
# 'freeze': True
'freeze': False
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment