From 7e20108b58bf828e190805ed32555a5794cf4297 Mon Sep 17 00:00:00 2001
From: "Daniel.Frisinghelli" <daniel.frisinghelli@eurac.edu>
Date: Fri, 8 Oct 2021 10:24:44 +0200
Subject: [PATCH] Move transformations to corresponding loss functions.

---
 climax/core/loss.py    | 4 ++--
 climax/core/predict.py | 4 ++--
 2 files changed, 4 insertions(+), 4 deletions(-)

diff --git a/climax/core/loss.py b/climax/core/loss.py
index 5eb2fd5..8895ab1 100644
--- a/climax/core/loss.py
+++ b/climax/core/loss.py
@@ -108,7 +108,7 @@ class BernoulliGammaLoss(BernoulliLoss):
     @staticmethod
     def predict(p, shape, scale):
         # pr = p * shape * scale
-        return p * shape * scale
+        return p * np.exp(shape) * np.exp(scale)
 
 
 class BernoulliGenParetoLoss(BernoulliLoss):
@@ -219,4 +219,4 @@ class BernoulliWeibullLoss(BernoulliLoss):
     @staticmethod
     def predict(p, shape, scale):
         # pr = p * scale * gamma(1 + 1 / shape)
-        return p * scale *  gamma(1 + 1 / shape)
+        return p * np.exp(scale) *  gamma(1 + 1 / np.exp(shape))
diff --git a/climax/core/predict.py b/climax/core/predict.py
index af1e38b..bbefa41 100644
--- a/climax/core/predict.py
+++ b/climax/core/predict.py
@@ -75,8 +75,8 @@ def predict_ERA5(net, ERA5_ds, predictand, loss, batch_size=16, **kwargs):
                                                  dtype=torch.float32)).numpy()
 
             # shape and scale parameters
-            shape = np.exp(target[:, 1, ...].squeeze())
-            scale = np.exp(target[:, 2, ...].squeeze())
+            shape = target[:, 1, ...].squeeze()
+            scale = target[:, 2, ...].squeeze()
 
             # precipitation amount
             pr = loss.predict(prob, shape, scale)
-- 
GitLab