Commit e3580c71 authored by Frisinghelli Daniel's avatar Frisinghelli Daniel
Browse files

Remove NaN masking.

parent 804234d7
......@@ -1182,9 +1182,7 @@ class NetworkTrainer(BaseConfig):
if self.classification:
loss = self.loss_function(outputs, labels.long())
# exclude potentially missing values
mask = ~torch.isnan(labels)
loss = self.loss_function(outputs[mask], labels[mask])
loss = self.loss_function(outputs, labels)
# compute the gradients of the loss function w.r.t.
# the network weights
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment