diff --git a/climax/core/loss.py b/climax/core/loss.py index c2cc811cb3d2ffaf3f8c81bbc46a1be6ac8be949..c747989fcca66e13ec46cb4f0fa68669719a83b8 100644 --- a/climax/core/loss.py +++ b/climax/core/loss.py @@ -148,10 +148,15 @@ class BernoulliWeibullLoss(BernoulliLoss): loss_bern = torch.log(1 - p_pred[mask_p] + self.epsilon) # Weibull contribution + # loss_weib = (torch.log(p_pred[~mask_p] + self.epsilon) + + # torch.log(shape + self.epsilon) - + # shape * torch.log(scale + self.epsilon) + + # (shape - 1) * torch.log(y_weib + self.epsilon) - + # torch.pow(y_weib / (scale + self.epsilon), shape) + # ) loss_weib = (torch.log(p_pred[~mask_p] + self.epsilon) + - torch.log(shape + self.epsilon) - - shape * torch.log(scale + self.epsilon) + - (shape - 1) * torch.log(y_weib + self.epsilon) - + torch.log(shape / (scale + self.epsilon)) - + (shape - 1) * torch.log(y_weib / (scale + self.epsilon)) - torch.pow(y_weib / (scale + self.epsilon), shape) )