Skip to content
Snippets Groups Projects
Commit 77e1b3b0 authored by Frisinghelli Daniel's avatar Frisinghelli Daniel
Browse files

Fixed wrong initialization.

parent 1676fe09
No related branches found
No related tags found
No related merge requests found
...@@ -48,7 +48,7 @@ class L1Loss(NaNLoss): ...@@ -48,7 +48,7 @@ class L1Loss(NaNLoss):
return F.l1_loss(y_pred[mask], y_true[mask], reduction=self.reduction) return F.l1_loss(y_pred[mask], y_true[mask], reduction=self.reduction)
class BernoulliLoss(NaNLoss): class BernoulliGammaLoss(NaNLoss):
def __init__(self, size_average=None, reduce=None, reduction='mean', def __init__(self, size_average=None, reduce=None, reduction='mean',
min_amount=0): min_amount=0):
...@@ -57,13 +57,6 @@ class BernoulliLoss(NaNLoss): ...@@ -57,13 +57,6 @@ class BernoulliLoss(NaNLoss):
# minimum amount of precipitation to be classified as precipitation # minimum amount of precipitation to be classified as precipitation
self.min_amount = min_amount self.min_amount = min_amount
class BernoulliGammaLoss(BernoulliLoss):
def __init__(self, size_average=None, reduce=None, reduction='mean',
min_amount=0):
super().__init__(size_average, reduce, reduction)
def forward(self, y_pred, y_true): def forward(self, y_pred, y_true):
# convert to float32 # convert to float32
...@@ -111,12 +104,15 @@ class BernoulliGammaLoss(BernoulliLoss): ...@@ -111,12 +104,15 @@ class BernoulliGammaLoss(BernoulliLoss):
return p * shape * scale return p * shape * scale
class BernoulliGenParetoLoss(BernoulliLoss): class BernoulliGenParetoLoss(NaNLoss):
def __init__(self, size_average=None, reduce=None, reduction='mean', def __init__(self, size_average=None, reduce=None, reduction='mean',
min_amount=0): min_amount=0):
super().__init__(size_average, reduce, reduction) super().__init__(size_average, reduce, reduction)
# minimum amount of precipitation to be classified as precipitation
self.min_amount = min_amount
def forward(self, y_pred, y_true): def forward(self, y_pred, y_true):
# convert to float32 # convert to float32
...@@ -157,12 +153,15 @@ class BernoulliGenParetoLoss(BernoulliLoss): ...@@ -157,12 +153,15 @@ class BernoulliGenParetoLoss(BernoulliLoss):
return self.reduce(loss) return self.reduce(loss)
class BernoulliWeibullLoss(BernoulliLoss): class BernoulliWeibullLoss(NaNLoss):
def __init__(self, size_average=None, reduce=None, reduction='mean', def __init__(self, size_average=None, reduce=None, reduction='mean',
min_amount=0): min_amount=0):
super().__init__(size_average, reduce, reduction) super().__init__(size_average, reduce, reduction)
# minimum amount of precipitation to be classified as precipitation
self.min_amount = min_amount
def forward(self, y_pred, y_true): def forward(self, y_pred, y_true):
# convert to float32 # convert to float32
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment