From c17e9f3cb7388e34912555cbe1a474ade44e4168 Mon Sep 17 00:00:00 2001 From: "Daniel.Frisinghelli" <daniel.frisinghelli@eurac.edu> Date: Mon, 27 Sep 2021 15:54:15 +0200 Subject: [PATCH] Implemented training pr with MSE loss. --- climax/core/predict.py | 4 ++-- climax/main/downscale_train.py | 10 ++++++++-- 2 files changed, 10 insertions(+), 4 deletions(-) diff --git a/climax/core/predict.py b/climax/core/predict.py index d5686ce..42cc74d 100644 --- a/climax/core/predict.py +++ b/climax/core/predict.py @@ -57,11 +57,11 @@ def predict_ERA5(net, ERA5_ds, predictand, batch_size=16, **kwargs): LOGGER.info('Mini-batch: {:d}/{:d}'.format(batch + 1, len(dl))) # convert numpy array to xarray.Dataset - if predictand == 'tas': + if predictand == 'tas' and net.classifier.out_channels == 2: # in case of tas, the netwokr predicts both tasmax and tasmin ds = {'tasmax': target[:, 0, ...].squeeze(), 'tasmin': target[:, 1, ...].squeeze()} - elif predictand == 'pr': + elif predictand == 'pr' and net.classifier.out_channels == 3: # probability of precipitation prob = torch.sigmoid(torch.as_tensor(target[:, 0, ...].squeeze(), diff --git a/climax/main/downscale_train.py b/climax/main/downscale_train.py index 16c4b1c..69f5b6f 100644 --- a/climax/main/downscale_train.py +++ b/climax/main/downscale_train.py @@ -21,6 +21,7 @@ from pysegcnn.core.trainer import NetworkTrainer, LogConfig from pysegcnn.core.models import Network from pysegcnn.core.logging import log_conf from climax.core.dataset import ERA5Dataset, NetCDFDataset +from climax.core.loss import BernoulliGammaLoss from climax.main.config import (ERA5_PLEVELS, ERA5_PREDICTORS, PREDICTAND, CALIB_PERIOD, DOY, SHUFFLE, BATCH_SIZE, LR, LAMBDA, NORM, TRAIN_CONFIG, NET, LOSS, FILTERS, @@ -128,8 +129,13 @@ if __name__ == '__main__': valid_dl = DataLoader(valid_ds, batch_size=BATCH_SIZE, shuffle=SHUFFLE, drop_last=False) - # instanciate network from scratch - outputs = 3 if PREDICTAND == 'pr' else len(Obs_ds.data_vars) + # define number of output fields + # check whether modelling pr with probabilistic approach + outputs = (Obs_ds.data_vars) + if PREDICTAND == 'pr' and isinstance(LOSS, BernoulliGammaLoss): + outputs = 3 + + # instanciate network net = NET(state_file, train_ds.X.shape[1], outputs, filters=FILTERS) # initialize optimizer -- GitLab