Skip to content
Snippets Groups Projects
Commit 335e78fa authored by Frisinghelli Daniel's avatar Frisinghelli Daniel
Browse files

Implemented preprocessing of ERA5 predictor dataset.

parent 7584c276
Branches
No related tags found
No related merge requests found
......@@ -40,3 +40,6 @@ ERA5_S_VARIABLES = ['mean_sea_level_pressure']
# ERA5 variables
ERA5_VARIABLES = ERA5_P_VARIABLES + ERA5_S_VARIABLES
# name of target projection
PROJECTION = 'lambert_azimuthal_equal_area'
......@@ -6,6 +6,11 @@
# externals
import torch
import numpy as np
import xarray as xr
# locals
from climax.core.constants import ERA5_VARIABLES, PROJECTION
from pysegcnn.core.utils import search_files
class EoDataset(torch.utils.data.Dataset):
......@@ -31,7 +36,7 @@ class EoDataset(torch.utils.data.Dataset):
return torch.tensor(np.asarray(x).copy(), dtype=dtype)
def NetCDFDataset(EoDataset):
class NetCDFDataset(EoDataset):
def __init__(self, X, y, dim='time'):
......@@ -48,5 +53,52 @@ def NetCDFDataset(EoDataset):
return len(self.X[self.dim])
def __getitem__(self, idx):
return (self.to_tensor(self.X.isel(time=idx), dtype=torch.float32),
self.to_tensor(self.y.isel(time=idx), dtype=torch.float32))
return (self.to_tensor(self.X.isel({self.dim: idx}), torch.float32),
self.to_tensor(self.y.isel({self.dim: idx}), torch.float32))
@staticmethod
def clip_period(ds, period):
return ds.sel(time=slice(period[0], period[1]))
class ERA5Dataset(EoDataset):
def __init__(self, root_dir, variables, pressure_levels=None):
# root directory: search for ERA5 files
self.root_dir = root_dir
# valid variable names
self.variables = [var for var in variables if var in ERA5_VARIABLES]
# pressure levels
self.pressure_levels = pressure_levels
def merge(self, **kwargs):
# search dataset for each variable in root directory
datasets = [search_files(self.root_dir.joinpath(var), '.nc$').pop() for
var in self.variables]
# iterate over input datasets and select pressure levels
predictors = []
for ds in datasets:
# read dataset
ds = xr.open_dataset(ds, **kwargs)
# iterate over pressure levels to use
for pl in self.pressure_levels:
# select pressure level and drop unnecessary dimensions and
# variables
ds = ds.sel(level=pl).drop('level')
if PROJECTION in ds.data_vars:
ds = ds.drop_vars(PROJECTION)
# rename variable including corresponding pressure level
ds = ds.rename({k: '_'.join([k, str(pl)]) for k in
ds.data_vars})
# append to list of predictors
predictors.append(ds)
# merge final predictor dataset
return xr.merge(predictors)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment