Skip to content
Snippets Groups Projects
Commit 71eb0978 authored by Frisinghelli Daniel's avatar Frisinghelli Daniel
Browse files

Use sigmoid activation for p; optimize logarithms of scale, shape.

parent bdceca43
No related branches found
No related tags found
No related merge requests found
......@@ -11,8 +11,7 @@ from torch.nn.modules.loss import _Loss
class NaNLoss(_Loss):
def __init__(self, size_average=None, reduce=None, reduction='mean',
min_amount=0, epsilon=1e-7):
def __init__(self, size_average=None, reduce=None, reduction='mean'):
super().__init__(size_average, reduce, reduction)
def nan_reduce(self, tensor):
......@@ -29,8 +28,7 @@ class NaNLoss(_Loss):
class MSELoss(NaNLoss):
def __init__(self, size_average=None, reduce=None, reduction='mean',
min_amount=0, epsilon=1e-7):
def __init__(self, size_average=None, reduce=None, reduction='mean'):
super().__init__(size_average, reduce, reduction)
def forward(self, y_pred, y_true):
......@@ -39,8 +37,7 @@ class MSELoss(NaNLoss):
class L1Loss(NaNLoss):
def __init__(self, size_average=None, reduce=None, reduction='mean',
min_amount=0, epsilon=1e-7):
def __init__(self, size_average=None, reduce=None, reduction='mean'):
super().__init__(size_average, reduce, reduction)
def forward(self, y_pred, y_true):
......@@ -73,11 +70,11 @@ class BernoulliGammaLoss(NaNLoss):
# parameters: ensure numerical stability
# clip probabilities to (0, 1)
p_pred = y_pred[:, 0, ...].clamp(0, 1)
p_pred = F.sigmoid(y_pred[:, 0, ...])
# clip shape and scale to (0, +infinity)
gshape = y_pred[:, 1, ...].clamp(min=self.epsilon)
gscale = y_pred[:, 2, ...].clamp(min=self.epsilon)
gshape = torch.exp(y_pred[:, 1, ...].clamp(min=self.epsilon))
gscale = torch.exp(y_pred[:, 2, ...].clamp(min=self.epsilon))
# negative log-likelihood function of Bernoulli-Gamma distribution
......@@ -89,4 +86,4 @@ class BernoulliGammaLoss(NaNLoss):
- (y_true / gscale) - gshape * torch.log(gscale) -
torch.lgamma(gshape))
return self.nan_reduce(loss)
return - self.nan_reduce(loss)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment