From 432286c127ea987f7967fd9ab49f6c60a3b32d73 Mon Sep 17 00:00:00 2001
From: "Daniel.Frisinghelli" <daniel.frisinghelli@eurac.edu>
Date: Wed, 6 Oct 2021 14:54:25 +0200
Subject: [PATCH] Instanciate correct number of output fields.

---
 climax/main/downscale_train.py        | 8 ++++----
 climax/main/downscale_train_season.py | 8 ++++----
 2 files changed, 8 insertions(+), 8 deletions(-)

diff --git a/climax/main/downscale_train.py b/climax/main/downscale_train.py
index f5de32b..937776e 100644
--- a/climax/main/downscale_train.py
+++ b/climax/main/downscale_train.py
@@ -22,7 +22,7 @@ from pysegcnn.core.trainer import NetworkTrainer, LogConfig
 from pysegcnn.core.models import Network
 from pysegcnn.core.logging import log_conf
 from climax.core.dataset import ERA5Dataset, NetCDFDataset
-from climax.core.loss import BernoulliGammaLoss, BernoulliGenParetoLoss
+from climax.core.loss import MSELoss, L1Loss
 from climax.main.config import (ERA5_PLEVELS, ERA5_PREDICTORS, PREDICTAND,
                                 CALIB_PERIOD, DOY, SHUFFLE, BATCH_SIZE, LR,
                                 LAMBDA, NORM, TRAIN_CONFIG, NET, LOSS, FILTERS,
@@ -110,9 +110,9 @@ if __name__ == '__main__':
     # define number of output fields
     # check whether modelling pr with probabilistic approach
     outputs = len(Obs_ds.data_vars)
-    if PREDICTAND == 'pr' and (isinstance(LOSS, BernoulliGammaLoss) or
-                               isinstance(LOSS, BernoulliGenParetoLoss)):
-        outputs = 3
+    if PREDICTAND == 'pr':
+        outputs = (1 if (isinstance(LOSS, MSELoss) or isinstance(LOSS, L1Loss))
+                   else 3)
 
     # instanciate network
     inputs = len(Era5_ds.data_vars) + 2 if DOY else len(Era5_ds.data_vars)
diff --git a/climax/main/downscale_train_season.py b/climax/main/downscale_train_season.py
index ccac375..19fc89b 100644
--- a/climax/main/downscale_train_season.py
+++ b/climax/main/downscale_train_season.py
@@ -22,7 +22,7 @@ from pysegcnn.core.trainer import NetworkTrainer, LogConfig
 from pysegcnn.core.models import Network
 from pysegcnn.core.logging import log_conf
 from climax.core.dataset import ERA5Dataset, NetCDFDataset
-from climax.core.loss import BernoulliGammaLoss, BernoulliGenParetoLoss
+from climax.core.loss import MSELoss, L1Loss
 from climax.main.config import (ERA5_PLEVELS, ERA5_PREDICTORS, PREDICTAND,
                                 CALIB_PERIOD, DOY, SHUFFLE, BATCH_SIZE, LR,
                                 LAMBDA, NORM, TRAIN_CONFIG, NET, LOSS, FILTERS,
@@ -161,9 +161,9 @@ if __name__ == '__main__':
         # define number of output fields
         # check whether modelling pr with probabilistic approach
         outputs = len(Obs_train.data_vars)
-        if PREDICTAND == 'pr' and (isinstance(LOSS, BernoulliGammaLoss) or
-                                   isinstance(LOSS, BernoulliGenParetoLoss)):
-            outputs = 3
+        if PREDICTAND == 'pr':
+            outputs = (1 if (isinstance(LOSS, MSELoss) or
+                       isinstance(LOSS, L1Loss)) else 3)
 
         # instanciate network
         inputs = (len(Era5_train.data_vars) + 2 if DOY else
-- 
GitLab