diff --git a/climax/core/loss.py b/climax/core/loss.py index 5eb2fd5bd7c5900c35d67b2aa01afe5e0f06a53b..accb9abf9d283895659ce7373dd0ab08abb37381 100644 --- a/climax/core/loss.py +++ b/climax/core/loss.py @@ -48,7 +48,7 @@ class L1Loss(NaNLoss): return F.l1_loss(y_pred[mask], y_true[mask], reduction=self.reduction) -class BernoulliLoss(NaNLoss): +class BernoulliGammaLoss(NaNLoss): def __init__(self, size_average=None, reduce=None, reduction='mean', min_amount=0): @@ -57,13 +57,6 @@ class BernoulliLoss(NaNLoss): # minimum amount of precipitation to be classified as precipitation self.min_amount = min_amount - -class BernoulliGammaLoss(BernoulliLoss): - - def __init__(self, size_average=None, reduce=None, reduction='mean', - min_amount=0): - super().__init__(size_average, reduce, reduction) - def forward(self, y_pred, y_true): # convert to float32 @@ -111,12 +104,15 @@ class BernoulliGammaLoss(BernoulliLoss): return p * shape * scale -class BernoulliGenParetoLoss(BernoulliLoss): +class BernoulliGenParetoLoss(NaNLoss): def __init__(self, size_average=None, reduce=None, reduction='mean', min_amount=0): super().__init__(size_average, reduce, reduction) + # minimum amount of precipitation to be classified as precipitation + self.min_amount = min_amount + def forward(self, y_pred, y_true): # convert to float32 @@ -157,12 +153,15 @@ class BernoulliGenParetoLoss(BernoulliLoss): return self.reduce(loss) -class BernoulliWeibullLoss(BernoulliLoss): +class BernoulliWeibullLoss(NaNLoss): def __init__(self, size_average=None, reduce=None, reduction='mean', min_amount=0): super().__init__(size_average, reduce, reduction) + # minimum amount of precipitation to be classified as precipitation + self.min_amount = min_amount + def forward(self, y_pred, y_true): # convert to float32