From 432286c127ea987f7967fd9ab49f6c60a3b32d73 Mon Sep 17 00:00:00 2001 From: "Daniel.Frisinghelli" <daniel.frisinghelli@eurac.edu> Date: Wed, 6 Oct 2021 14:54:25 +0200 Subject: [PATCH] Instanciate correct number of output fields. --- climax/main/downscale_train.py | 8 ++++---- climax/main/downscale_train_season.py | 8 ++++---- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/climax/main/downscale_train.py b/climax/main/downscale_train.py index f5de32b..937776e 100644 --- a/climax/main/downscale_train.py +++ b/climax/main/downscale_train.py @@ -22,7 +22,7 @@ from pysegcnn.core.trainer import NetworkTrainer, LogConfig from pysegcnn.core.models import Network from pysegcnn.core.logging import log_conf from climax.core.dataset import ERA5Dataset, NetCDFDataset -from climax.core.loss import BernoulliGammaLoss, BernoulliGenParetoLoss +from climax.core.loss import MSELoss, L1Loss from climax.main.config import (ERA5_PLEVELS, ERA5_PREDICTORS, PREDICTAND, CALIB_PERIOD, DOY, SHUFFLE, BATCH_SIZE, LR, LAMBDA, NORM, TRAIN_CONFIG, NET, LOSS, FILTERS, @@ -110,9 +110,9 @@ if __name__ == '__main__': # define number of output fields # check whether modelling pr with probabilistic approach outputs = len(Obs_ds.data_vars) - if PREDICTAND == 'pr' and (isinstance(LOSS, BernoulliGammaLoss) or - isinstance(LOSS, BernoulliGenParetoLoss)): - outputs = 3 + if PREDICTAND == 'pr': + outputs = (1 if (isinstance(LOSS, MSELoss) or isinstance(LOSS, L1Loss)) + else 3) # instanciate network inputs = len(Era5_ds.data_vars) + 2 if DOY else len(Era5_ds.data_vars) diff --git a/climax/main/downscale_train_season.py b/climax/main/downscale_train_season.py index ccac375..19fc89b 100644 --- a/climax/main/downscale_train_season.py +++ b/climax/main/downscale_train_season.py @@ -22,7 +22,7 @@ from pysegcnn.core.trainer import NetworkTrainer, LogConfig from pysegcnn.core.models import Network from pysegcnn.core.logging import log_conf from climax.core.dataset import ERA5Dataset, NetCDFDataset -from climax.core.loss import BernoulliGammaLoss, BernoulliGenParetoLoss +from climax.core.loss import MSELoss, L1Loss from climax.main.config import (ERA5_PLEVELS, ERA5_PREDICTORS, PREDICTAND, CALIB_PERIOD, DOY, SHUFFLE, BATCH_SIZE, LR, LAMBDA, NORM, TRAIN_CONFIG, NET, LOSS, FILTERS, @@ -161,9 +161,9 @@ if __name__ == '__main__': # define number of output fields # check whether modelling pr with probabilistic approach outputs = len(Obs_train.data_vars) - if PREDICTAND == 'pr' and (isinstance(LOSS, BernoulliGammaLoss) or - isinstance(LOSS, BernoulliGenParetoLoss)): - outputs = 3 + if PREDICTAND == 'pr': + outputs = (1 if (isinstance(LOSS, MSELoss) or + isinstance(LOSS, L1Loss)) else 3) # instanciate network inputs = (len(Era5_train.data_vars) + 2 if DOY else -- GitLab