Skip to content
Snippets Groups Projects
Commit d459b2cf authored by Frisinghelli Daniel's avatar Frisinghelli Daniel
Browse files

Implemented NLL for Bernoulli-GenPareto distribution.

parent 4f230889
No related branches found
No related tags found
No related merge requests found
......@@ -95,3 +95,55 @@ class BernoulliGammaLoss(NaNLoss):
torch.lgamma(gshape + self.epsilon))
return self.reduce(loss)
class BernoulliGenParetoLoss(NaNLoss):
def __init__(self, size_average=None, reduce=None, reduction='mean',
min_amount=0):
super().__init__(size_average, reduce, reduction)
# minimum amount of precipitation to be classified as precipitation
self.min_amount = min_amount
# small number ensuring numerical stability
self.epsilon = 1e-7
def forward(self, y_pred, y_true):
# convert to float32
y_pred = y_pred.type(torch.float32)
y_true = y_true.type(torch.float32).squeeze()
# missing values
mask = ~torch.isnan(y_true)
y_true = y_true[mask]
# mask values less than 0
y_true[y_true < 0] = 0
# calculate true probability of precipitation:
# 1 if y_true > min_amount else 0
p_true = (y_true > self.min_amount).type(torch.float32)
# estimates of precipitation probability and gamma shape and scale
# parameters: ensure numerical stability
# clip probabilities to (0, 1)
p_pred = torch.sigmoid(y_pred[:, 0, ...].squeeze()[mask])
# clip scale to (0, +infinity)
gscale = torch.exp(y_pred[:, 2, ...].squeeze()[mask])
gshape = y_pred[:, 1, ...].squeeze()[mask]
# negative log-likelihood function of Bernoulli-GenPareto distribution
# Bernoulli contribution
loss = - (1 - p_true) * torch.log(1 - p_pred + self.epsilon)
# Gamma contribution
loss -= p_true * (torch.log(p_pred + self.epsilon) + torch.log(
1 - (1 + (gshape * y_true / gscale)) ** (- 1 / gshape) +
self.epsilon))
return self.reduce(loss)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment