From 62c345fdf48d5a7bd954a2da96598406372d04a5 Mon Sep 17 00:00:00 2001 From: "Daniel.Frisinghelli" <daniel.frisinghelli@eurac.edu> Date: Thu, 14 Oct 2021 15:35:03 +0200 Subject: [PATCH] Implemented learning rate decay. --- climax/main/config.py | 3 +-- climax/main/downscale_train.py | 3 ++- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/climax/main/config.py b/climax/main/config.py index 068f82a..f246cb5 100644 --- a/climax/main/config.py +++ b/climax/main/config.py @@ -150,8 +150,7 @@ TRAIN_CONFIG = { 'patience': 5, 'multi_gpu': True, 'classification': False, - 'clip_gradients': True, - # 'lr_scheduler': torch.optim.lr_scheduler. + 'clip_gradients': True } # whether to overwrite existing models diff --git a/climax/main/downscale_train.py b/climax/main/downscale_train.py index 314ae37..74390e5 100644 --- a/climax/main/downscale_train.py +++ b/climax/main/downscale_train.py @@ -162,7 +162,8 @@ if __name__ == '__main__': # initialize network trainer trainer = NetworkTrainer(net, optimizer, net.state_file, train_dl, - valid_dl, loss_function=LOSS, **TRAIN_CONFIG) + valid_dl, loss_function=LOSS, + lr_scheduler=LR_SCHEDULER, **TRAIN_CONFIG) # train model state = trainer.train() -- GitLab