Skip to content
Snippets Groups Projects
Commit 62c345fd authored by Frisinghelli Daniel's avatar Frisinghelli Daniel
Browse files

Implemented learning rate decay.

parent 125090f0
No related branches found
No related tags found
No related merge requests found
......@@ -150,8 +150,7 @@ TRAIN_CONFIG = {
'patience': 5,
'multi_gpu': True,
'classification': False,
'clip_gradients': True,
# 'lr_scheduler': torch.optim.lr_scheduler.
'clip_gradients': True
}
# whether to overwrite existing models
......
......@@ -162,7 +162,8 @@ if __name__ == '__main__':
# initialize network trainer
trainer = NetworkTrainer(net, optimizer, net.state_file, train_dl,
valid_dl, loss_function=LOSS, **TRAIN_CONFIG)
valid_dl, loss_function=LOSS,
lr_scheduler=LR_SCHEDULER, **TRAIN_CONFIG)
# train model
state = trainer.train()
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment