diff --git a/climax/core/loss.py b/climax/core/loss.py index 02e7a5b8897eee228ce9e6384b3a2d07656f5750..5ed3af283569b6e48eaddd3e1b47643c94de8af8 100644 --- a/climax/core/loss.py +++ b/climax/core/loss.py @@ -92,6 +92,7 @@ class BernoulliGammaLoss(NaNLoss): torch.log(y_true + self.epsilon) - y_true / (gscale + self.epsilon) - gshape * torch.log(gscale + self.epsilon) - - torch.lgamma(gshape + self.epsilon)) + torch.log(torch.lgamma(gshape + self.epsilon) + + self.epsilon)) return - self.reduce(loss)