Skip to content
Snippets Groups Projects
Commit a692cd41 authored by Frisinghelli Daniel's avatar Frisinghelli Daniel
Browse files

Changed handling of Nans.

parent cc567b5e
No related branches found
No related tags found
No related merge requests found
......@@ -14,16 +14,13 @@ class NaNLoss(_Loss):
def __init__(self, size_average=None, reduce=None, reduction='mean'):
super().__init__(size_average, reduce, reduction)
def nan_reduce(self, tensor):
# mask missing values
tensor = tensor[~torch.isnan(tensor)]
if self.reduction is None:
def reduce(self, tensor):
if self.reduction == 'none':
return tensor
if self.reduction == 'sum':
return tensor.sum()
if self.reduction == 'mean':
elif self.reduction == 'mean':
return tensor.mean()
elif self.reduction == 'sum':
return tensor.sum()
class MSELoss(NaNLoss):
......@@ -32,7 +29,8 @@ class MSELoss(NaNLoss):
super().__init__(size_average, reduce, reduction)
def forward(self, y_pred, y_true):
return self.nan_reduce(F.mse_loss(y_pred, y_true, reduction='none'))
mask = ~torch.isnan(y_true)
return F.mse_loss(y_pred[mask], y_true[mask], reduction=self.reduction)
class L1Loss(NaNLoss):
......@@ -41,40 +39,41 @@ class L1Loss(NaNLoss):
super().__init__(size_average, reduce, reduction)
def forward(self, y_pred, y_true):
return self.nan_reduce(F.l1_loss(y_pred, y_true, reduction='none'))
mask = ~torch.isnan(y_true)
return F.l1_loss(y_pred[mask], y_true[mask], reduction=self.reduction)
class BernoulliGammaLoss(NaNLoss):
def __init__(self, size_average=None, reduce=None, reduction='mean',
min_amount=0, epsilon=1e-7):
min_amount=0):
super().__init__(size_average, reduce, reduction)
# minimum amount of precipitation to be classified as precipitation
self.min_amount = min_amount
# small number ensuring numerical stability
self.epsilon = epsilon
def forward(self, y_pred, y_true):
# convert to float32
y_pred = y_pred.type(torch.float32) + self.epsilon
y_pred = y_pred.type(torch.float32)
y_true = y_true.type(torch.float32)
# missing values
mask = ~torch.isnan(y_true)
# calculate true probability of precipitation:
# 1 if y_true > min_amount else 0
p_true = (y_true > self.min_amount).type(torch.float32)
p_true = (y_true > self.min_amount).type(torch.float32)[mask]
# estimates of precipitation probability and gamma shape and scale
# parameters: ensure numerical stability
# clip probabilities to (0, 1)
p_pred = torch.sigmoid(y_pred[:, 0, ...])
p_pred = torch.sigmoid(y_pred[:, 0, ...])[mask]
# clip shape and scale to (0, +infinity)
gshape = torch.exp(y_pred[:, 1, ...])
gscale = torch.exp(y_pred[:, 2, ...])
gshape = torch.exp(y_pred[:, 1, ...])[mask]
gscale = torch.exp(y_pred[:, 2, ...])[mask]
# negative log-likelihood function of Bernoulli-Gamma distribution
......@@ -86,4 +85,4 @@ class BernoulliGammaLoss(NaNLoss):
- (y_true / gscale) - gshape * torch.log(gscale) -
torch.lgamma(gshape))
return - self.nan_reduce(loss)
return - self.reduce(loss)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment