diff --git a/climax/core/loss.py b/climax/core/loss.py
index 80226a096cd0464c25c9ef61163b39f2d87a660d..9872f805dbeb617fad388bc9e459ed15d5c46b8b 100644
--- a/climax/core/loss.py
+++ b/climax/core/loss.py
@@ -55,8 +55,8 @@ class BernoulliGammaLoss(NaNLoss):
     def forward(self, y_pred, y_true):
 
         # convert to float32
-        y_pred = y_pred.type(torch.float32)
-        y_true = y_true.type(torch.float32)
+        y_pred = y_pred.type(torch.float32).squeeze()
+        y_true = y_true.type(torch.float32).squeeze()
 
         # missing values
         mask = ~torch.isnan(y_true)