Commit 6c886386 authored by Frisinghelli Daniel's avatar Frisinghelli Daniel
Browse files

Merge branch 'optim' into ai4ebv-public

parents 95309804 aff1839e
......@@ -1224,6 +1224,12 @@ class NetworkTrainer(BaseConfig):
progress = ', '.join([progress,
'Accuracy: {:.2f}'.format(acc)])
# decay learning rate after each batch, if scheduler is specified
if self.update_lr_after_batch:
self.lr_scheduler.step()
progress = ', '.join([progress, 'Learning rate: {:.5f}'.format(
self.lr_scheduler.get_last_lr().pop())])
# print progress
LOGGER.info(progress)
......@@ -1297,8 +1303,9 @@ class NetworkTrainer(BaseConfig):
if self.save:
self.save_state()
# decay learning rate, if scheduler is specified
if self.lr_scheduler is not None:
# decay learning rate after each epoch, if scheduler is specified
if (self.update_lr_after_batch is not None and not
self.update_lr_after_batch):
self.lr_scheduler.step()
LOGGER.info('Epoch: {:d}, Learning rate: {:.5f}'.format(
epoch + 1, self.lr_scheduler.get_last_lr().pop()))
......@@ -1402,6 +1409,18 @@ class NetworkTrainer(BaseConfig):
self.optimizer,
state=self.training_state,
**self.params_to_save)
@property
def update_lr_after_batch(self):
# check if a learning rate scheduler is specified
if self.lr_scheduler is None:
return
# check which kind of learning rate scheduler is specified
if isinstance(self.lr_scheduler, torch.optim.lr_scheduler.CyclicLR):
return True
else:
return False
@property
def training_state(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment