diff --git a/pysegcnn/main/train_transfer.py b/pysegcnn/main/train_transfer.py index dfb9f36d812214a311deff4f8732577a8f798709..610d7ac5a57890ad376958dab3124d2bf606f9c1 100644 --- a/pysegcnn/main/train_transfer.py +++ b/pysegcnn/main/train_transfer.py @@ -110,11 +110,14 @@ if __name__ == '__main__': net, optimizer, checkpoint = net_mc.init_model( len(src_ds.use_bands), len(src_ds.labels), state_file) + # set the model state file + net.state_file = state_file + # (xv) instanciate the network trainer class trainer = DomainAdaptationTrainer( model=net, optimizer=optimizer, - state_file=net.state_file, + state_file=state_file, src_train_dl=src_tra_dl, src_valid_dl=src_val_dl, src_test_dl=src_tes_dl,