diff --git a/climax/core/predict.py b/climax/core/predict.py index 42cc74d447087f2b7c370cabaf68d1d77967bfd0..06de9e0f53cacc419ef9af763499dbd81c332baf 100644 --- a/climax/core/predict.py +++ b/climax/core/predict.py @@ -19,7 +19,7 @@ from climax.core.dataset import EoDataset, NetCDFDataset LOGGER = logging.getLogger(__name__) -def predict_ERA5(net, ERA5_ds, predictand, batch_size=16, **kwargs): +def predict_ERA5(net, ERA5_ds, predictand, loss, batch_size=16, **kwargs): # create PyTorch compliant dataset and dataloader instances for model # inference ds = NetCDFDataset(ERA5_ds, **kwargs) @@ -56,27 +56,44 @@ def predict_ERA5(net, ERA5_ds, predictand, batch_size=16, **kwargs): outputs) LOGGER.info('Mini-batch: {:d}/{:d}'.format(batch + 1, len(dl))) - # convert numpy array to xarray.Dataset - if predictand == 'tas' and net.classifier.out_channels == 2: - # in case of tas, the netwokr predicts both tasmax and tasmin - ds = {'tasmax': target[:, 0, ...].squeeze(), - 'tasmin': target[:, 1, ...].squeeze()} - elif predictand == 'pr' and net.classifier.out_channels == 3: - - # probability of precipitation - prob = torch.sigmoid(torch.as_tensor(target[:, 0, ...].squeeze(), - dtype=torch.float32)).numpy() - - # precipitation amount: expected value of Bernoulli-Gamma distribution - # pr = p * shape * scale - pr = (prob * np.exp(target[:, 1, ...].squeeze()) * - np.exp(target[:, 2, ...].squeeze())) - - ds = {'prob': prob, 'precipitation': pr} - - else: - # single predictand + # check how many output fields are modelled + if net.classifier.out_channels == 1: + # single output field ds = {predictand: target} + else: + # joint optimization of min and max temperature + if predictand == 'tas' and net.classifier.out_channels == 2: + # in case of tas, the network predicts both tasmax and tasmin + ds = {'tasmax': target[:, 0, ...].squeeze(), + 'tasmin': target[:, 1, ...].squeeze()} + + # probabilistic precipitation modelling + elif predictand == 'pr' and net.classifier.out_channels == 3: + + # probability of precipitation + prob = torch.sigmoid(torch.as_tensor(target[:, 0, ...].squeeze(), + dtype=torch.float32)).numpy() + + # check which loss function is used + + # Bernoulli-Gamma + if loss.__class__.__name__ == 'BernoulliGammaLoss': + # precipitation amount: expected value of Bernoulli-Gamma + # distribution + # pr = p * shape * scale + pr = (prob * np.exp(target[:, 1, ...].squeeze()) * + np.exp(target[:, 2, ...].squeeze())) + + # Bernoulli-GenPareto + if loss.__class__.__name__ == 'BernoulliGenParetoLoss': + # precipitation amount: expected value of Bernoulli-GenPareto + # distribution + # pr = + # pr = (prob * np.exp(target[:, 1, ...].squeeze()) * + # np.exp(target[:, 2, ...].squeeze())) + pass + + ds = {'prob': prob, 'precipitation': pr} # add coordinates to arrays ds = {k: EoDataset.add_coordinates(v) for k, v in ds.items()}