diff --git a/climax/core/predict.py b/climax/core/predict.py index c6a93ab98fb4609fea9c7640c0bf85ffd929d262..d5686ce25708e01d005294b9fed691513040b646 100644 --- a/climax/core/predict.py +++ b/climax/core/predict.py @@ -59,19 +59,27 @@ def predict_ERA5(net, ERA5_ds, predictand, batch_size=16, **kwargs): # convert numpy array to xarray.Dataset if predictand == 'tas': # in case of tas, the netwokr predicts both tasmax and tasmin - ds = {'tasmax': EoDataset.add_coordinates(target[:, 0, ...].squeeze()), - 'tasmin': EoDataset.add_coordinates(target[:, 1, ...].squeeze())} + ds = {'tasmax': target[:, 0, ...].squeeze(), + 'tasmin': target[:, 1, ...].squeeze()} elif predictand == 'pr': - ds = {'prob': EoDataset.add_coordinates(target[:, 0, ...].squeeze()), - # amount of precipitation: expected value of gamma distribution - # pr = shape * scale - 'pr': EoDataset.add_coordinates( - (np.exp(target[:, 1, ...]) * - np.exp(target[:, 2, ...])).squeeze())} + # probability of precipitation + prob = torch.sigmoid(torch.as_tensor(target[:, 0, ...].squeeze(), + dtype=torch.float32)).numpy() + + # precipitation amount: expected value of Bernoulli-Gamma distribution + # pr = p * shape * scale + pr = (prob * np.exp(target[:, 1, ...].squeeze()) * + np.exp(target[:, 2, ...].squeeze())) + + ds = {'prob': prob, 'precipitation': pr} + else: # single predictand - ds = {predictand: EoDataset.add_coordinates(target)} + ds = {predictand: target} + + # add coordinates to arrays + ds = {k: EoDataset.add_coordinates(v) for k, v in ds.items()} # create xarray dataset: dtype=Float32 ds = xr.Dataset(data_vars=ds,