From 3c9b388b04e0c3f6317c60d4fed81e671862fc95 Mon Sep 17 00:00:00 2001 From: "Daniel.Frisinghelli" <daniel.frisinghelli@eurac.edu> Date: Thu, 7 Oct 2021 15:40:58 +0200 Subject: [PATCH] Implemented Bernoulli-Weibull loss. --- climax/core/dataset.py | 6 ++++-- climax/core/loss.py | 25 ++++++++++++------------- climax/core/predict.py | 2 +- 3 files changed, 17 insertions(+), 16 deletions(-) diff --git a/climax/core/dataset.py b/climax/core/dataset.py index 17f8366..302960c 100644 --- a/climax/core/dataset.py +++ b/climax/core/dataset.py @@ -19,7 +19,8 @@ import dask.array as da from climax.core.constants import (ERA5_VARIABLES, ERA5_PRESSURE_LEVELS, ERA5_P_VARIABLE_NAME, ERA5_S_VARIABLE_NAME, PROJECTION) -from climax.core.loss import BernoulliGammaLoss, BernoulliGenParetoLoss +from climax.core.loss import (BernoulliGammaLoss, BernoulliGenParetoLoss, + BernoulliWeibullLoss) from pysegcnn.core.utils import search_files, img2np from pysegcnn.core.trainer import LogConfig @@ -126,7 +127,8 @@ class EoDataset(torch.utils.data.Dataset): # check which loss function is used if predictand == 'pr': if (isinstance(loss, BernoulliGammaLoss) or - isinstance(loss, BernoulliGenParetoLoss)): + isinstance(loss, BernoulliGenParetoLoss) or + isinstance(loss, BernoulliWeibullLoss)): # adjust state file for precipitation state_file = '_'.join([state_file, '{}mm'.format( str(loss.min_amount).replace('.', '')), diff --git a/climax/core/loss.py b/climax/core/loss.py index 7fdb74b..abb9ca3 100644 --- a/climax/core/loss.py +++ b/climax/core/loss.py @@ -178,15 +178,14 @@ class BernoulliWeibullLoss(NaNLoss): # clip probabilities to (0, 1) p_pred = torch.sigmoid(y_pred[:, 0, ...].squeeze()[mask]) - # clip shape and scale to (0, 50) - # NOTE: in general shape, scale in (0, +infinity), but clipping to - # finite numbers is required for numerical stability - # gshape = torch.clamp( - # torch.exp(y_pred[:, 1, ...].squeeze()[mask]), max=1) - # gscale = torch.clamp( - # torch.exp(y_pred[:, 2, ...].squeeze()[mask]), max=20) - gshape = torch.sigmoid(y_pred[:, 1, ...].squeeze()[mask]) - gscale = torch.exp(y_pred[:, 2, ...].squeeze()[mask]) + # clip scale to (0, +infinity) + scale = torch.exp(y_pred[:, 2, ...].squeeze()[mask]) + + # clip shape to (0, 1) + # NOTE: in general shape in (0, +infinity), but for precipitation the + # shape parameter of the Weibull distribution < 1 + shape = torch.clamp( + torch.exp(y_pred[:, 1, ...].squeeze()[mask]), max=1) # negative log-likelihood function of Bernoulli-Weibull distribution @@ -198,10 +197,10 @@ class BernoulliWeibullLoss(NaNLoss): # Weibull contribution loss -= p_true * (torch.log(p_pred + self.epsilon) + - torch.log(gshape + self.epsilon) - - gshape * torch.log(gscale + self.epsilon) + - (gshape - 1) * torch.log(y_true + self.epsilon) - - torch.pow(y_true / (gscale + self.epsilon), gshape) + torch.log(shape + self.epsilon) - + shape * torch.log(scale + self.epsilon) + + (shape - 1) * torch.log(y_true + self.epsilon) - + torch.pow(y_true / (scale + self.epsilon), shape) ) # clip loss to finite values to ensure numerical stability diff --git a/climax/core/predict.py b/climax/core/predict.py index b09bf22..4ece682 100644 --- a/climax/core/predict.py +++ b/climax/core/predict.py @@ -92,7 +92,7 @@ def predict_ERA5(net, ERA5_ds, predictand, loss, batch_size=16, **kwargs): # distribution # pr = p * scale * tau(1 + 1 / shape) pr = (prob * np.exp(target[:, 2, ...].squeeze()) * - gamma(1 + 1 / torch.sigmoid(target[:, 1, ...]))) + gamma(1 + 1 / np.exp(target[:, 1, ...]))) ds = {'prob': prob, 'precipitation': pr} -- GitLab