diff --git a/climax/core/predict.py b/climax/core/predict.py
index d5686ce25708e01d005294b9fed691513040b646..42cc74d447087f2b7c370cabaf68d1d77967bfd0 100644
--- a/climax/core/predict.py
+++ b/climax/core/predict.py
@@ -57,11 +57,11 @@ def predict_ERA5(net, ERA5_ds, predictand, batch_size=16, **kwargs):
         LOGGER.info('Mini-batch: {:d}/{:d}'.format(batch + 1, len(dl)))
 
     # convert numpy array to xarray.Dataset
-    if predictand == 'tas':
+    if predictand == 'tas' and net.classifier.out_channels == 2:
         # in case of tas, the netwokr predicts both tasmax and tasmin
         ds = {'tasmax': target[:, 0, ...].squeeze(),
               'tasmin': target[:, 1, ...].squeeze()}
-    elif predictand == 'pr':
+    elif predictand == 'pr' and net.classifier.out_channels == 3:
 
         # probability of precipitation
         prob = torch.sigmoid(torch.as_tensor(target[:, 0, ...].squeeze(),
diff --git a/climax/main/downscale_train.py b/climax/main/downscale_train.py
index 16c4b1c23720732db7d51e57d4054514b76caa02..69f5b6f93bc354392be3ad346cb31da58cf8ed4e 100644
--- a/climax/main/downscale_train.py
+++ b/climax/main/downscale_train.py
@@ -21,6 +21,7 @@ from pysegcnn.core.trainer import NetworkTrainer, LogConfig
 from pysegcnn.core.models import Network
 from pysegcnn.core.logging import log_conf
 from climax.core.dataset import ERA5Dataset, NetCDFDataset
+from climax.core.loss import BernoulliGammaLoss
 from climax.main.config import (ERA5_PLEVELS, ERA5_PREDICTORS, PREDICTAND,
                                 CALIB_PERIOD, DOY, SHUFFLE, BATCH_SIZE, LR,
                                 LAMBDA, NORM, TRAIN_CONFIG, NET, LOSS, FILTERS,
@@ -128,8 +129,13 @@ if __name__ == '__main__':
         valid_dl = DataLoader(valid_ds, batch_size=BATCH_SIZE, shuffle=SHUFFLE,
                               drop_last=False)
 
-        # instanciate network from scratch
-        outputs = 3 if PREDICTAND == 'pr' else len(Obs_ds.data_vars)
+        # define number of output fields
+        # check whether modelling pr with probabilistic approach
+        outputs = (Obs_ds.data_vars)
+        if PREDICTAND == 'pr' and isinstance(LOSS, BernoulliGammaLoss):
+            outputs = 3
+
+        # instanciate network
         net = NET(state_file, train_ds.X.shape[1], outputs, filters=FILTERS)
 
         	# initialize optimizer