From d8fedcf3e4db2fb5da5601655223f1ad74d7c213 Mon Sep 17 00:00:00 2001 From: "Daniel.Frisinghelli" <daniel.frisinghelli@eurac.edu> Date: Thu, 7 Oct 2021 15:11:56 +0200 Subject: [PATCH] Stable implementation of Bernoulli-Weibull loss. --- climax/core/loss.py | 32 +++++++++++++++++++------------- 1 file changed, 19 insertions(+), 13 deletions(-) diff --git a/climax/core/loss.py b/climax/core/loss.py index 05b49ed..7fdb74b 100644 --- a/climax/core/loss.py +++ b/climax/core/loss.py @@ -58,8 +58,8 @@ class BernoulliGammaLoss(NaNLoss): def forward(self, y_pred, y_true): # convert to float32 - y_pred = y_pred.type(torch.float32) - y_true = y_true.type(torch.float32).squeeze() + y_pred = y_pred.type(torch.float64) + y_true = y_true.type(torch.float64).squeeze() # missing values mask = ~torch.isnan(y_true) @@ -109,8 +109,8 @@ class BernoulliGenParetoLoss(NaNLoss): def forward(self, y_pred, y_true): # convert to float32 - y_pred = y_pred.type(torch.float32) - y_true = y_true.type(torch.float32).squeeze() + y_pred = y_pred.type(torch.float64) + y_true = y_true.type(torch.float64).squeeze() # missing values mask = ~torch.isnan(y_true) @@ -158,8 +158,8 @@ class BernoulliWeibullLoss(NaNLoss): def forward(self, y_pred, y_true): # convert to float32 - y_pred = y_pred.type(torch.float32) - y_true = y_true.type(torch.float32).squeeze() + y_pred = y_pred.type(torch.float64) + y_true = y_true.type(torch.float64).squeeze() # missing values mask = ~torch.isnan(y_true) @@ -181,24 +181,30 @@ class BernoulliWeibullLoss(NaNLoss): # clip shape and scale to (0, 50) # NOTE: in general shape, scale in (0, +infinity), but clipping to # finite numbers is required for numerical stability - gshape = torch.clamp( - torch.exp(y_pred[:, 1, ...].squeeze()[mask]), max=50) - gscale = torch.clamp( - torch.exp(y_pred[:, 2, ...].squeeze()[mask]), max=50) + # gshape = torch.clamp( + # torch.exp(y_pred[:, 1, ...].squeeze()[mask]), max=1) + # gscale = torch.clamp( + # torch.exp(y_pred[:, 2, ...].squeeze()[mask]), max=20) + gshape = torch.sigmoid(y_pred[:, 1, ...].squeeze()[mask]) + gscale = torch.exp(y_pred[:, 2, ...].squeeze()[mask]) - # negative log-likelihood function of Bernoulli-GenPareto distribution + # negative log-likelihood function of Bernoulli-Weibull distribution # Bernoulli contribution loss = - (1 - p_true) * torch.log(1 - p_pred + self.epsilon) + # replace values < min_amount: ensures numerical stability in backprop + y_true[y_true <= self.min_amount] = 1 + # Weibull contribution loss -= p_true * (torch.log(p_pred + self.epsilon) + torch.log(gshape + self.epsilon) - gshape * torch.log(gscale + self.epsilon) + (gshape - 1) * torch.log(y_true + self.epsilon) - - (y_true / (gscale + self.epsilon)) ** gshape) + torch.pow(y_true / (gscale + self.epsilon), gshape) + ) # clip loss to finite values to ensure numerical stability - loss = torch.clamp(loss, max=100) + # loss = torch.clamp(loss, max=100) return self.reduce(loss) -- GitLab