diff --git a/climax/core/loss.py b/climax/core/loss.py index c747989fcca66e13ec46cb4f0fa68669719a83b8..9a6e99f86640e895de3e55dbabd0b9f3abc2dbdb 100644 --- a/climax/core/loss.py +++ b/climax/core/loss.py @@ -155,10 +155,9 @@ class BernoulliWeibullLoss(BernoulliLoss): # torch.pow(y_weib / (scale + self.epsilon), shape) # ) loss_weib = (torch.log(p_pred[~mask_p] + self.epsilon) + - torch.log(shape / (scale + self.epsilon)) - - (shape - 1) * torch.log(y_weib / (scale + self.epsilon)) - - torch.pow(y_weib / (scale + self.epsilon), shape) - ) + torch.log(shape / scale) - + (shape - 1) * torch.log(y_weib / scale) - + torch.pow(y_weib / scale, shape)) # fill loss array loss[torch.where(mask_p)] = - loss_bern