Skip to content
Snippets Groups Projects
Commit c17e9f3c authored by Frisinghelli Daniel's avatar Frisinghelli Daniel
Browse files

Implemented training pr with MSE loss.

parent 5d1420a5
No related branches found
No related tags found
No related merge requests found
......@@ -57,11 +57,11 @@ def predict_ERA5(net, ERA5_ds, predictand, batch_size=16, **kwargs):
LOGGER.info('Mini-batch: {:d}/{:d}'.format(batch + 1, len(dl)))
# convert numpy array to xarray.Dataset
if predictand == 'tas':
if predictand == 'tas' and net.classifier.out_channels == 2:
# in case of tas, the netwokr predicts both tasmax and tasmin
ds = {'tasmax': target[:, 0, ...].squeeze(),
'tasmin': target[:, 1, ...].squeeze()}
elif predictand == 'pr':
elif predictand == 'pr' and net.classifier.out_channels == 3:
# probability of precipitation
prob = torch.sigmoid(torch.as_tensor(target[:, 0, ...].squeeze(),
......
......@@ -21,6 +21,7 @@ from pysegcnn.core.trainer import NetworkTrainer, LogConfig
from pysegcnn.core.models import Network
from pysegcnn.core.logging import log_conf
from climax.core.dataset import ERA5Dataset, NetCDFDataset
from climax.core.loss import BernoulliGammaLoss
from climax.main.config import (ERA5_PLEVELS, ERA5_PREDICTORS, PREDICTAND,
CALIB_PERIOD, DOY, SHUFFLE, BATCH_SIZE, LR,
LAMBDA, NORM, TRAIN_CONFIG, NET, LOSS, FILTERS,
......@@ -128,8 +129,13 @@ if __name__ == '__main__':
valid_dl = DataLoader(valid_ds, batch_size=BATCH_SIZE, shuffle=SHUFFLE,
drop_last=False)
# instanciate network from scratch
outputs = 3 if PREDICTAND == 'pr' else len(Obs_ds.data_vars)
# define number of output fields
# check whether modelling pr with probabilistic approach
outputs = (Obs_ds.data_vars)
if PREDICTAND == 'pr' and isinstance(LOSS, BernoulliGammaLoss):
outputs = 3
# instanciate network
net = NET(state_file, train_ds.X.shape[1], outputs, filters=FILTERS)
# initialize optimizer
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment