diff --git a/climax/core/loss.py b/climax/core/loss.py
index 5d7ab4045059c0f7e96d2a09f5bace6e246102ae..02e7a5b8897eee228ce9e6384b3a2d07656f5750 100644
--- a/climax/core/loss.py
+++ b/climax/core/loss.py
@@ -58,7 +58,7 @@ class BernoulliGammaLoss(NaNLoss):
     def forward(self, y_pred, y_true):
 
         # convert to float32
-        y_pred = y_pred.type(torch.float32).squeeze()
+        y_pred = y_pred.type(torch.float32)
         y_true = y_true.type(torch.float32).squeeze()
 
         # missing values
@@ -76,11 +76,11 @@ class BernoulliGammaLoss(NaNLoss):
         # parameters: ensure numerical stability
 
         # clip probabilities to (0, 1)
-        p_pred = torch.sigmoid(y_pred[:, 0, ...][mask])
+        p_pred = torch.sigmoid(y_pred[:, 0, ...].squeeze()[mask])
 
         # clip shape and scale to (0, +infinity)
-        gshape = torch.exp(y_pred[:, 1, ...][mask])
-        gscale = torch.exp(y_pred[:, 2, ...][mask])
+        gshape = torch.exp(y_pred[:, 1, ...].squeeze()[mask])
+        gscale = torch.exp(y_pred[:, 2, ...].squeeze()[mask])
 
         # negative log-likelihood function of Bernoulli-Gamma distribution