Commit 1ee52260 authored by Frisinghelli Daniel's avatar Frisinghelli Daniel
Browse files

Implemented gradient clipping.

parent ebc684ba
......@@ -1016,6 +1016,9 @@ class NetworkTrainer(BaseConfig):
A model checkpoint for ``model``. If specified, ``checkpoint_state``
should be a dictionary with keys describing the training metric.
The default is `{}`.
clip_gradients : `bool`
Whether to apply gradient clipping. Useful for problems with exploding
gradients. The default is `False`.
save : `bool`
Whether to save the model state to ``state_file``. The default is
`True`.
......@@ -1075,6 +1078,7 @@ class NetworkTrainer(BaseConfig):
delta: float = 0
patience: int = 10
checkpoint_state: dict = dataclasses.field(default_factory=dict)
clip_gradients: bool = False
save: bool = True
save_loaders: bool = True
multi_gpu: bool = False
......@@ -1188,6 +1192,11 @@ class NetworkTrainer(BaseConfig):
# the network weights
loss.backward()
# clip gradients
if self.clip_gradients:
nn.utils.nnclip_grad_value_(self.model.parameters(),
clip_value=1.0)
# update the weights
self.optimizer.step()
......@@ -1680,6 +1689,11 @@ class DomainAdaptationTrainer(NetworkTrainer):
# the network weights
tot_loss.backward()
# clip gradients
if self.clip_gradients:
nn.utils.nnclip_grad_value_(self.model.parameters(),
clip_value=1.0)
# update the weights
self.optimizer.step()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment