Skip to content
Snippets Groups Projects
Commit 79c42172 authored by Frisinghelli Daniel's avatar Frisinghelli Daniel
Browse files

Implemented BernoulliGenPareto Loss.

parent d459b2cf
No related branches found
No related tags found
No related merge requests found
...@@ -21,7 +21,7 @@ from pysegcnn.core.trainer import NetworkTrainer, LogConfig ...@@ -21,7 +21,7 @@ from pysegcnn.core.trainer import NetworkTrainer, LogConfig
from pysegcnn.core.models import Network from pysegcnn.core.models import Network
from pysegcnn.core.logging import log_conf from pysegcnn.core.logging import log_conf
from climax.core.dataset import ERA5Dataset, NetCDFDataset from climax.core.dataset import ERA5Dataset, NetCDFDataset
from climax.core.loss import BernoulliGammaLoss from climax.core.loss import BernoulliGammaLoss, BernoulliGenParetoLoss
from climax.main.config import (ERA5_PLEVELS, ERA5_PREDICTORS, PREDICTAND, from climax.main.config import (ERA5_PLEVELS, ERA5_PREDICTORS, PREDICTAND,
CALIB_PERIOD, DOY, SHUFFLE, BATCH_SIZE, LR, CALIB_PERIOD, DOY, SHUFFLE, BATCH_SIZE, LR,
LAMBDA, NORM, TRAIN_CONFIG, NET, LOSS, FILTERS, LAMBDA, NORM, TRAIN_CONFIG, NET, LOSS, FILTERS,
...@@ -121,7 +121,8 @@ if __name__ == '__main__': ...@@ -121,7 +121,8 @@ if __name__ == '__main__':
# define number of output fields # define number of output fields
# check whether modelling pr with probabilistic approach # check whether modelling pr with probabilistic approach
outputs = len(Obs_ds.data_vars) outputs = len(Obs_ds.data_vars)
if PREDICTAND == 'pr' and isinstance(LOSS, BernoulliGammaLoss): if PREDICTAND == 'pr' and (isinstance(LOSS, BernoulliGammaLoss) or
isinstance(LOSS, BernoulliGenParetoLoss)):
outputs = 3 outputs = 3
# instanciate network # instanciate network
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment