Skip to content
GitLab
Menu
Projects
Groups
Snippets
/
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
earth_observation_public
PySegCNN
Commits
b526b330
Commit
b526b330
authored
Oct 14, 2021
by
Frisinghelli Daniel
Browse files
Implemented LR-scheduling.
parent
10ee61cf
Changes
1
Hide whitespace changes
Inline
Side-by-side
pysegcnn/core/trainer.py
View file @
b526b330
...
...
@@ -34,6 +34,7 @@ import torch.nn as nn
import
torch.nn.functional
as
F
from
torch.utils.data
import
DataLoader
from
torch.optim
import
Optimizer
from
torch.optim.lr_scheduler
import
_LRScheduler
from
sklearn.metrics
import
confusion_matrix
,
classification_report
# locals
...
...
@@ -1071,6 +1072,7 @@ class NetworkTrainer(BaseConfig):
src_valid_dl
:
DataLoader
src_test_dl
:
DataLoader
=
DataLoader
(
None
)
loss_function
:
nn
.
modules
.
loss
.
_Loss
=
nn
.
CrossEntropyLoss
()
lr_scheduler
:
(
type
(
None
),
_LRScheduler
)
=
None
epochs
:
int
=
1
nthreads
:
int
=
torch
.
get_num_threads
()
early_stop
:
bool
=
False
...
...
@@ -1290,6 +1292,10 @@ class NetworkTrainer(BaseConfig):
if
self
.
save
:
self
.
save_state
()
# decay learning rate, if scheduler is specified
if
self
.
lr_scheduler
is
not
None
:
self
.
lr_scheduler
.
step
()
return
self
.
training_state
def
predict
(
self
,
dataloader
,
return_pred
=
False
):
...
...
Write
Preview
Supports
Markdown
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment