From 77e1b3b065eb307f58511aba0f4b31011a1fbf2c Mon Sep 17 00:00:00 2001
From: "Daniel.Frisinghelli" <daniel.frisinghelli@eurac.edu>
Date: Thu, 7 Oct 2021 18:47:59 +0200
Subject: [PATCH] Fixed wrong initialization.

---
 climax/core/loss.py | 19 +++++++++----------
 1 file changed, 9 insertions(+), 10 deletions(-)

diff --git a/climax/core/loss.py b/climax/core/loss.py
index 5eb2fd5..accb9ab 100644
--- a/climax/core/loss.py
+++ b/climax/core/loss.py
@@ -48,7 +48,7 @@ class L1Loss(NaNLoss):
         return F.l1_loss(y_pred[mask], y_true[mask], reduction=self.reduction)
 
 
-class BernoulliLoss(NaNLoss):
+class BernoulliGammaLoss(NaNLoss):
 
     def __init__(self, size_average=None, reduce=None, reduction='mean',
                  min_amount=0):
@@ -57,13 +57,6 @@ class BernoulliLoss(NaNLoss):
         # minimum amount of precipitation to be classified as precipitation
         self.min_amount = min_amount
 
-
-class BernoulliGammaLoss(BernoulliLoss):
-
-    def __init__(self, size_average=None, reduce=None, reduction='mean',
-                 min_amount=0):
-        super().__init__(size_average, reduce, reduction)
-
     def forward(self, y_pred, y_true):
 
         # convert to float32
@@ -111,12 +104,15 @@ class BernoulliGammaLoss(BernoulliLoss):
         return p * shape * scale
 
 
-class BernoulliGenParetoLoss(BernoulliLoss):
+class BernoulliGenParetoLoss(NaNLoss):
 
     def __init__(self, size_average=None, reduce=None, reduction='mean',
                  min_amount=0):
         super().__init__(size_average, reduce, reduction)
 
+        # minimum amount of precipitation to be classified as precipitation
+        self.min_amount = min_amount
+
     def forward(self, y_pred, y_true):
 
         # convert to float32
@@ -157,12 +153,15 @@ class BernoulliGenParetoLoss(BernoulliLoss):
         return self.reduce(loss)
 
 
-class BernoulliWeibullLoss(BernoulliLoss):
+class BernoulliWeibullLoss(NaNLoss):
 
     def __init__(self, size_average=None, reduce=None, reduction='mean',
                  min_amount=0):
         super().__init__(size_average, reduce, reduction)
 
+        # minimum amount of precipitation to be classified as precipitation
+        self.min_amount = min_amount
+
     def forward(self, y_pred, y_true):
 
         # convert to float32
-- 
GitLab