From 1a42b5b52a0e6512e844b79cdb0034c0b1a24c44 Mon Sep 17 00:00:00 2001 From: "Daniel.Frisinghelli" <daniel.frisinghelli@eurac.edu> Date: Fri, 23 Jul 2021 17:36:21 +0200 Subject: [PATCH] Correct shape. --- climax/core/loss.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/climax/core/loss.py b/climax/core/loss.py index 5d7ab40..02e7a5b 100644 --- a/climax/core/loss.py +++ b/climax/core/loss.py @@ -58,7 +58,7 @@ class BernoulliGammaLoss(NaNLoss): def forward(self, y_pred, y_true): # convert to float32 - y_pred = y_pred.type(torch.float32).squeeze() + y_pred = y_pred.type(torch.float32) y_true = y_true.type(torch.float32).squeeze() # missing values @@ -76,11 +76,11 @@ class BernoulliGammaLoss(NaNLoss): # parameters: ensure numerical stability # clip probabilities to (0, 1) - p_pred = torch.sigmoid(y_pred[:, 0, ...][mask]) + p_pred = torch.sigmoid(y_pred[:, 0, ...].squeeze()[mask]) # clip shape and scale to (0, +infinity) - gshape = torch.exp(y_pred[:, 1, ...][mask]) - gscale = torch.exp(y_pred[:, 2, ...][mask]) + gshape = torch.exp(y_pred[:, 1, ...].squeeze()[mask]) + gscale = torch.exp(y_pred[:, 2, ...].squeeze()[mask]) # negative log-likelihood function of Bernoulli-Gamma distribution -- GitLab