Commit b94141b0 authored by Frisinghelli Daniel's avatar Frisinghelli Daniel
Browse files

Check label types for loss function.

parent 92476ba3
......@@ -1174,7 +1174,8 @@ class NetworkTrainer(BaseConfig):
outputs = self.model(inputs)
# compute loss
loss = self.loss_function(outputs, labels.long())
loss = self.loss_function(
outputs, labels.long() if self.classification else labels)
# compute the gradients of the loss function w.r.t.
# the network weights
......@@ -1323,7 +1324,8 @@ class NetworkTrainer(BaseConfig):
outputs = self.model(inputs)
# compute loss
val_loss = self.loss_function(outputs, labels.long())
val_loss = self.loss_function(
outputs, labels.long() if self.classification else labels)
loss.append(val_loss.item())
# calculate predicted class labels
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment