Skip to content
Snippets Groups Projects
Commit 7584c276 authored by Frisinghelli Daniel's avatar Frisinghelli Daniel
Browse files

Vanilla implemenation of a PyTorch compliant NetCDF dataset.

parent efa18e94
No related branches found
No related tags found
No related merge requests found
"""Dataset classes compliant to the Pytorch standard."""
# !/usr/bin/env python
# -*- coding: utf-8 -*-
# externals
import torch
import numpy as np
class EoDataset(torch.utils.data.Dataset):
@staticmethod
def to_tensor(x, dtype):
"""Convert ``x`` to :py:class:`torch.Tensor`.
Parameters
----------
x : array_like
The input data.
dtype : :py:class:`torch.dtype`
The data type used to convert ``x``.
The modified class labels.
Returns
-------
x : `torch.Tensor`
The input data tensor.
"""
return torch.tensor(np.asarray(x).copy(), dtype=dtype)
def NetCDFDataset(EoDataset):
def __init__(self, X, y, dim='time'):
# NetCDF dataset containing predictor variables (ERA5)
self.X = X
# NetCDF dataset containing target variable (OBS)
self.y = y
# dimension along which to sample
self.dim = dim
def __len__(self):
return len(self.X[self.dim])
def __getitem__(self, idx):
return (self.to_tensor(self.X.isel(time=idx), dtype=torch.float32),
self.to_tensor(self.y.isel(time=idx), dtype=torch.float32))
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment