diff --git a/climax/core/predict.py b/climax/core/predict.py index bbefa411878b8d7eb2a92ea0d764bb8948aca266..bd29e2f099b45862f938ba6575ecd3af3ec0c351 100644 --- a/climax/core/predict.py +++ b/climax/core/predict.py @@ -74,12 +74,10 @@ def predict_ERA5(net, ERA5_ds, predictand, loss, batch_size=16, **kwargs): prob = torch.sigmoid(torch.as_tensor(target[:, 0, ...].squeeze(), dtype=torch.float32)).numpy() - # shape and scale parameters - shape = target[:, 1, ...].squeeze() - scale = target[:, 2, ...].squeeze() - # precipitation amount - pr = loss.predict(prob, shape, scale) + pr = loss.predict(prob, target[:, 1, ...].squeeze(), + target[:, 2, ...].squeeze()) + del target # clear memory # precipitation probability and amount ds = {'prob': prob, 'precipitation': pr}