diff --git a/climax/main/downscale_train.py b/climax/main/downscale_train.py index b985de3a0c0df2dec1b8d91d92f68d73e41b9230..2b49ab3a58d537b9957b060bc509ef8cc5baf31d 100644 --- a/climax/main/downscale_train.py +++ b/climax/main/downscale_train.py @@ -21,7 +21,7 @@ from pysegcnn.core.trainer import NetworkTrainer, LogConfig from pysegcnn.core.models import Network from pysegcnn.core.logging import log_conf from climax.core.dataset import ERA5Dataset, NetCDFDataset -from climax.core.loss import BernoulliGammaLoss +from climax.core.loss import BernoulliGammaLoss, BernoulliGenParetoLoss from climax.main.config import (ERA5_PLEVELS, ERA5_PREDICTORS, PREDICTAND, CALIB_PERIOD, DOY, SHUFFLE, BATCH_SIZE, LR, LAMBDA, NORM, TRAIN_CONFIG, NET, LOSS, FILTERS, @@ -121,7 +121,8 @@ if __name__ == '__main__': # define number of output fields # check whether modelling pr with probabilistic approach outputs = len(Obs_ds.data_vars) - if PREDICTAND == 'pr' and isinstance(LOSS, BernoulliGammaLoss): + if PREDICTAND == 'pr' and (isinstance(LOSS, BernoulliGammaLoss) or + isinstance(LOSS, BernoulliGenParetoLoss)): outputs = 3 # instanciate network