From ba859495adc372b24386dd6de8f94cb438a586c8 Mon Sep 17 00:00:00 2001
From: "Daniel.Frisinghelli" <daniel.frisinghelli@eurac.edu>
Date: Fri, 23 Jul 2021 17:03:29 +0200
Subject: [PATCH] Changed handling of Nans.

---
 climax/core/loss.py | 4 ++--
 1 file changed, 2 insertions(+), 2 deletions(-)

diff --git a/climax/core/loss.py b/climax/core/loss.py
index 80226a0..9872f80 100644
--- a/climax/core/loss.py
+++ b/climax/core/loss.py
@@ -55,8 +55,8 @@ class BernoulliGammaLoss(NaNLoss):
     def forward(self, y_pred, y_true):
 
         # convert to float32
-        y_pred = y_pred.type(torch.float32)
-        y_true = y_true.type(torch.float32)
+        y_pred = y_pred.type(torch.float32).squeeze()
+        y_true = y_true.type(torch.float32).squeeze()
 
         # missing values
         mask = ~torch.isnan(y_true)
-- 
GitLab