From 68635fab73cffdece0ee88a352d966979709ea92 Mon Sep 17 00:00:00 2001
From: "Daniel.Frisinghelli" <daniel.frisinghelli@eurac.edu>
Date: Wed, 6 Oct 2021 15:11:22 +0000
Subject: [PATCH] Refactor.

---
 Notebooks/check_predictors.ipynb | 103 ++++++++++++++++++++++++++++++-
 Notebooks/eval_temperature.ipynb |   2 +-
 2 files changed, 101 insertions(+), 4 deletions(-)

diff --git a/Notebooks/check_predictors.ipynb b/Notebooks/check_predictors.ipynb
index 4cb2f92..dd442c4 100644
--- a/Notebooks/check_predictors.ipynb
+++ b/Notebooks/check_predictors.ipynb
@@ -22,12 +22,15 @@
     "import numpy as np\n",
     "import pandas as pd\n",
     "import xarray as xr\n",
-    "from sklearn.linear_model import TweedieRegressor\n",
+    "from sklearn.model_selection import train_test_split\n",
     "\n",
     "# locals\n",
+    "from pysegcnn.core.utils import search_files\n",
     "from climax.core.dataset import ERA5Dataset\n",
     "from climax.core.constants import ERA5_VARIABLES\n",
-    "from climax.core.utils import search_files"
+    "from climax.core.utils import search_files\n",
+    "from climax.main.config import CALIB_PERIOD\n",
+    "from climax.main.io import OBS_PATH"
    ]
   },
   {
@@ -169,7 +172,101 @@
    "metadata": {},
    "outputs": [],
    "source": [
-    "xr.merge([predictors, dem])"
+    "Era5_ds = xr.merge([predictors, dem])"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "id": "dd584313-8913-45b5-a379-2724d42bc99f",
+   "metadata": {},
+   "source": [
+    "## Read observations"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "e982d6f9-ae6b-435e-917d-44118c6a09b1",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "PREDICTAND = 'pr'"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "fe45a13d-d95d-4260-a3bc-0e6fd53db732",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "# read in-situ gridded observations\n",
+    "Obs_ds = search_files(OBS_PATH.joinpath(PREDICTAND), 'OBS_{}(.*).nc$'.format(PREDICTAND)).pop()\n",
+    "Obs_ds = xr.open_dataset(Obs_ds)"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "id": "caaf0557-74a9-4c30-9697-bb58d68553aa",
+   "metadata": {},
+   "source": [
+    "## Group predictors by season"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "c98896f5-7953-4bb3-af97-fd98740514d1",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "# split into training and validation set\n",
+    "train, valid = train_test_split(CALIB_PERIOD, shuffle=False, test_size=0.1)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "1a7295fd-3c57-4ca4-9c47-23a179db4829",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "# training and validation dataset\n",
+    "Era5_train, Obs_train = Era5_ds.sel(time=train), Obs_ds.sel(time=train)\n",
+    "Era5_valid, Obs_valid = Era5_ds.sel(time=valid), Obs_ds.sel(time=valid)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "82a88472-7837-4fa6-a591-1f1952ea442b",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "# group predictors and predictand by season\n",
+    "season_indices_train = Era5_train.groupby('time.season').groups\n",
+    "season_indices_valid = Era5_valid.groupby('time.season').groups\n",
+    "\n",
+    "# group training and validation set by season\n",
+    "Era_season_train = {k: Era5_train.isel(time=v) for k, v in\n",
+    "                    season_indices_train.items()}\n",
+    "Obs_season_train = {k: Obs_train.isel(time=v) for k, v in\n",
+    "                    season_indices_train.items()}\n",
+    "Era_season_valid = {k: Era5_valid.isel(time=v) for k, v in\n",
+    "                    season_indices_valid.items()}\n",
+    "Obs_season_valid = {k: Obs_valid.isel(time=v) for k, v in\n",
+    "                    season_indices_valid.items()}"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "dde5735a-4633-452f-aada-22da3a15696e",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "Era_season_train = {k: Era5_train.isel(time=v) for k, v in\n",
+    "                    season_indices_train.items()}"
    ]
   }
  ],
diff --git a/Notebooks/eval_temperature.ipynb b/Notebooks/eval_temperature.ipynb
index 1193bc3..f315da9 100644
--- a/Notebooks/eval_temperature.ipynb
+++ b/Notebooks/eval_temperature.ipynb
@@ -498,7 +498,7 @@
     "\n",
     "# adjust figure\n",
     "fig.suptitle('Average {}: 1991 - 2010'.format(NAMES[PREDICTAND]), fontsize=20);\n",
-    "fig.subplots_adjust(hspace=0, wspace=0, top=0.85)\n",
+    "fig.subplots_adjust(hspace=0, wspace=0, top=0.85)c\n",
     "\n",
     "# add colorbar for bias\n",
     "axes = axes.flatten()\n",
-- 
GitLab