diff --git a/climax/core/loss.py b/climax/core/loss.py index 9a6e99f86640e895de3e55dbabd0b9f3abc2dbdb..c2cc811cb3d2ffaf3f8c81bbc46a1be6ac8be949 100644 --- a/climax/core/loss.py +++ b/climax/core/loss.py @@ -148,16 +148,12 @@ class BernoulliWeibullLoss(BernoulliLoss): loss_bern = torch.log(1 - p_pred[mask_p] + self.epsilon) # Weibull contribution - # loss_weib = (torch.log(p_pred[~mask_p] + self.epsilon) + - # torch.log(shape + self.epsilon) - - # shape * torch.log(scale + self.epsilon) + - # (shape - 1) * torch.log(y_weib + self.epsilon) - - # torch.pow(y_weib / (scale + self.epsilon), shape) - # ) loss_weib = (torch.log(p_pred[~mask_p] + self.epsilon) + - torch.log(shape / scale) - - (shape - 1) * torch.log(y_weib / scale) - - torch.pow(y_weib / scale, shape)) + torch.log(shape + self.epsilon) - + shape * torch.log(scale + self.epsilon) + + (shape - 1) * torch.log(y_weib + self.epsilon) - + torch.pow(y_weib / (scale + self.epsilon), shape) + ) # fill loss array loss[torch.where(mask_p)] = - loss_bern