{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "0ade2427-0020-414a-957b-4ed71281b61f",
   "metadata": {},
   "source": [
    "## Generalized linear regression"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cfd6644a-9af1-4b71-8dac-1e423786e0a8",
   "metadata": {},
   "outputs": [],
   "source": [
    "# builtins\n",
    "import pathlib\n",
    "\n",
    "# externals\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import xarray as xr\n",
    "from sklearn.linear_model import TweedieRegressor\n",
    "\n",
    "# locals\n",
    "from climax.core.dataset import ERA5Dataset\n",
    "from climax.core.constants import ERA5_VARIABLES\n",
    "from climax.core.utils import search_files"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "646406dd-8c0d-4cdc-b7c3-b0b55b2ec4bd",
   "metadata": {},
   "outputs": [],
   "source": [
    "# path to ERA5 reanalysis data\n",
    "ERA5_PATH = pathlib.Path('/mnt/CEPH_PROJECTS/FACT_CLIMAX/REANALYSIS/ERA5')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "18d222d5-fe39-4858-af39-04a14171cad6",
   "metadata": {},
   "outputs": [],
   "source": [
    "# list of valid predictor variable names\n",
    "ERA5_VARIABLES"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1cdca239-c1c9-41c6-bce5-7929ff3b19de",
   "metadata": {},
   "outputs": [],
   "source": [
    "# define the predictor variables you want to use\n",
    "ERA5_PREDICTORS = ['geopotential', 'temperature', 'mean_sea_level_pressure']  # use geopotential, temperature and pressure\n",
    "\n",
    "# you can change this list as you wish, e.g.:\n",
    "# ERA5_PREDICTORS = ['geopotential', 'temperature']  # use only geopotential and temperature\n",
    "# ERA5_PREDICTORS = ERA5_VARIABLES  # use all ERA5 variables as predictors\n",
    "\n",
    "# this checks if the variable names are correct\n",
    "assert all([p in ERA5_VARIABLES for p in ERA5_PREDICTORS]) "
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2fd49c6d-a7dc-49cb-87ba-c1f539ff2127",
   "metadata": {},
   "source": [
    "### Use the climax package to load ERA5 predictors"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "72e4c65c-add4-47a8-aae0-e4ae3301301a",
   "metadata": {},
   "outputs": [],
   "source": [
    "# define which pressure levels you want to use: currently only 500 and 850 are available\n",
    "PLEVELS = [500, 850]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "35979cb6-0a18-41bd-9eea-081b187be5ff",
   "metadata": {},
   "outputs": [],
   "source": [
    "# create the xarray.Dataset of the specified predictor variables\n",
    "predictors = ERA5Dataset(ERA5_PATH, ERA5_PREDICTORS, plevels=PLEVELS)\n",
    "predictors = predictors.merge()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "155c0954-cb8a-40d1-99e9-6c46281ae899",
   "metadata": {},
   "outputs": [],
   "source": [
    "# check out the xarray.Dataset: you will see all the variables you specified\n",
    "predictors"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "727fc0f2-c8ef-480d-a273-a07d3dff0141",
   "metadata": {},
   "source": [
    "### Load target data: observations"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "629d12b0-e4d1-4cc8-8c02-bd62bbb67cdc",
   "metadata": {},
   "outputs": [],
   "source": [
    "# path to observation data\n",
    "OBS_PATH = pathlib.Path('/mnt/CEPH_PROJECTS/FACT_CLIMAX/OBSERVATION')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c589cdea-09ed-415e-8613-a5ec311fdb2b",
   "metadata": {},
   "outputs": [],
   "source": [
    "# define the predictand, i.e. tasmax, tasmin or pr\n",
    "PREDICTAND = 'tasmax'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c8301c21-c32b-4123-a405-bde4b4d80f48",
   "metadata": {},
   "outputs": [],
   "source": [
    "# load the observation data\n",
    "predictand = xr.open_dataset(search_files(OBS_PATH.joinpath(PREDICTAND), '.nc$').pop())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4760558e-b1ed-4b16-9609-cf301aa39f55",
   "metadata": {},
   "outputs": [],
   "source": [
    "# check out the xarray.Dataset: you will see a single variable\n",
    "predictand"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "756e22e2-0b53-4d8f-877c-b9aee2659e17",
   "metadata": {},
   "source": [
    "### Prepare training data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8aec082a-1d6a-4932-8600-efc48f886cb3",
   "metadata": {},
   "outputs": [],
   "source": [
    "# define the training period and the validation period\n",
    "TRAIN_PERIOD = slice('1981-01-01', '1991-01-01')\n",
    "VALID_PERIOD = slice('1991-01-01', '2010-01-01')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0812673a-5a9d-48bf-ba87-3f96cb26ff49",
   "metadata": {},
   "outputs": [],
   "source": [
    "# select training and validation data: predictors\n",
    "predictors_train = predictors.sel(time=TRAIN_PERIOD)\n",
    "predictors_valid = predictors.sel(time=VALID_PERIOD)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c326326c-5621-4ed4-af2b-a942bc678e71",
   "metadata": {},
   "outputs": [],
   "source": [
    "# select training and validation data: predictand\n",
    "predictand_train = predictand.sel(time=TRAIN_PERIOD)\n",
    "predictand_valid = predictand.sel(time=VALID_PERIOD)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9a410594-714e-4124-b155-d2985ed5b6cb",
   "metadata": {},
   "source": [
    "### Train the generalized linear regression model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "57da5a56-bcd7-44eb-9304-62697d779f7f",
   "metadata": {},
   "outputs": [],
   "source": [
    "# instanciate the GLM\n",
    "model = TweedieRegressor(power=0 if PREDICTAND in ['tasmax', 'tasmin'] else 2)\n",
    "model\n",
    "# power = 0: Normal distribution (tasmax, tasmin)\n",
    "# power = 1: Poisson distribution\n",
    "# power = (1, 2): Compound Poisson Gamma distribution\n",
    "# power = 2: Gamma distribution (pr)\n",
    "# power = 3: Inverse Gaussian"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "802169db-8a22-42ad-8d4f-e29f676a49e9",
   "metadata": {},
   "outputs": [],
   "source": [
    "# function to normalize predictors to [0, 1]\n",
    "def normalize(predictors):\n",
    "    predictors -= predictors.min(axis=1, keepdims=True)\n",
    "    predictors /= predictors.max(axis=1, keepdims=True)\n",
    "    return predictors"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5382f429-d8e7-4636-8f47-922ae20a91cc",
   "metadata": {},
   "outputs": [],
   "source": [
    "# iterate over the grid points\n",
    "prediction = np.ones(shape=(len(predictors_valid.time), len(predictors_valid.y), len(predictors_valid.x))) * np.nan\n",
    "for i, _ in enumerate(predictors_train.x):\n",
    "    for j, _ in enumerate(predictors_train.y):\n",
    "        \n",
    "        # current grid point: xarray.Dataset, dimensions=(time)\n",
    "        point_predictors = predictors_train.isel(x=i, y=j)\n",
    "        point_predictand = predictand_train.isel(x=i, y=j)\n",
    "        \n",
    "        # convert xarray.Dataset to numpy.array: shape=(time, predictors)\n",
    "        point_predictors = point_predictors.to_array().values.swapaxes(0, 1)\n",
    "        point_predictand = point_predictand.to_array().values.squeeze()\n",
    "        \n",
    "        # check if the grid point is valid\n",
    "        if np.isnan(point_predictors).any() or np.isnan(point_predictand).any():\n",
    "            # move on to next grid point\n",
    "            continue\n",
    "            \n",
    "        # normalize each predictor variable to [0, 1]\n",
    "        point_predictors = normalize(point_predictors)\n",
    "        \n",
    "        # instanciate the model for the current grid point\n",
    "        model = TweedieRegressor(power=0 if PREDICTAND in ['tasmax', 'tasmin'] else 2)\n",
    "        \n",
    "        # train model on training data\n",
    "        model.fit(point_predictors, point_predictand)\n",
    "        print('Processing grid point: ({:d}, {:d}), score: {:.2f}'.format(j, i, model.score(point_predictors, point_predictand)))\n",
    "        \n",
    "        # prepare predictors of validation period\n",
    "        point_validation = predictors_valid.isel(x=i, y=j).to_array().values.swapaxes(0, 1)\n",
    "        point_validation = normalize(point_validation)\n",
    "        \n",
    "        # predict validation period\n",
    "        pred = model.predict(point_validation)\n",
    "        \n",
    "        # store predictions for current grid point\n",
    "        prediction[:, j, i] = pred\n",
    "    \n",
    "# store predictions in xarray.Dataset\n",
    "predictions = xr.DataArray(data=prediction, dims=['time', 'y', 'x'],\n",
    "                           coords=dict(time=pd.date_range(VALID_PERIOD.start, VALID_PERIOD.stop, freq='D'),\n",
    "                                       lat=predictand_valid.y, lon=predictand_valid.x))\n",
    "predictions = predictions.to_dataset(name=PREDICTAND)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "595ee11d-570d-42ed-8c40-226d724aa9fb",
   "metadata": {},
   "source": [
    "### Save predictions as NetCDF"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fb708c40-c0e3-4c62-bc4e-9d0b6bfc52d4",
   "metadata": {},
   "outputs": [],
   "source": [
    "# specify the output path, filename: PREDICTAND.nc\n",
    "OUTPUT_PATH = pathlib.Path('~/{}'.format(PREDICTAND + '.nc'))\n",
    "\n",
    "# save to NetCDF\n",
    "predictions.to_netcdf(OUTPUT_PATH, engine='h5netcdf')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fc9aba1a-fde1-48c6-85e7-02cb5ebdcfa7",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Enjoy and have fun!"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}