diff --git a/climax/core/dataset.py b/climax/core/dataset.py index d8435cbe5ea02060b32ad50f50a22e4a5d76b5b4..4e4ad63211d64414e735613b3a7512ecb2c81823 100644 --- a/climax/core/dataset.py +++ b/climax/core/dataset.py @@ -78,9 +78,10 @@ class EoDataset(torch.utils.data.Dataset): return da.repeat(da.array(array), repeats, axis) @staticmethod - def encode_doys(ds, **kwargs): + def encode_doys(ds, dims=('time', 'y', 'x'), chunks=None): # compute day of the year + LOGGER.info('Encoding day of the year to cyclical feature ...') doys = ds.time.values.astype('datetime64[D]') doys = da.asarray( [date.timetuple(doy.astype(date)).tm_yday for doy in doys]) @@ -98,8 +99,15 @@ class EoDataset(torch.utils.data.Dataset): EoDataset.repeat_along_axis(sin_doy, repeat, 0).reshape(target), EoDataset.repeat_along_axis(cos_doy, repeat, 0).reshape(target)) - return {'sin_doy': EoDataset.add_coordinates(sin_doy, **kwargs), - 'cos_doy': EoDataset.add_coordinates(cos_doy, **kwargs)} + # chunk data for parallel loading + if chunks is not None: + sin_doy = sin_doy.rechunk( + {dims.index(k): v for k, v in chunks.items()}) + cos_doy = cos_doy.rechunk( + {dims.index(k): v for k, v in chunks.items()}) + + return {'sin_doy': EoDataset.add_coordinates(sin_doy, dims), + 'cos_doy': EoDataset.add_coordinates(cos_doy, dims)} @staticmethod def state_file(model, predictand, predictors, plevels, dem=False, @@ -220,7 +228,7 @@ class NetCDFDataset(EoDataset): if doy: # add doy to set of predictor variables LOGGER.info('Adding day of the year to predictor variables ...') - X = X.assign(self.encode_doys(X)) + X = X.assign(self.encode_doys(X, chunks=X.chunks)) # NetCDF dataset containing predictor variables (ERA5) # shape: (t, vars, y, x)