From d8fedcf3e4db2fb5da5601655223f1ad74d7c213 Mon Sep 17 00:00:00 2001
From: "Daniel.Frisinghelli" <daniel.frisinghelli@eurac.edu>
Date: Thu, 7 Oct 2021 15:11:56 +0200
Subject: [PATCH] Stable implementation of Bernoulli-Weibull loss.

---
 climax/core/loss.py | 32 +++++++++++++++++++-------------
 1 file changed, 19 insertions(+), 13 deletions(-)

diff --git a/climax/core/loss.py b/climax/core/loss.py
index 05b49ed..7fdb74b 100644
--- a/climax/core/loss.py
+++ b/climax/core/loss.py
@@ -58,8 +58,8 @@ class BernoulliGammaLoss(NaNLoss):
     def forward(self, y_pred, y_true):
 
         # convert to float32
-        y_pred = y_pred.type(torch.float32)
-        y_true = y_true.type(torch.float32).squeeze()
+        y_pred = y_pred.type(torch.float64)
+        y_true = y_true.type(torch.float64).squeeze()
 
         # missing values
         mask = ~torch.isnan(y_true)
@@ -109,8 +109,8 @@ class BernoulliGenParetoLoss(NaNLoss):
     def forward(self, y_pred, y_true):
 
         # convert to float32
-        y_pred = y_pred.type(torch.float32)
-        y_true = y_true.type(torch.float32).squeeze()
+        y_pred = y_pred.type(torch.float64)
+        y_true = y_true.type(torch.float64).squeeze()
 
         # missing values
         mask = ~torch.isnan(y_true)
@@ -158,8 +158,8 @@ class BernoulliWeibullLoss(NaNLoss):
     def forward(self, y_pred, y_true):
 
         # convert to float32
-        y_pred = y_pred.type(torch.float32)
-        y_true = y_true.type(torch.float32).squeeze()
+        y_pred = y_pred.type(torch.float64)
+        y_true = y_true.type(torch.float64).squeeze()
 
         # missing values
         mask = ~torch.isnan(y_true)
@@ -181,24 +181,30 @@ class BernoulliWeibullLoss(NaNLoss):
         # clip shape and scale to (0, 50)
         # NOTE: in general shape, scale in (0, +infinity), but clipping to
         #       finite numbers is required for numerical stability
-        gshape = torch.clamp(
-            torch.exp(y_pred[:, 1, ...].squeeze()[mask]), max=50)
-        gscale = torch.clamp(
-            torch.exp(y_pred[:, 2, ...].squeeze()[mask]), max=50)
+        # gshape = torch.clamp(
+        #     torch.exp(y_pred[:, 1, ...].squeeze()[mask]), max=1)
+        # gscale = torch.clamp(
+        #     torch.exp(y_pred[:, 2, ...].squeeze()[mask]), max=20)
+        gshape = torch.sigmoid(y_pred[:, 1, ...].squeeze()[mask])
+        gscale = torch.exp(y_pred[:, 2, ...].squeeze()[mask])
 
-        # negative log-likelihood function of Bernoulli-GenPareto distribution
+        # negative log-likelihood function of Bernoulli-Weibull distribution
 
         # Bernoulli contribution
         loss = - (1 - p_true) * torch.log(1 - p_pred + self.epsilon)
 
+        # replace values < min_amount: ensures numerical stability in backprop
+        y_true[y_true <= self.min_amount] = 1
+
         # Weibull contribution
         loss -= p_true * (torch.log(p_pred + self.epsilon) +
                           torch.log(gshape + self.epsilon) -
                           gshape * torch.log(gscale + self.epsilon) +
                           (gshape - 1) * torch.log(y_true + self.epsilon) -
-                          (y_true / (gscale + self.epsilon)) ** gshape)
+                          torch.pow(y_true / (gscale + self.epsilon), gshape)
+                          )
 
         # clip loss to finite values to ensure numerical stability
-        loss = torch.clamp(loss, max=100)
+        # loss = torch.clamp(loss, max=100)
 
         return self.reduce(loss)
-- 
GitLab