diff --git a/climax/core/dataset.py b/climax/core/dataset.py index 35fa8d8183790d25aa39b64ab25c84aeda8c261c..1d393dff4c9d9ba6dcae990a8dbf4a232501d4a6 100644 --- a/climax/core/dataset.py +++ b/climax/core/dataset.py @@ -19,7 +19,7 @@ from pysegcnn.core.utils import search_files class EoDataset(torch.utils.data.Dataset): @staticmethod - def to_tensor(x, dtype): + def to_tensor(x, dtype=torch.float32): """Convert ``x`` to :py:class:`torch.Tensor`. Parameters @@ -43,8 +43,6 @@ class NetCDFDataset(EoDataset): def __init__(self, X, y, dim='time'): - # TODO: check if conversion to array is more efficient - # NetCDF dataset containing predictor variables (ERA5) self.X = X @@ -58,12 +56,8 @@ class NetCDFDataset(EoDataset): return len(self.X[self.dim]) def __getitem__(self, idx): - return (self.to_tensor(self.X.isel({self.dim: idx}), torch.float32), - self.to_tensor(self.y.isel({self.dim: idx}), torch.float32)) - - @staticmethod - def clip_period(ds, period): - return ds.sel(time=slice(period[0], period[1])) + return (self.to_tensor(self.X.isel({self.dim: idx}).to_array().values), + self.to_tensor(self.y.isel({self.dim: idx}).to_array().values)) class ERA5Dataset(EoDataset):