diff --git a/climax/core/loss.py b/climax/core/loss.py index c2cc811cb3d2ffaf3f8c81bbc46a1be6ac8be949..2147c5bc4a156ec957319b72b72a2cf2a3bcbe0b 100644 --- a/climax/core/loss.py +++ b/climax/core/loss.py @@ -26,6 +26,8 @@ class NaNLoss(_Loss): return tensor.mean() elif self.reduction == 'sum': return tensor.sum() + elif self.reduction == 'median': + return tensor.median() class MSELoss(NaNLoss): @@ -133,13 +135,9 @@ class BernoulliWeibullLoss(BernoulliLoss): # clip probabilities to (0, 1) p_pred = torch.sigmoid(y_pred[:, 0, ...].squeeze()[mask]) - # clip scale to (0, +infinity) - scale = torch.exp(y_pred[:, 2, ...].squeeze()[mask][~mask_p]) - - # clip shape to (0, 10) - # NOTE: in general shape in (0, +infinity), clipping is required for - # numerical stability + # clip shape and scale to (0, +infinity) shape = torch.exp(y_pred[:, 1, ...].squeeze()[mask][~mask_p]) + scale = torch.exp(y_pred[:, 2, ...].squeeze()[mask][~mask_p]) # negative log-likelihood function of Bernoulli-Weibull distribution loss = torch.zeros_like(y_true)