Skip to content
Snippets Groups Projects
Commit abe479b8 authored by Frisinghelli Daniel's avatar Frisinghelli Daniel
Browse files

Refactor.

parent f1d00246
No related branches found
No related tags found
No related merge requests found
......@@ -19,7 +19,7 @@ from climax.core.dataset import EoDataset, NetCDFDataset
LOGGER = logging.getLogger(__name__)
def predict_ERA5(net, ERA5_ds, predictand, batch_size=16, **kwargs):
def predict_ERA5(net, ERA5_ds, predictand, loss, batch_size=16, **kwargs):
# create PyTorch compliant dataset and dataloader instances for model
# inference
ds = NetCDFDataset(ERA5_ds, **kwargs)
......@@ -56,27 +56,44 @@ def predict_ERA5(net, ERA5_ds, predictand, batch_size=16, **kwargs):
outputs)
LOGGER.info('Mini-batch: {:d}/{:d}'.format(batch + 1, len(dl)))
# convert numpy array to xarray.Dataset
if predictand == 'tas' and net.classifier.out_channels == 2:
# in case of tas, the netwokr predicts both tasmax and tasmin
ds = {'tasmax': target[:, 0, ...].squeeze(),
'tasmin': target[:, 1, ...].squeeze()}
elif predictand == 'pr' and net.classifier.out_channels == 3:
# probability of precipitation
prob = torch.sigmoid(torch.as_tensor(target[:, 0, ...].squeeze(),
dtype=torch.float32)).numpy()
# precipitation amount: expected value of Bernoulli-Gamma distribution
# pr = p * shape * scale
pr = (prob * np.exp(target[:, 1, ...].squeeze()) *
np.exp(target[:, 2, ...].squeeze()))
ds = {'prob': prob, 'precipitation': pr}
else:
# single predictand
# check how many output fields are modelled
if net.classifier.out_channels == 1:
# single output field
ds = {predictand: target}
else:
# joint optimization of min and max temperature
if predictand == 'tas' and net.classifier.out_channels == 2:
# in case of tas, the network predicts both tasmax and tasmin
ds = {'tasmax': target[:, 0, ...].squeeze(),
'tasmin': target[:, 1, ...].squeeze()}
# probabilistic precipitation modelling
elif predictand == 'pr' and net.classifier.out_channels == 3:
# probability of precipitation
prob = torch.sigmoid(torch.as_tensor(target[:, 0, ...].squeeze(),
dtype=torch.float32)).numpy()
# check which loss function is used
# Bernoulli-Gamma
if loss.__class__.__name__ == 'BernoulliGammaLoss':
# precipitation amount: expected value of Bernoulli-Gamma
# distribution
# pr = p * shape * scale
pr = (prob * np.exp(target[:, 1, ...].squeeze()) *
np.exp(target[:, 2, ...].squeeze()))
# Bernoulli-GenPareto
if loss.__class__.__name__ == 'BernoulliGenParetoLoss':
# precipitation amount: expected value of Bernoulli-GenPareto
# distribution
# pr =
# pr = (prob * np.exp(target[:, 1, ...].squeeze()) *
# np.exp(target[:, 2, ...].squeeze()))
pass
ds = {'prob': prob, 'precipitation': pr}
# add coordinates to arrays
ds = {k: EoDataset.add_coordinates(v) for k, v in ds.items()}
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment