From 79c421721dd8aaec8ba6a4b48df9aa9cefb61d3b Mon Sep 17 00:00:00 2001
From: "Daniel.Frisinghelli" <daniel.frisinghelli@eurac.edu>
Date: Tue, 5 Oct 2021 14:53:02 +0200
Subject: [PATCH] Implemented BernoulliGenPareto Loss.

---
 climax/main/downscale_train.py | 5 +++--
 1 file changed, 3 insertions(+), 2 deletions(-)

diff --git a/climax/main/downscale_train.py b/climax/main/downscale_train.py
index b985de3..2b49ab3 100644
--- a/climax/main/downscale_train.py
+++ b/climax/main/downscale_train.py
@@ -21,7 +21,7 @@ from pysegcnn.core.trainer import NetworkTrainer, LogConfig
 from pysegcnn.core.models import Network
 from pysegcnn.core.logging import log_conf
 from climax.core.dataset import ERA5Dataset, NetCDFDataset
-from climax.core.loss import BernoulliGammaLoss
+from climax.core.loss import BernoulliGammaLoss, BernoulliGenParetoLoss
 from climax.main.config import (ERA5_PLEVELS, ERA5_PREDICTORS, PREDICTAND,
                                 CALIB_PERIOD, DOY, SHUFFLE, BATCH_SIZE, LR,
                                 LAMBDA, NORM, TRAIN_CONFIG, NET, LOSS, FILTERS,
@@ -121,7 +121,8 @@ if __name__ == '__main__':
         # define number of output fields
         # check whether modelling pr with probabilistic approach
         outputs = len(Obs_ds.data_vars)
-        if PREDICTAND == 'pr' and isinstance(LOSS, BernoulliGammaLoss):
+        if PREDICTAND == 'pr' and (isinstance(LOSS, BernoulliGammaLoss) or
+                                   isinstance(LOSS, BernoulliGenParetoLoss)):
             outputs = 3
 
         # instanciate network
-- 
GitLab