Skip to content
Snippets Groups Projects
Commit 6e3d8716 authored by Frisinghelli Daniel's avatar Frisinghelli Daniel
Browse files

Clean implementation of model predictions.

parent 1175c093
No related branches found
No related tags found
No related merge requests found
...@@ -4,9 +4,11 @@ ...@@ -4,9 +4,11 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# externals # externals
import numpy as np
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from torch.nn.modules.loss import _Loss from torch.nn.modules.loss import _Loss
from scipy.special import gamma
class NaNLoss(_Loss): class NaNLoss(_Loss):
...@@ -46,7 +48,7 @@ class L1Loss(NaNLoss): ...@@ -46,7 +48,7 @@ class L1Loss(NaNLoss):
return F.l1_loss(y_pred[mask], y_true[mask], reduction=self.reduction) return F.l1_loss(y_pred[mask], y_true[mask], reduction=self.reduction)
class BernoulliGammaLoss(NaNLoss): class BernoulliLoss(NaNLoss):
def __init__(self, size_average=None, reduce=None, reduction='mean', def __init__(self, size_average=None, reduce=None, reduction='mean',
min_amount=0): min_amount=0):
...@@ -55,6 +57,13 @@ class BernoulliGammaLoss(NaNLoss): ...@@ -55,6 +57,13 @@ class BernoulliGammaLoss(NaNLoss):
# minimum amount of precipitation to be classified as precipitation # minimum amount of precipitation to be classified as precipitation
self.min_amount = min_amount self.min_amount = min_amount
class BernoulliGammaLoss(BernoulliLoss):
def __init__(self, size_average=None, reduce=None, reduction='mean',
min_amount=0):
super().__init__(size_average, reduce, reduction)
def forward(self, y_pred, y_true): def forward(self, y_pred, y_true):
# convert to float32 # convert to float32
...@@ -96,16 +105,18 @@ class BernoulliGammaLoss(NaNLoss): ...@@ -96,16 +105,18 @@ class BernoulliGammaLoss(NaNLoss):
return self.reduce(loss) return self.reduce(loss)
@staticmethod
def predict(p, shape, scale):
# pr = p * shape * scale
return p * shape * scale
class BernoulliGenParetoLoss(NaNLoss): class BernoulliGenParetoLoss(BernoulliLoss):
def __init__(self, size_average=None, reduce=None, reduction='mean', def __init__(self, size_average=None, reduce=None, reduction='mean',
min_amount=0): min_amount=0):
super().__init__(size_average, reduce, reduction) super().__init__(size_average, reduce, reduction)
# minimum amount of precipitation to be classified as precipitation
self.min_amount = min_amount
def forward(self, y_pred, y_true): def forward(self, y_pred, y_true):
# convert to float32 # convert to float32
...@@ -146,15 +157,12 @@ class BernoulliGenParetoLoss(NaNLoss): ...@@ -146,15 +157,12 @@ class BernoulliGenParetoLoss(NaNLoss):
return self.reduce(loss) return self.reduce(loss)
class BernoulliWeibullLoss(NaNLoss): class BernoulliWeibullLoss(BernoulliLoss):
def __init__(self, size_average=None, reduce=None, reduction='mean', def __init__(self, size_average=None, reduce=None, reduction='mean',
min_amount=0): min_amount=0):
super().__init__(size_average, reduce, reduction) super().__init__(size_average, reduce, reduction)
# minimum amount of precipitation to be classified as precipitation
self.min_amount = min_amount
def forward(self, y_pred, y_true): def forward(self, y_pred, y_true):
# convert to float32 # convert to float32
...@@ -207,3 +215,8 @@ class BernoulliWeibullLoss(NaNLoss): ...@@ -207,3 +215,8 @@ class BernoulliWeibullLoss(NaNLoss):
# loss = torch.clamp(loss, max=100) # loss = torch.clamp(loss, max=100)
return self.reduce(loss) return self.reduce(loss)
@staticmethod
def predict(p, shape, scale):
# pr = p * scale * gamma(1 + 1 / shape)
return p * scale * gamma(1 + 1 / shape)
...@@ -11,11 +11,9 @@ import torch ...@@ -11,11 +11,9 @@ import torch
import numpy as np import numpy as np
import xarray as xr import xarray as xr
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from scipy.special import gamma
# locals # locals
from climax.core.dataset import EoDataset, NetCDFDataset from climax.core.dataset import EoDataset, NetCDFDataset
from climax.core.loss import BernoulliGammaLoss, BernoulliWeibullLoss
# module level logger # module level logger
LOGGER = logging.getLogger(__name__) LOGGER = logging.getLogger(__name__)
...@@ -76,24 +74,14 @@ def predict_ERA5(net, ERA5_ds, predictand, loss, batch_size=16, **kwargs): ...@@ -76,24 +74,14 @@ def predict_ERA5(net, ERA5_ds, predictand, loss, batch_size=16, **kwargs):
prob = torch.sigmoid(torch.as_tensor(target[:, 0, ...].squeeze(), prob = torch.sigmoid(torch.as_tensor(target[:, 0, ...].squeeze(),
dtype=torch.float32)).numpy() dtype=torch.float32)).numpy()
# check which loss function is used # shape and scale parameters
shape = np.exp(target[:, 1, ...].squeeze())
scale = np.exp(target[:, 2, ...].squeeze())
# Bernoulli-Gamma # precipitation amount
if isinstance(loss, BernoulliGammaLoss): pr = loss.predict(prob, shape, scale)
# precipitation amount: expected value of Bernoulli-Gamma
# distribution
# pr = p * shape * scale
pr = (prob * np.exp(target[:, 1, ...].squeeze()) *
np.exp(target[:, 2, ...].squeeze()))
# Bernoulli-Weibull
if isinstance(loss, BernoulliWeibullLoss):
# precipitation amount: expected value of Bernoulli-Weibull
# distribution
# pr = p * scale * tau(1 + 1 / shape)
pr = (prob * np.exp(target[:, 2, ...].squeeze()) *
gamma(1 + 1 / np.exp(target[:, 1, ...])))
# precipitation probability and amount
ds = {'prob': prob, 'precipitation': pr} ds = {'prob': prob, 'precipitation': pr}
# add coordinates to arrays # add coordinates to arrays
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment