Commit 64e94ac1 authored by Frisinghelli Daniel's avatar Frisinghelli Daniel
Browse files

Merge branch 'optim' into ai4ebv-public

parents 6fe75577 b526b330
...@@ -34,6 +34,7 @@ import torch.nn as nn ...@@ -34,6 +34,7 @@ import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from torch.optim import Optimizer from torch.optim import Optimizer
from torch.optim.lr_scheduler import _LRScheduler
from sklearn.metrics import confusion_matrix, classification_report from sklearn.metrics import confusion_matrix, classification_report
# locals # locals
...@@ -1071,6 +1072,7 @@ class NetworkTrainer(BaseConfig): ...@@ -1071,6 +1072,7 @@ class NetworkTrainer(BaseConfig):
src_valid_dl: DataLoader src_valid_dl: DataLoader
src_test_dl: DataLoader = DataLoader(None) src_test_dl: DataLoader = DataLoader(None)
loss_function: nn.modules.loss._Loss = nn.CrossEntropyLoss() loss_function: nn.modules.loss._Loss = nn.CrossEntropyLoss()
lr_scheduler: (type(None), _LRScheduler) = None
epochs: int = 1 epochs: int = 1
nthreads: int = torch.get_num_threads() nthreads: int = torch.get_num_threads()
early_stop: bool = False early_stop: bool = False
...@@ -1290,6 +1292,10 @@ class NetworkTrainer(BaseConfig): ...@@ -1290,6 +1292,10 @@ class NetworkTrainer(BaseConfig):
if self.save: if self.save:
self.save_state() self.save_state()
# decay learning rate, if scheduler is specified
if self.lr_scheduler is not None:
self.lr_scheduler.step()
return self.training_state return self.training_state
def predict(self, dataloader, return_pred=False): def predict(self, dataloader, return_pred=False):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment