diff --git a/climax/core/predict.py b/climax/core/predict.py index 850bdcec3b31793db2ebf3d123b07e5706027dd8..0a135a5c74e352aa34bbd9d94ab4ef38d5923990 100644 --- a/climax/core/predict.py +++ b/climax/core/predict.py @@ -64,7 +64,9 @@ def predict_ERA5(net, ERA5_ds, predictand, batch_size=16, **kwargs): elif predictand == 'pr': ds = { # probability of precipitation - 'prob': torch.sigmoid(target[:, 0, ...].squeeze()), + 'prob': torch.sigmoid( + torch.as_tensor(target[:, 0, ...].squeeze(), + dtype=torch.float32)).numpy(), # amount of precipitation: expected value of gamma distribution # pr = shape * scale