Skip to content
Snippets Groups Projects
Commit 432286c1 authored by Frisinghelli Daniel's avatar Frisinghelli Daniel
Browse files

Instanciate correct number of output fields.

parent 3aaf11cd
No related branches found
No related tags found
No related merge requests found
...@@ -22,7 +22,7 @@ from pysegcnn.core.trainer import NetworkTrainer, LogConfig ...@@ -22,7 +22,7 @@ from pysegcnn.core.trainer import NetworkTrainer, LogConfig
from pysegcnn.core.models import Network from pysegcnn.core.models import Network
from pysegcnn.core.logging import log_conf from pysegcnn.core.logging import log_conf
from climax.core.dataset import ERA5Dataset, NetCDFDataset from climax.core.dataset import ERA5Dataset, NetCDFDataset
from climax.core.loss import BernoulliGammaLoss, BernoulliGenParetoLoss from climax.core.loss import MSELoss, L1Loss
from climax.main.config import (ERA5_PLEVELS, ERA5_PREDICTORS, PREDICTAND, from climax.main.config import (ERA5_PLEVELS, ERA5_PREDICTORS, PREDICTAND,
CALIB_PERIOD, DOY, SHUFFLE, BATCH_SIZE, LR, CALIB_PERIOD, DOY, SHUFFLE, BATCH_SIZE, LR,
LAMBDA, NORM, TRAIN_CONFIG, NET, LOSS, FILTERS, LAMBDA, NORM, TRAIN_CONFIG, NET, LOSS, FILTERS,
...@@ -110,9 +110,9 @@ if __name__ == '__main__': ...@@ -110,9 +110,9 @@ if __name__ == '__main__':
# define number of output fields # define number of output fields
# check whether modelling pr with probabilistic approach # check whether modelling pr with probabilistic approach
outputs = len(Obs_ds.data_vars) outputs = len(Obs_ds.data_vars)
if PREDICTAND == 'pr' and (isinstance(LOSS, BernoulliGammaLoss) or if PREDICTAND == 'pr':
isinstance(LOSS, BernoulliGenParetoLoss)): outputs = (1 if (isinstance(LOSS, MSELoss) or isinstance(LOSS, L1Loss))
outputs = 3 else 3)
# instanciate network # instanciate network
inputs = len(Era5_ds.data_vars) + 2 if DOY else len(Era5_ds.data_vars) inputs = len(Era5_ds.data_vars) + 2 if DOY else len(Era5_ds.data_vars)
......
...@@ -22,7 +22,7 @@ from pysegcnn.core.trainer import NetworkTrainer, LogConfig ...@@ -22,7 +22,7 @@ from pysegcnn.core.trainer import NetworkTrainer, LogConfig
from pysegcnn.core.models import Network from pysegcnn.core.models import Network
from pysegcnn.core.logging import log_conf from pysegcnn.core.logging import log_conf
from climax.core.dataset import ERA5Dataset, NetCDFDataset from climax.core.dataset import ERA5Dataset, NetCDFDataset
from climax.core.loss import BernoulliGammaLoss, BernoulliGenParetoLoss from climax.core.loss import MSELoss, L1Loss
from climax.main.config import (ERA5_PLEVELS, ERA5_PREDICTORS, PREDICTAND, from climax.main.config import (ERA5_PLEVELS, ERA5_PREDICTORS, PREDICTAND,
CALIB_PERIOD, DOY, SHUFFLE, BATCH_SIZE, LR, CALIB_PERIOD, DOY, SHUFFLE, BATCH_SIZE, LR,
LAMBDA, NORM, TRAIN_CONFIG, NET, LOSS, FILTERS, LAMBDA, NORM, TRAIN_CONFIG, NET, LOSS, FILTERS,
...@@ -161,9 +161,9 @@ if __name__ == '__main__': ...@@ -161,9 +161,9 @@ if __name__ == '__main__':
# define number of output fields # define number of output fields
# check whether modelling pr with probabilistic approach # check whether modelling pr with probabilistic approach
outputs = len(Obs_train.data_vars) outputs = len(Obs_train.data_vars)
if PREDICTAND == 'pr' and (isinstance(LOSS, BernoulliGammaLoss) or if PREDICTAND == 'pr':
isinstance(LOSS, BernoulliGenParetoLoss)): outputs = (1 if (isinstance(LOSS, MSELoss) or
outputs = 3 isinstance(LOSS, L1Loss)) else 3)
# instanciate network # instanciate network
inputs = (len(Era5_train.data_vars) + 2 if DOY else inputs = (len(Era5_train.data_vars) + 2 if DOY else
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment