## Generalized linear regression

In [None]:
# builtins
import pathlib

# externals
import numpy as np
import pandas as pd
import xarray as xr
from sklearn.linear_model import TweedieRegressor

# locals
from climax.core.dataset import ERA5Dataset
from climax.core.constants import ERA5_VARIABLES
from climax.core.utils import search_files

In [None]:
# path to ERA5 reanalysis data
ERA5_PATH = pathlib.Path('/mnt/CEPH_PROJECTS/FACT_CLIMAX/REANALYSIS/ERA5')

In [None]:
# list of valid predictor variable names
ERA5_VARIABLES

In [None]:
# define the predictor variables you want to use
ERA5_PREDICTORS = ['geopotential', 'temperature', 'mean_sea_level_pressure'] # use geopotential, temperature and pressure

# you can change this list as you wish, e.g.:
# ERA5_PREDICTORS = ['geopotential', 'temperature'] # use only geopotential and temperature
# ERA5_PREDICTORS = ERA5_VARIABLES # use all ERA5 variables as predictors

# this checks if the variable names are correct
assert all([p in ERA5_VARIABLES for p in ERA5_PREDICTORS]) 

### Use the climax package to load ERA5 predictors

In [None]:
# define which pressure levels you want to use: currently only 500 and 850 are available
PLEVELS = [500, 850]

In [None]:
# create the xarray.Dataset of the specified predictor variables
predictors = ERA5Dataset(ERA5_PATH, ERA5_PREDICTORS, plevels=PLEVELS)
predictors = predictors.merge()

In [None]:
# check out the xarray.Dataset: you will see all the variables you specified
predictors

### Load target data: observations

In [None]:
# path to observation data
OBS_PATH = pathlib.Path('/mnt/CEPH_PROJECTS/FACT_CLIMAX/OBSERVATION')

In [None]:
# define the predictand, i.e. tasmax, tasmin or pr
PREDICTAND = 'tasmax'

In [None]:
# load the observation data
predictand = xr.open_dataset(search_files(OBS_PATH.joinpath(PREDICTAND), '.nc$').pop())

In [None]:
# check out the xarray.Dataset: you will see a single variable
predictand

### Prepare training data

In [None]:
# define the training period and the validation period
TRAIN_PERIOD = slice('1981-01-01', '1991-01-01')
VALID_PERIOD = slice('1991-01-01', '2010-01-01')

In [None]:
# select training and validation data: predictors
predictors_train = predictors.sel(time=TRAIN_PERIOD)
predictors_valid = predictors.sel(time=VALID_PERIOD)

In [None]:
# select training and validation data: predictand
predictand_train = predictand.sel(time=TRAIN_PERIOD)
predictand_valid = predictand.sel(time=VALID_PERIOD)

### Train the generalized linear regression model

In [None]:
# instanciate the GLM
model = TweedieRegressor(power=0 if PREDICTAND in ['tasmax', 'tasmin'] else 2)
model
# power = 0: Normal distribution (tasmax, tasmin)
# power = 1: Poisson distribution
# power = (1, 2): Compound Poisson Gamma distribution
# power = 2: Gamma distribution (pr)
# power = 3: Inverse Gaussian

In [None]:
# function to normalize predictors to [0, 1]
def normalize(predictors):
 predictors -= predictors.min(axis=1, keepdims=True)
 predictors /= predictors.max(axis=1, keepdims=True)
 return predictors

In [None]:
# iterate over the grid points
prediction = np.ones(shape=(len(predictors_valid.time), len(predictors_valid.y), len(predictors_valid.x))) * np.nan
for i, _ in enumerate(predictors_train.x):
 for j, _ in enumerate(predictors_train.y):
 
 # current grid point: xarray.Dataset, dimensions=(time)
 point_predictors = predictors_train.isel(x=i, y=j)
 point_predictand = predictand_train.isel(x=i, y=j)
 
 # convert xarray.Dataset to numpy.array: shape=(time, predictors)
 point_predictors = point_predictors.to_array().values.swapaxes(0, 1)
 point_predictand = point_predictand.to_array().values.squeeze()
 
 # check if the grid point is valid
 if np.isnan(point_predictors).any() or np.isnan(point_predictand).any():
 # move on to next grid point
 continue
 
 # normalize each predictor variable to [0, 1]
 point_predictors = normalize(point_predictors)
 
 # instanciate the model for the current grid point
 model = TweedieRegressor(power=0 if PREDICTAND in ['tasmax', 'tasmin'] else 2)
 
 # train model on training data
 model.fit(point_predictors, point_predictand)
 print('Processing grid point: ({:d}, {:d}), score: {:.2f}'.format(j, i, model.score(point_predictors, point_predictand)))
 
 # prepare predictors of validation period
 point_validation = predictors_valid.isel(x=i, y=j).to_array().values.swapaxes(0, 1)
 point_validation = normalize(point_validation)
 
 # predict validation period
 pred = model.predict(point_validation)
 
 # store predictions for current grid point
 prediction[:, j, i] = pred
 
# store predictions in xarray.Dataset
predictions = xr.DataArray(data=prediction, dims=['time', 'y', 'x'],
 coords=dict(time=pd.date_range(VALID_PERIOD.start, VALID_PERIOD.stop, freq='D'),
 lat=predictand_valid.y, lon=predictand_valid.x))
predictions = predictions.to_dataset(name=PREDICTAND)

### Save predictions as NetCDF

In [None]:
# specify the output path, filename: PREDICTAND.nc
OUTPUT_PATH = pathlib.Path('~/{}'.format(PREDICTAND + '.nc'))

# save to NetCDF
predictions.to_netcdf(OUTPUT_PATH, engine='h5netcdf')

In [None]:
# Enjoy and have fun!