Commit 74462d9a authored by Frisinghelli Daniel's avatar Frisinghelli Daniel
Browse files

Exclude missing values for regression tasks when computing loss.

parent b94141b0
......@@ -1173,9 +1173,13 @@ class NetworkTrainer(BaseConfig):
# perform forward pass
outputs = self.model(inputs)
# compute loss
loss = self.loss_function(
outputs, labels.long() if self.classification else labels)
# compute loss: classification vs. regression
if self.classification:
loss = self.loss_function(outputs, labels.long())
else:
# exclude potentially missing values
mask = ~np.isnan(labels)
loss = self.loss_function(outputs[mask], labels[mask])
# compute the gradients of the loss function w.r.t.
# the network weights
......@@ -1323,9 +1327,13 @@ class NetworkTrainer(BaseConfig):
with torch.no_grad():
outputs = self.model(inputs)
# compute loss
val_loss = self.loss_function(
outputs, labels.long() if self.classification else labels)
# compute loss: classification vs. regression
if self.classification:
val_loss = self.loss_function(outputs, labels.long())
else:
# exclude potentially missing values
mask = ~np.isnan(labels)
val_loss = self.loss_function(outputs[mask], labels[mask])
loss.append(val_loss.item())
# calculate predicted class labels
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment