Commit 87e29166 authored by Frisinghelli Daniel's avatar Frisinghelli Daniel
Browse files

Exclude missing values for regression tasks when computing loss.

parent 74462d9a
......@@ -1178,7 +1178,7 @@ class NetworkTrainer(BaseConfig):
loss = self.loss_function(outputs, labels.long())
else:
# exclude potentially missing values
mask = ~np.isnan(labels)
mask = ~torch.isnan(labels)
loss = self.loss_function(outputs[mask], labels[mask])
# compute the gradients of the loss function w.r.t.
......@@ -1332,7 +1332,7 @@ class NetworkTrainer(BaseConfig):
val_loss = self.loss_function(outputs, labels.long())
else:
# exclude potentially missing values
mask = ~np.isnan(labels)
mask = ~torch.isnan(labels)
val_loss = self.loss_function(outputs[mask], labels[mask])
loss.append(val_loss.item())
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment