Commit ebc684ba authored by Frisinghelli Daniel's avatar Frisinghelli Daniel
Browse files

Remove NaN masking.

parent e3580c71
......@@ -1334,9 +1334,7 @@ class NetworkTrainer(BaseConfig):
if self.classification:
val_loss = self.loss_function(outputs, labels.long())
# exclude potentially missing values
mask = ~torch.isnan(labels)
val_loss = self.loss_function(outputs[mask], labels[mask])
val_loss = self.loss_function(outputs, labels)
# calculate predicted class labels
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment