diff --git a/climax/core/loss.py b/climax/core/loss.py index 7ea7f33685f7ce74410abbc49ff1e65934b26c04..ff9f29c3b15bc3a87fe78c58a0f33eed85aa089d 100644 --- a/climax/core/loss.py +++ b/climax/core/loss.py @@ -132,16 +132,16 @@ class BernoulliGenParetoLoss(NaNLoss): # clip probabilities to (0, 1) p_pred = torch.sigmoid(y_pred[:, 0, ...].squeeze()[mask]) - # clip scale to (0, +infinity) + # clip shape and scale to (0, +infinity) + gshape = torch.exp(y_pred[:, 1, ...].squeeze()[mask]) gscale = torch.exp(y_pred[:, 2, ...].squeeze()[mask]) - gshape = y_pred[:, 1, ...].squeeze()[mask] # negative log-likelihood function of Bernoulli-GenPareto distribution # Bernoulli contribution loss = - (1 - p_true) * torch.log(1 - p_pred + self.epsilon) - # Gamma contribution + # GenPareto contribution loss -= p_true * (torch.log(p_pred + self.epsilon) + torch.log( 1 - (1 + (gshape * y_true / gscale)) ** (- 1 / gshape) + self.epsilon))