From 3c9b388b04e0c3f6317c60d4fed81e671862fc95 Mon Sep 17 00:00:00 2001
From: "Daniel.Frisinghelli" <daniel.frisinghelli@eurac.edu>
Date: Thu, 7 Oct 2021 15:40:58 +0200
Subject: [PATCH] Implemented Bernoulli-Weibull loss.

---
 climax/core/dataset.py |  6 ++++--
 climax/core/loss.py    | 25 ++++++++++++-------------
 climax/core/predict.py |  2 +-
 3 files changed, 17 insertions(+), 16 deletions(-)

diff --git a/climax/core/dataset.py b/climax/core/dataset.py
index 17f8366..302960c 100644
--- a/climax/core/dataset.py
+++ b/climax/core/dataset.py
@@ -19,7 +19,8 @@ import dask.array as da
 from climax.core.constants import (ERA5_VARIABLES, ERA5_PRESSURE_LEVELS,
                                    ERA5_P_VARIABLE_NAME, ERA5_S_VARIABLE_NAME,
                                    PROJECTION)
-from climax.core.loss import BernoulliGammaLoss, BernoulliGenParetoLoss
+from climax.core.loss import (BernoulliGammaLoss, BernoulliGenParetoLoss,
+                              BernoulliWeibullLoss)
 from pysegcnn.core.utils import search_files, img2np
 from pysegcnn.core.trainer import LogConfig
 
@@ -126,7 +127,8 @@ class EoDataset(torch.utils.data.Dataset):
         # check which loss function is used
         if predictand == 'pr':
             if (isinstance(loss, BernoulliGammaLoss) or
-                isinstance(loss, BernoulliGenParetoLoss)):
+                isinstance(loss, BernoulliGenParetoLoss) or
+                isinstance(loss, BernoulliWeibullLoss)):
                 # adjust state file for precipitation
                 state_file = '_'.join([state_file, '{}mm'.format(
                     str(loss.min_amount).replace('.', '')),
diff --git a/climax/core/loss.py b/climax/core/loss.py
index 7fdb74b..abb9ca3 100644
--- a/climax/core/loss.py
+++ b/climax/core/loss.py
@@ -178,15 +178,14 @@ class BernoulliWeibullLoss(NaNLoss):
         # clip probabilities to (0, 1)
         p_pred = torch.sigmoid(y_pred[:, 0, ...].squeeze()[mask])
 
-        # clip shape and scale to (0, 50)
-        # NOTE: in general shape, scale in (0, +infinity), but clipping to
-        #       finite numbers is required for numerical stability
-        # gshape = torch.clamp(
-        #     torch.exp(y_pred[:, 1, ...].squeeze()[mask]), max=1)
-        # gscale = torch.clamp(
-        #     torch.exp(y_pred[:, 2, ...].squeeze()[mask]), max=20)
-        gshape = torch.sigmoid(y_pred[:, 1, ...].squeeze()[mask])
-        gscale = torch.exp(y_pred[:, 2, ...].squeeze()[mask])
+        # clip scale to (0, +infinity)
+        scale = torch.exp(y_pred[:, 2, ...].squeeze()[mask])
+
+        # clip shape to (0, 1)
+        # NOTE: in general shape in (0, +infinity), but for precipitation the
+        #       shape parameter of the Weibull distribution < 1
+        shape = torch.clamp(
+            torch.exp(y_pred[:, 1, ...].squeeze()[mask]), max=1)
 
         # negative log-likelihood function of Bernoulli-Weibull distribution
 
@@ -198,10 +197,10 @@ class BernoulliWeibullLoss(NaNLoss):
 
         # Weibull contribution
         loss -= p_true * (torch.log(p_pred + self.epsilon) +
-                          torch.log(gshape + self.epsilon) -
-                          gshape * torch.log(gscale + self.epsilon) +
-                          (gshape - 1) * torch.log(y_true + self.epsilon) -
-                          torch.pow(y_true / (gscale + self.epsilon), gshape)
+                          torch.log(shape + self.epsilon) -
+                          shape * torch.log(scale + self.epsilon) +
+                          (shape - 1) * torch.log(y_true + self.epsilon) -
+                          torch.pow(y_true / (scale + self.epsilon), shape)
                           )
 
         # clip loss to finite values to ensure numerical stability
diff --git a/climax/core/predict.py b/climax/core/predict.py
index b09bf22..4ece682 100644
--- a/climax/core/predict.py
+++ b/climax/core/predict.py
@@ -92,7 +92,7 @@ def predict_ERA5(net, ERA5_ds, predictand, loss, batch_size=16, **kwargs):
                 # distribution
                 # pr = p * scale * tau(1 + 1 / shape)
                 pr = (prob * np.exp(target[:, 2, ...].squeeze()) *
-                      gamma(1 + 1 / torch.sigmoid(target[:, 1, ...])))
+                      gamma(1 + 1 / np.exp(target[:, 1, ...])))
 
             ds = {'prob': prob, 'precipitation': pr}
 
-- 
GitLab