Skip to content
Snippets Groups Projects
Commit d8fedcf3 authored by Frisinghelli Daniel's avatar Frisinghelli Daniel
Browse files

Stable implementation of Bernoulli-Weibull loss.

parent 135f805f
No related branches found
No related tags found
No related merge requests found
......@@ -58,8 +58,8 @@ class BernoulliGammaLoss(NaNLoss):
def forward(self, y_pred, y_true):
# convert to float32
y_pred = y_pred.type(torch.float32)
y_true = y_true.type(torch.float32).squeeze()
y_pred = y_pred.type(torch.float64)
y_true = y_true.type(torch.float64).squeeze()
# missing values
mask = ~torch.isnan(y_true)
......@@ -109,8 +109,8 @@ class BernoulliGenParetoLoss(NaNLoss):
def forward(self, y_pred, y_true):
# convert to float32
y_pred = y_pred.type(torch.float32)
y_true = y_true.type(torch.float32).squeeze()
y_pred = y_pred.type(torch.float64)
y_true = y_true.type(torch.float64).squeeze()
# missing values
mask = ~torch.isnan(y_true)
......@@ -158,8 +158,8 @@ class BernoulliWeibullLoss(NaNLoss):
def forward(self, y_pred, y_true):
# convert to float32
y_pred = y_pred.type(torch.float32)
y_true = y_true.type(torch.float32).squeeze()
y_pred = y_pred.type(torch.float64)
y_true = y_true.type(torch.float64).squeeze()
# missing values
mask = ~torch.isnan(y_true)
......@@ -181,24 +181,30 @@ class BernoulliWeibullLoss(NaNLoss):
# clip shape and scale to (0, 50)
# NOTE: in general shape, scale in (0, +infinity), but clipping to
# finite numbers is required for numerical stability
gshape = torch.clamp(
torch.exp(y_pred[:, 1, ...].squeeze()[mask]), max=50)
gscale = torch.clamp(
torch.exp(y_pred[:, 2, ...].squeeze()[mask]), max=50)
# gshape = torch.clamp(
# torch.exp(y_pred[:, 1, ...].squeeze()[mask]), max=1)
# gscale = torch.clamp(
# torch.exp(y_pred[:, 2, ...].squeeze()[mask]), max=20)
gshape = torch.sigmoid(y_pred[:, 1, ...].squeeze()[mask])
gscale = torch.exp(y_pred[:, 2, ...].squeeze()[mask])
# negative log-likelihood function of Bernoulli-GenPareto distribution
# negative log-likelihood function of Bernoulli-Weibull distribution
# Bernoulli contribution
loss = - (1 - p_true) * torch.log(1 - p_pred + self.epsilon)
# replace values < min_amount: ensures numerical stability in backprop
y_true[y_true <= self.min_amount] = 1
# Weibull contribution
loss -= p_true * (torch.log(p_pred + self.epsilon) +
torch.log(gshape + self.epsilon) -
gshape * torch.log(gscale + self.epsilon) +
(gshape - 1) * torch.log(y_true + self.epsilon) -
(y_true / (gscale + self.epsilon)) ** gshape)
torch.pow(y_true / (gscale + self.epsilon), gshape)
)
# clip loss to finite values to ensure numerical stability
loss = torch.clamp(loss, max=100)
# loss = torch.clamp(loss, max=100)
return self.reduce(loss)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment