Skip to content
Snippets Groups Projects
Commit 3aaf11cd authored by Frisinghelli Daniel's avatar Frisinghelli Daniel
Browse files

Implemented Bernoulli-Weibull loss.

parent aa62cddf
No related branches found
No related tags found
No related merge requests found
...@@ -147,3 +147,56 @@ class BernoulliGenParetoLoss(NaNLoss): ...@@ -147,3 +147,56 @@ class BernoulliGenParetoLoss(NaNLoss):
self.epsilon)) self.epsilon))
return self.reduce(loss) return self.reduce(loss)
class BernoulliWeibullLoss(NaNLoss):
def __init__(self, size_average=None, reduce=None, reduction='mean',
min_amount=0):
super().__init__(size_average, reduce, reduction)
# minimum amount of precipitation to be classified as precipitation
self.min_amount = min_amount
# small number ensuring numerical stability
self.epsilon = 1e-7
def forward(self, y_pred, y_true):
# convert to float32
y_pred = y_pred.type(torch.float32)
y_true = y_true.type(torch.float32).squeeze()
# missing values
mask = ~torch.isnan(y_true)
y_true = y_true[mask]
# mask values less than 0
y_true[y_true < 0] = 0
# calculate true probability of precipitation:
# 1 if y_true > min_amount else 0
p_true = (y_true > self.min_amount).type(torch.float32)
# estimates of precipitation probability and gamma shape and scale
# parameters: ensure numerical stability
# clip probabilities to (0, 1)
p_pred = torch.sigmoid(y_pred[:, 0, ...].squeeze()[mask])
# clip shape and scale to (0, +infinity)
gshape = torch.exp(y_pred[:, 1, ...].squeeze()[mask])
gscale = torch.exp(y_pred[:, 2, ...].squeeze()[mask])
# negative log-likelihood function of Bernoulli-GenPareto distribution
# Bernoulli contribution
loss = - (1 - p_true) * torch.log(1 - p_pred + self.epsilon)
# Weibull contribution
loss -= p_true * (torch.log(gshape + self.epsilon) -
gshape * torch.log(gscale + self.epsilon) +
(gshape - 1) * torch.log(y_true + self.epsilon) -
(y_true / (gscale + self.epsilon)) ** gshape)
return self.reduce(loss)
...@@ -11,9 +11,11 @@ import torch ...@@ -11,9 +11,11 @@ import torch
import numpy as np import numpy as np
import xarray as xr import xarray as xr
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from scipy.special import gamma
# locals # locals
from climax.core.dataset import EoDataset, NetCDFDataset from climax.core.dataset import EoDataset, NetCDFDataset
from climax.core.loss import BernoulliGammaLoss, BernoulliWeibullLoss
# module level logger # module level logger
LOGGER = logging.getLogger(__name__) LOGGER = logging.getLogger(__name__)
...@@ -77,21 +79,20 @@ def predict_ERA5(net, ERA5_ds, predictand, loss, batch_size=16, **kwargs): ...@@ -77,21 +79,20 @@ def predict_ERA5(net, ERA5_ds, predictand, loss, batch_size=16, **kwargs):
# check which loss function is used # check which loss function is used
# Bernoulli-Gamma # Bernoulli-Gamma
if loss.__class__.__name__ == 'BernoulliGammaLoss': if isinstance(loss, BernoulliGammaLoss):
# precipitation amount: expected value of Bernoulli-Gamma # precipitation amount: expected value of Bernoulli-Gamma
# distribution # distribution
# pr = p * shape * scale # pr = p * shape * scale
pr = (prob * np.exp(target[:, 1, ...].squeeze()) * pr = (prob * np.exp(target[:, 1, ...].squeeze()) *
np.exp(target[:, 2, ...].squeeze())) np.exp(target[:, 2, ...].squeeze()))
# Bernoulli-GenPareto # Bernoulli-Weibull
if loss.__class__.__name__ == 'BernoulliGenParetoLoss': if isinstance(loss, BernoulliWeibullLoss):
# precipitation amount: expected value of Bernoulli-GenPareto # precipitation amount: expected value of Bernoulli-Weibull
# distribution # distribution
# pr = # pr = p * scale * tau(1 + 1 / shape)
# pr = (prob * np.exp(target[:, 1, ...].squeeze()) * pr = (prob * np.exp(target[:, 2, ...].squeeze()) *
# np.exp(target[:, 2, ...].squeeze())) gamma(1 + 1 / target[:, 1, ...]))
pass
ds = {'prob': prob, 'precipitation': pr} ds = {'prob': prob, 'precipitation': pr}
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment