Commit b526b330 authored by Frisinghelli Daniel's avatar Frisinghelli Daniel
Browse files

Implemented LR-scheduling.

parent 10ee61cf
......@@ -34,6 +34,7 @@ import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torch.optim import Optimizer
from torch.optim.lr_scheduler import _LRScheduler
from sklearn.metrics import confusion_matrix, classification_report
# locals
......@@ -1071,6 +1072,7 @@ class NetworkTrainer(BaseConfig):
src_valid_dl: DataLoader
src_test_dl: DataLoader = DataLoader(None)
loss_function: nn.modules.loss._Loss = nn.CrossEntropyLoss()
lr_scheduler: (type(None), _LRScheduler) = None
epochs: int = 1
nthreads: int = torch.get_num_threads()
early_stop: bool = False
......@@ -1290,6 +1292,10 @@ class NetworkTrainer(BaseConfig):
if self.save:
self.save_state()
# decay learning rate, if scheduler is specified
if self.lr_scheduler is not None:
self.lr_scheduler.step()
return self.training_state
def predict(self, dataloader, return_pred=False):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment