{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "032a6c39-ca0c-44d7-bf26-b4ea0ed87739",
   "metadata": {},
   "source": [
    "## ERA5-predictors"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1daf69a6-f61b-484d-9b07-101da88f4b28",
   "metadata": {},
   "outputs": [],
   "source": [
    "# builtins\n",
    "import pathlib\n",
    "\n",
    "# externals\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import xarray as xr\n",
    "from sklearn.model_selection import train_test_split\n",
    "\n",
    "# locals\n",
    "from pysegcnn.core.utils import search_files\n",
    "from climax.core.dataset import ERA5Dataset\n",
    "from climax.core.constants import ERA5_VARIABLES\n",
    "from climax.core.utils import search_files\n",
    "from climax.main.config import CALIB_PERIOD\n",
    "from climax.main.io import OBS_PATH"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e64afc53-097d-403c-b6ee-e6865b26a245",
   "metadata": {},
   "outputs": [],
   "source": [
    "# path to ERA5 reanalysis data\n",
    "ERA5_PATH = pathlib.Path('/mnt/CEPH_PROJECTS/FACT_CLIMAX/REANALYSIS/ERA5')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d2a127b5-faa0-4cf5-8d7c-bfb747d38b51",
   "metadata": {},
   "outputs": [],
   "source": [
    "# list of valid predictor variable names\n",
    "ERA5_VARIABLES"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "10c6c7e1-6ff2-4aa3-ace6-3adb7edf2583",
   "metadata": {},
   "outputs": [],
   "source": [
    "# define the predictor variables you want to use\n",
    "ERA5_PREDICTORS = ['geopotential', 'temperature', 'mean_sea_level_pressure']  # use geopotential, temperature and pressure\n",
    "\n",
    "# you can change this list as you wish, e.g.:\n",
    "# ERA5_PREDICTORS = ['geopotential', 'temperature']  # use only geopotential and temperature\n",
    "# ERA5_PREDICTORS = ERA5_VARIABLES  # use all ERA5 variables as predictors\n",
    "\n",
    "# this checks if the variable names are correct\n",
    "assert all([p in ERA5_VARIABLES for p in ERA5_PREDICTORS]) "
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6f500399-4d61-41e8-9013-eaac69b3e49a",
   "metadata": {},
   "source": [
    "### Use the climax package to load ERA5 predictors"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "57223fb2-efec-46c9-a9f7-ca71524627ef",
   "metadata": {},
   "outputs": [],
   "source": [
    "# define which pressure levels you want to use: currently only 500 and 850 are available\n",
    "PLEVELS = [500, 850]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f5c75041-a40d-45e5-b940-2847fee80926",
   "metadata": {},
   "outputs": [],
   "source": [
    "# create the xarray.Dataset of the specified predictor variables\n",
    "predictors = ERA5Dataset(ERA5_PATH, ERA5_PREDICTORS, plevels=PLEVELS)\n",
    "predictors = predictors.merge(chunks=-1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ea5c1c43-fbbf-4db1-926b-1102e8079348",
   "metadata": {},
   "outputs": [],
   "source": [
    "# check out the xarray.Dataset: you will see all the variables you specified\n",
    "predictors"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "14668e55-3b47-44ea-b3f7-596b0e62eec6",
   "metadata": {},
   "source": [
    "## DEM-predictors"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3c89e1ea-3a2d-45aa-bc0d-6123239d58b3",
   "metadata": {},
   "outputs": [],
   "source": [
    "DEM_PATH = pathlib.Path('/mnt/CEPH_PROJECTS/FACT_CLIMAX/DEM/')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2c39941c-2edc-428c-bd8e-70dc43115238",
   "metadata": {},
   "outputs": [],
   "source": [
    "# digital elevation model: Copernicus EU-Dem v1.1\n",
    "dem = search_files(DEM_PATH, '^eu_dem_v11_stt.nc$').pop()\n",
    "\n",
    "# read elevation and compute slope and aspect\n",
    "dem = ERA5Dataset.dem_features(dem, {'y': predictors.y, 'x': predictors.x}, add_coord={'time': predictors.time})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f3a04b66-72bb-4b16-a527-080c543cbeab",
   "metadata": {},
   "outputs": [],
   "source": [
    "dem"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d55ca9b1-46a9-462a-ba42-fb7a38c823fa",
   "metadata": {},
   "source": [
    "## Merge predictors"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b96ad03c-6065-4609-986a-66a5cb799526",
   "metadata": {},
   "outputs": [],
   "source": [
    "Era5_ds = xr.merge([predictors, dem])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "dd584313-8913-45b5-a379-2724d42bc99f",
   "metadata": {},
   "source": [
    "## Read observations"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e982d6f9-ae6b-435e-917d-44118c6a09b1",
   "metadata": {},
   "outputs": [],
   "source": [
    "PREDICTAND = 'pr'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fe45a13d-d95d-4260-a3bc-0e6fd53db732",
   "metadata": {},
   "outputs": [],
   "source": [
    "# read in-situ gridded observations\n",
    "Obs_ds = search_files(OBS_PATH.joinpath(PREDICTAND), 'OBS_{}(.*).nc$'.format(PREDICTAND)).pop()\n",
    "Obs_ds = xr.open_dataset(Obs_ds)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "caaf0557-74a9-4c30-9697-bb58d68553aa",
   "metadata": {},
   "source": [
    "## Group predictors by season"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c98896f5-7953-4bb3-af97-fd98740514d1",
   "metadata": {},
   "outputs": [],
   "source": [
    "# split into training and validation set\n",
    "train, valid = train_test_split(CALIB_PERIOD, shuffle=False, test_size=0.1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1a7295fd-3c57-4ca4-9c47-23a179db4829",
   "metadata": {},
   "outputs": [],
   "source": [
    "# training and validation dataset\n",
    "Era5_train, Obs_train = Era5_ds.sel(time=train), Obs_ds.sel(time=train)\n",
    "Era5_valid, Obs_valid = Era5_ds.sel(time=valid), Obs_ds.sel(time=valid)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "82a88472-7837-4fa6-a591-1f1952ea442b",
   "metadata": {},
   "outputs": [],
   "source": [
    "# group predictors and predictand by season\n",
    "season_indices_train = Era5_train.groupby('time.season').groups\n",
    "season_indices_valid = Era5_valid.groupby('time.season').groups\n",
    "\n",
    "# group training and validation set by season\n",
    "Era_season_train = {k: Era5_train.isel(time=v) for k, v in\n",
    "                    season_indices_train.items()}\n",
    "Obs_season_train = {k: Obs_train.isel(time=v) for k, v in\n",
    "                    season_indices_train.items()}\n",
    "Era_season_valid = {k: Era5_valid.isel(time=v) for k, v in\n",
    "                    season_indices_valid.items()}\n",
    "Obs_season_valid = {k: Obs_valid.isel(time=v) for k, v in\n",
    "                    season_indices_valid.items()}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "dde5735a-4633-452f-aada-22da3a15696e",
   "metadata": {},
   "outputs": [],
   "source": [
    "Era_season_train = {k: Era5_train.isel(time=v) for k, v in\n",
    "                    season_indices_train.items()}"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}