diff --git a/climax/core/dataset.py b/climax/core/dataset.py index 17f83668a4edffc83b04e54a7548b8ff5dae01b9..302960c22ae27eb31a3054cb679eacc06d3975f0 100644 --- a/climax/core/dataset.py +++ b/climax/core/dataset.py @@ -19,7 +19,8 @@ import dask.array as da from climax.core.constants import (ERA5_VARIABLES, ERA5_PRESSURE_LEVELS, ERA5_P_VARIABLE_NAME, ERA5_S_VARIABLE_NAME, PROJECTION) -from climax.core.loss import BernoulliGammaLoss, BernoulliGenParetoLoss +from climax.core.loss import (BernoulliGammaLoss, BernoulliGenParetoLoss, + BernoulliWeibullLoss) from pysegcnn.core.utils import search_files, img2np from pysegcnn.core.trainer import LogConfig @@ -126,7 +127,8 @@ class EoDataset(torch.utils.data.Dataset): # check which loss function is used if predictand == 'pr': if (isinstance(loss, BernoulliGammaLoss) or - isinstance(loss, BernoulliGenParetoLoss)): + isinstance(loss, BernoulliGenParetoLoss) or + isinstance(loss, BernoulliWeibullLoss)): # adjust state file for precipitation state_file = '_'.join([state_file, '{}mm'.format( str(loss.min_amount).replace('.', '')), diff --git a/climax/core/loss.py b/climax/core/loss.py index 7fdb74b0a4030aac0a61cd18d75df8a64a8ede73..abb9ca3f930e84df9c3a0af41c5a9784583e799a 100644 --- a/climax/core/loss.py +++ b/climax/core/loss.py @@ -178,15 +178,14 @@ class BernoulliWeibullLoss(NaNLoss): # clip probabilities to (0, 1) p_pred = torch.sigmoid(y_pred[:, 0, ...].squeeze()[mask]) - # clip shape and scale to (0, 50) - # NOTE: in general shape, scale in (0, +infinity), but clipping to - # finite numbers is required for numerical stability - # gshape = torch.clamp( - # torch.exp(y_pred[:, 1, ...].squeeze()[mask]), max=1) - # gscale = torch.clamp( - # torch.exp(y_pred[:, 2, ...].squeeze()[mask]), max=20) - gshape = torch.sigmoid(y_pred[:, 1, ...].squeeze()[mask]) - gscale = torch.exp(y_pred[:, 2, ...].squeeze()[mask]) + # clip scale to (0, +infinity) + scale = torch.exp(y_pred[:, 2, ...].squeeze()[mask]) + + # clip shape to (0, 1) + # NOTE: in general shape in (0, +infinity), but for precipitation the + # shape parameter of the Weibull distribution < 1 + shape = torch.clamp( + torch.exp(y_pred[:, 1, ...].squeeze()[mask]), max=1) # negative log-likelihood function of Bernoulli-Weibull distribution @@ -198,10 +197,10 @@ class BernoulliWeibullLoss(NaNLoss): # Weibull contribution loss -= p_true * (torch.log(p_pred + self.epsilon) + - torch.log(gshape + self.epsilon) - - gshape * torch.log(gscale + self.epsilon) + - (gshape - 1) * torch.log(y_true + self.epsilon) - - torch.pow(y_true / (gscale + self.epsilon), gshape) + torch.log(shape + self.epsilon) - + shape * torch.log(scale + self.epsilon) + + (shape - 1) * torch.log(y_true + self.epsilon) - + torch.pow(y_true / (scale + self.epsilon), shape) ) # clip loss to finite values to ensure numerical stability diff --git a/climax/core/predict.py b/climax/core/predict.py index b09bf22a524dd75287f732c2c445a20a90d90ffb..4ece682b3bd52484dc77f970d0497a0a67118411 100644 --- a/climax/core/predict.py +++ b/climax/core/predict.py @@ -92,7 +92,7 @@ def predict_ERA5(net, ERA5_ds, predictand, loss, batch_size=16, **kwargs): # distribution # pr = p * scale * tau(1 + 1 / shape) pr = (prob * np.exp(target[:, 2, ...].squeeze()) * - gamma(1 + 1 / torch.sigmoid(target[:, 1, ...]))) + gamma(1 + 1 / np.exp(target[:, 1, ...]))) ds = {'prob': prob, 'precipitation': pr}