From ba859495adc372b24386dd6de8f94cb438a586c8 Mon Sep 17 00:00:00 2001 From: "Daniel.Frisinghelli" <daniel.frisinghelli@eurac.edu> Date: Fri, 23 Jul 2021 17:03:29 +0200 Subject: [PATCH] Changed handling of Nans. --- climax/core/loss.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/climax/core/loss.py b/climax/core/loss.py index 80226a0..9872f80 100644 --- a/climax/core/loss.py +++ b/climax/core/loss.py @@ -55,8 +55,8 @@ class BernoulliGammaLoss(NaNLoss): def forward(self, y_pred, y_true): # convert to float32 - y_pred = y_pred.type(torch.float32) - y_true = y_true.type(torch.float32) + y_pred = y_pred.type(torch.float32).squeeze() + y_true = y_true.type(torch.float32).squeeze() # missing values mask = ~torch.isnan(y_true) -- GitLab