diff --git a/climax/core/predict.py b/climax/core/predict.py index 849e5f6bb0baa859dccbfb973e52228bc7c3be1f..850bdcec3b31793db2ebf3d123b07e5706027dd8 100644 --- a/climax/core/predict.py +++ b/climax/core/predict.py @@ -59,19 +59,24 @@ def predict_ERA5(net, ERA5_ds, predictand, batch_size=16, **kwargs): # convert numpy array to xarray.Dataset if predictand == 'tas': # in case of tas, the netwokr predicts both tasmax and tasmin - ds = {'tasmax': EoDataset.add_coordinates(target[:, 0, ...].squeeze()), - 'tasmin': EoDataset.add_coordinates(target[:, 1, ...].squeeze())} + ds = {'tasmax': target[:, 0, ...].squeeze(), + 'tasmin': target[:, 1, ...].squeeze()} elif predictand == 'pr': - ds = {'prob': EoDataset.add_coordinates(target[:, 0, ...].squeeze()), - - # amount of precipitation: expected value of gamma distribution - # pr = shape * scale - 'precipitation': EoDataset.add_coordinates( - (np.exp(target[:, 1, ...]) * - np.exp(target[:, 2, ...])).squeeze())} + ds = { + # probability of precipitation + 'prob': torch.sigmoid(target[:, 0, ...].squeeze()), + + # amount of precipitation: expected value of gamma distribution + # pr = shape * scale + 'precipitation': (np.exp(target[:, 1, ...]) * + np.exp(target[:, 2, ...])).squeeze() + } else: # single predictand - ds = {predictand: EoDataset.add_coordinates(target)} + ds = {predictand: target} + + # add coordinates to arrays + ds = {k: EoDataset.add_coordinates(v) for k, v in ds.items()} # create xarray dataset: dtype=Float32 ds = xr.Dataset(data_vars=ds,