diff --git a/climax/main/config.py b/climax/main/config.py index 068f82a4d946ae577594e090c986eff3aa4d9fb7..f246cb544d03311e01224c5e2f9e9015cc6e5530 100644 --- a/climax/main/config.py +++ b/climax/main/config.py @@ -150,8 +150,7 @@ TRAIN_CONFIG = { 'patience': 5, 'multi_gpu': True, 'classification': False, - 'clip_gradients': True, - # 'lr_scheduler': torch.optim.lr_scheduler. + 'clip_gradients': True } # whether to overwrite existing models diff --git a/climax/main/downscale_train.py b/climax/main/downscale_train.py index 314ae378e8b7367cf1aa4627961e4b68a9e42a4a..74390e55bdd7787d77a799c03abd68a111fd2731 100644 --- a/climax/main/downscale_train.py +++ b/climax/main/downscale_train.py @@ -162,7 +162,8 @@ if __name__ == '__main__': # initialize network trainer trainer = NetworkTrainer(net, optimizer, net.state_file, train_dl, - valid_dl, loss_function=LOSS, **TRAIN_CONFIG) + valid_dl, loss_function=LOSS, + lr_scheduler=LR_SCHEDULER, **TRAIN_CONFIG) # train model state = trainer.train()