diff --git a/climax/core/predict.py b/climax/core/predict.py index d5686ce25708e01d005294b9fed691513040b646..42cc74d447087f2b7c370cabaf68d1d77967bfd0 100644 --- a/climax/core/predict.py +++ b/climax/core/predict.py @@ -57,11 +57,11 @@ def predict_ERA5(net, ERA5_ds, predictand, batch_size=16, **kwargs): LOGGER.info('Mini-batch: {:d}/{:d}'.format(batch + 1, len(dl))) # convert numpy array to xarray.Dataset - if predictand == 'tas': + if predictand == 'tas' and net.classifier.out_channels == 2: # in case of tas, the netwokr predicts both tasmax and tasmin ds = {'tasmax': target[:, 0, ...].squeeze(), 'tasmin': target[:, 1, ...].squeeze()} - elif predictand == 'pr': + elif predictand == 'pr' and net.classifier.out_channels == 3: # probability of precipitation prob = torch.sigmoid(torch.as_tensor(target[:, 0, ...].squeeze(), diff --git a/climax/main/downscale_train.py b/climax/main/downscale_train.py index 16c4b1c23720732db7d51e57d4054514b76caa02..69f5b6f93bc354392be3ad346cb31da58cf8ed4e 100644 --- a/climax/main/downscale_train.py +++ b/climax/main/downscale_train.py @@ -21,6 +21,7 @@ from pysegcnn.core.trainer import NetworkTrainer, LogConfig from pysegcnn.core.models import Network from pysegcnn.core.logging import log_conf from climax.core.dataset import ERA5Dataset, NetCDFDataset +from climax.core.loss import BernoulliGammaLoss from climax.main.config import (ERA5_PLEVELS, ERA5_PREDICTORS, PREDICTAND, CALIB_PERIOD, DOY, SHUFFLE, BATCH_SIZE, LR, LAMBDA, NORM, TRAIN_CONFIG, NET, LOSS, FILTERS, @@ -128,8 +129,13 @@ if __name__ == '__main__': valid_dl = DataLoader(valid_ds, batch_size=BATCH_SIZE, shuffle=SHUFFLE, drop_last=False) - # instanciate network from scratch - outputs = 3 if PREDICTAND == 'pr' else len(Obs_ds.data_vars) + # define number of output fields + # check whether modelling pr with probabilistic approach + outputs = (Obs_ds.data_vars) + if PREDICTAND == 'pr' and isinstance(LOSS, BernoulliGammaLoss): + outputs = 3 + + # instanciate network net = NET(state_file, train_ds.X.shape[1], outputs, filters=FILTERS) # initialize optimizer