From bc0cd0f91628741933705d3e1d674cb88a68928e Mon Sep 17 00:00:00 2001
From: "Daniel.Frisinghelli" <daniel.frisinghelli@eurac.edu>
Date: Mon, 4 Oct 2021 15:36:39 +0000
Subject: [PATCH] Evaluate different distributions for precipitation.

---
 Notebooks/pr_distribution.ipynb | 285 ++++++++++++++++++++++++++------
 1 file changed, 237 insertions(+), 48 deletions(-)

diff --git a/Notebooks/pr_distribution.ipynb b/Notebooks/pr_distribution.ipynb
index 59db30e..51d2351 100644
--- a/Notebooks/pr_distribution.ipynb
+++ b/Notebooks/pr_distribution.ipynb
@@ -40,6 +40,36 @@
     "from pysegcnn.core.graphics import plot_classification_report"
    ]
   },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "de6ae734-3a6a-477e-a5a0-8b9ec5911369",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "# entire reference period\n",
+    "REFERENCE_PERIOD = np.concatenate([CALIB_PERIOD, VALID_PERIOD], axis=0)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "534d9565-4b58-4959-bef3-edde969e2364",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "# empirical quantiles\n",
+    "quantiles = np.arange(0.01, 1, 0.005)"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "id": "12382efb-1a3a-4ede-a904-7f762bfe56c7",
+   "metadata": {},
+   "source": [
+    "### Load observations"
+   ]
+  },
   {
    "cell_type": "code",
    "execution_count": null,
@@ -53,6 +83,25 @@
     "y_true = xr.open_dataset(search_files(OBS_PATH.joinpath('pr'), 'OBS_pr(.*).nc$').pop())"
    ]
   },
+  {
+   "cell_type": "markdown",
+   "id": "5d30b543-aa3b-45f3-b8e8-90d72f4f6896",
+   "metadata": {},
+   "source": [
+    "### Select time period"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "f902683a-a560-48f9-b2d1-ef9c341ca69a",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "# time period\n",
+    "PERIOD = REFERENCE_PERIOD"
+   ]
+  },
   {
    "cell_type": "code",
    "execution_count": null,
@@ -60,30 +109,42 @@
    "metadata": {},
    "outputs": [],
    "source": [
-    "# subset to calibration and validation period\n",
-    "y_calib = y_true.sel(time=CALIB_PERIOD).precipitation.values"
+    "# subset to time period\n",
+    "y = y_true.sel(time=PERIOD)"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "id": "f6d01e1e-9dc2-4c31-a31a-a6c91abc7fb4",
+   "metadata": {},
+   "source": [
+    "### Fit distributions: annually"
    ]
   },
   {
    "cell_type": "code",
    "execution_count": null,
-   "id": "ed7d1686-968e-49e9-ba34-d03658ba3b32",
+   "id": "0ffce851-50fc-4795-84b9-972e4f1a5169",
    "metadata": {},
    "outputs": [],
    "source": [
-    "# mask missing values\n",
-    "y_calib = y_calib[~np.isnan(y_calib)]"
+    "# helper function retrieving only valid observations\n",
+    "def valid(ds):\n",
+    "    valid = ds.precipitation.values\n",
+    "    valid = valid[~np.isnan(valid)]  # mask missing values\n",
+    "    valid = valid[valid > 0]  # only consider pr > 0\n",
+    "    return valid"
    ]
   },
   {
    "cell_type": "code",
    "execution_count": null,
-   "id": "1b570e39-f242-49ef-8aef-eff8fbcf7c4d",
+   "id": "6f68803b-4dbc-4d43-99c0-a32e482b647a",
    "metadata": {},
    "outputs": [],
    "source": [
-    "# only use values greater 0\n",
-    "y_calib = y_calib[y_calib > 0]"
+    "# valid observations\n",
+    "y_valid = valid(y)"
    ]
   },
   {
@@ -94,8 +155,8 @@
    "outputs": [],
    "source": [
     "# fit gamma distribution to data\n",
-    "alpha, loc, beta = stats.gamma.fit(y_calib, loc=0.1)\n",
-    "gamma_calib = stats.gamma(alpha, loc=loc, scale=beta)"
+    "alpha, loc, beta = stats.gamma.fit(y_valid, floc=0)\n",
+    "gamma = stats.gamma(alpha, loc=loc, scale=beta)"
    ]
   },
   {
@@ -106,8 +167,8 @@
    "outputs": [],
    "source": [
     "# fit generalized pareto distribution to data\n",
-    "alpha, loc, beta = stats.genpareto.fit(y_calib, loc=0.1)\n",
-    "genpareto_calib = stats.genpareto(alpha, loc=loc, scale=beta)"
+    "alpha, loc, beta = stats.genpareto.fit(y_valid, floc=0)\n",
+    "genpareto = stats.genpareto(alpha, loc=loc, scale=beta)"
    ]
   },
   {
@@ -117,13 +178,10 @@
    "metadata": {},
    "outputs": [],
    "source": [
-    "# compute empirical quantiles\n",
-    "quantiles = np.arange(0.01, 1, 0.005)\n",
-    "\n",
     "# empirical quantiles and theoretical quantiles\n",
-    "eq = np.quantile(y_calib, quantiles)\n",
-    "tq_gamma = gamma_calib.ppf(quantiles)\n",
-    "tq_genpareto = genpareto_calib.ppf(quantiles)\n",
+    "eq = np.quantile(y_valid, quantiles)\n",
+    "tq_gamma = gamma.ppf(quantiles)\n",
+    "tq_genpareto = genpareto.ppf(quantiles)\n",
     "\n",
     "# Q-Q plot \n",
     "RANGE = 40\n",
@@ -133,38 +191,121 @@
     "ax.plot(np.arange(0, RANGE), np.arange(0, RANGE), '--k')\n",
     "ax.set_xlim(0, RANGE)\n",
     "ax.set_ylim(0, RANGE)\n",
-    "ax.set_ylabel('Theoretical quantiles');\n",
-    "ax.set_xlabel('Empirical quantiles');\n",
-    "ax.legend(frameon=False, fontsize=12);"
+    "ax.set_xticks(np.arange(0, RANGE + 5, 5))\n",
+    "ax.set_yticks(np.arange(0, RANGE + 5, 5))\n",
+    "ax.set_xticklabels([str(t) for t in np.arange(0, RANGE + 5, 5)], fontsize=12)\n",
+    "ax.set_yticklabels([str(t) for t in np.arange(0, RANGE + 5, 5)], fontsize=12)\n",
+    "ax.set_ylabel('Theoretical quantiles', fontsize=14);\n",
+    "ax.set_xlabel('Empirical quantiles', fontsize=14);\n",
+    "ax.legend(frameon=False, fontsize=14);\n",
+    "ax.set_title('Reference period: {} - {}'.format(str(PERIOD[0]), str(PERIOD[-1])), fontsize=14)\n",
+    "\n",
+    "# save figure\n",
+    "fig.savefig('./Figures/pr_distribution.png', bbox_inches='tight', dpi=300)"
    ]
   },
   {
    "cell_type": "markdown",
-   "id": "c0fea8ac-bac0-4096-bc81-90d799f8ab94",
+   "id": "5fd0e9d8-759d-45ee-bb1f-9c749ac23e8e",
    "metadata": {},
    "source": [
-    "### Empirical quantiles per grid point"
+    "### Fit distributions: monthly"
    ]
   },
   {
    "cell_type": "code",
    "execution_count": null,
-   "id": "4dcb3348-5d22-4324-b840-2c305983e826",
+   "id": "156e5415-4065-4887-b759-0e665d671b38",
    "metadata": {},
    "outputs": [],
    "source": [
-    "quantiles = np.arange(0.01, 1, 0.01)"
+    "# get the indices of the observations for each month\n",
+    "month_idx = y.groupby('time.month').groups"
    ]
   },
   {
    "cell_type": "code",
    "execution_count": null,
-   "id": "18ce44b4-d6c0-4950-9cd2-7a3af1095b24",
+   "id": "092e865d-f033-4f60-8098-86ae5068e045",
    "metadata": {},
    "outputs": [],
    "source": [
-    "# subset to calibration period\n",
-    "y_calib_p = y_true.sel(time=CALIB_PERIOD).precipitation"
+    "# fit distribution to observations for each month\n",
+    "month_gamma = {}\n",
+    "month_genpareto = {}\n",
+    "for month, idx in month_idx.items():\n",
+    "    print('Month: {}'.format(calendar.month_name[month]))\n",
+    "    # select the data of the current month\n",
+    "    data = y.isel(time=idx)\n",
+    "    data = valid(data)\n",
+    "    \n",
+    "    # fit distributions\n",
+    "    \n",
+    "    # gamma\n",
+    "    alpha, loc, beta = stats.gamma.fit(data, floc=0)\n",
+    "    gamma = stats.gamma(alpha, loc=loc, scale=beta)\n",
+    "    month_gamma[month] = gamma\n",
+    "    \n",
+    "    # genpareto\n",
+    "    alpha, loc, beta = stats.genpareto.fit(data, floc=0)\n",
+    "    genpareto = stats.genpareto(alpha, loc=loc, scale=beta)\n",
+    "    month_genpareto[month] = genpareto  "
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "396e5ee4-1632-4591-b93b-91fa6ac1d373",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "# plot empirical vs. theoretical quantiles for each month\n",
+    "fig, axes = plt.subplots(4, 3, figsize=(12, 12), sharex=True, sharey=True)\n",
+    "axes = axes.flatten()\n",
+    "\n",
+    "RANGE = 40\n",
+    "for month, idx in month_idx.items():\n",
+    "    # axis to plot\n",
+    "    ax = axes[month - 1]\n",
+    "    \n",
+    "    # compute empirical quantiles\n",
+    "    data = y.isel(time=idx)\n",
+    "    data = valid(data)\n",
+    "    eq = np.quantile(data, quantiles)\n",
+    "    \n",
+    "    # compute theoretical quantiles\n",
+    "    tq_gamma = month_gamma[month].ppf(quantiles)\n",
+    "    tq_gpare = month_genpareto[month].ppf(quantiles)\n",
+    "    \n",
+    "    # plot empirical vs. theoretical quantiles\n",
+    "    ax.scatter(eq, tq_gamma, color='grey', label='Gamma')\n",
+    "    ax.scatter(eq, tq_gpare, color='k', label='GenPareto')\n",
+    "    ax.plot(np.arange(0, RANGE), np.arange(0, RANGE), '-k')\n",
+    "    ax.set_title(calendar.month_name[month], fontsize=14)\n",
+    "    ax.set_xlim(0, RANGE)\n",
+    "    ax.set_ylim(0, RANGE)\n",
+    "    ax.set_xticks(np.arange(0, RANGE + 5, 5))\n",
+    "    ax.set_yticks(np.arange(0, RANGE + 5, 5))\n",
+    "    ax.set_xticklabels([str(t) for t in np.arange(0, RANGE + 5, 5)], fontsize=12)\n",
+    "    ax.set_yticklabels([str(t) for t in np.arange(0, RANGE + 5, 5)], fontsize=12)\n",
+    "\n",
+    "# add legend\n",
+    "axes[0].legend(frameon=False, fontsize=12)\n",
+    "\n",
+    "# add figure title\n",
+    "fig.suptitle('Reference period: {} - {}'.format(str(PERIOD[0]), str(PERIOD[-1])), fontsize=14)\n",
+    "\n",
+    "# adjust subplots\n",
+    "fig.subplots_adjust(wspace=0.1)\n",
+    "fig.savefig('./Figures/pr_distribution_m.png', bbox_inches='tight', dpi=300)"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "id": "c0fea8ac-bac0-4096-bc81-90d799f8ab94",
+   "metadata": {},
+   "source": [
+    "### Empirical quantiles per grid point"
    ]
   },
   {
@@ -175,7 +316,7 @@
    "outputs": [],
    "source": [
     "# compute empirical quantiles over time\n",
-    "equantiles = y_calib_p.quantile(quantiles, dim='time')\n",
+    "equantiles = y.precipitation.quantile(quantiles, dim='time')\n",
     "equantiles = equantiles.rename({'quantile': 'q'})"
    ]
   },
@@ -187,39 +328,42 @@
    "outputs": [],
    "source": [
     "# iterate over the grid points\n",
-    "gammaq = np.ones(shape=(len(y_calib_p.q), len(y_calib_p.y), len(y_calib_p.x))) * np.nan\n",
-    "for i, _ in enumerate(y_calib_p.x):\n",
-    "    print('Rows: {}/{}'.format(i + 1, len(y_calib_p.x)))\n",
-    "    for j, _ in enumerate(y_calib_p.y):\n",
+    "gammaq = np.ones(shape=(len(equantiles.q), len(equantiles.y), len(equantiles.x))) * np.nan\n",
+    "genpaq = np.ones(shape=(len(equantiles.q), len(equantiles.y), len(equantiles.x))) * np.nan\n",
+    "for i, _ in enumerate(y.x):\n",
+    "    print('Rows: {}/{}'.format(i + 1, len(y.x)))\n",
+    "    for j, _ in enumerate(y.y):\n",
     "    \n",
     "        # current grid point: xarray.Dataset, dimensions=(time)\n",
-    "        point = y_calib_p.isel(x=i, y=j).values\n",
-    "        \n",
-    "        # mask missing values\n",
-    "        point = point[~np.isnan(point)]\n",
+    "        point = y.isel(x=i, y=j)\n",
+    "        point = valid(point)\n",
     "        \n",
     "        # check if the grid point is valid\n",
     "        if point.size < 1:\n",
     "            # move on to next grid point\n",
     "            continue\n",
     "            \n",
-    "        # consider only values > 0\n",
-    "        point = point[point > 0]\n",
-    "            \n",
     "        # fit Gamma distribution to grid point\n",
-    "        alpha, loc, beta = stats.gamma.fit(point)\n",
+    "        alpha, loc, beta = stats.gamma.fit(point, floc=0)\n",
     "        gamma = stats.gamma(alpha, loc=loc, scale=beta)\n",
     "        \n",
-    "        # compute theoretical quantiles of fitted gamma distribution\n",
-    "        tq = gamma.ppf(quantiles)\n",
+    "        # fit GenPareto distribution to grid point\n",
+    "        alpha, loc, beta = stats.genpareto.fit(point, floc=0)\n",
+    "        genpa = stats.genpareto(alpha, loc=loc, scale=beta)\n",
+    "        \n",
+    "        # compute theoretical quantiles of fitted distributions\n",
+    "        tq_gamma = gamma.ppf(quantiles)\n",
+    "        tq_genpa = genpa.ppf(quantiles)\n",
     "        \n",
     "        # store theoretical quantiles for current grid point\n",
-    "        gammaq[:, j, i] = tq\n",
+    "        gammaq[:, j, i] = tq_gamma\n",
+    "        genpaq[:, j, i] = tq_genpa\n",
     "\n",
     "# store theoretical quantiles in xarray.DataArray\n",
-    "tquantiles = xr.DataArray(data=gammaq, dims=['q', 'y', 'x'],\n",
-    "                          coords=dict(q=quantiles, lat=y_calib_p.y, lon=y_calib_p.x),\n",
-    "                          name='precipitation')"
+    "gammaq = xr.DataArray(data=gammaq, dims=['q', 'y', 'x'], coords=dict(q=quantiles, y=y.y, x=y.x),\n",
+    "                      name='precipitation')\n",
+    "genpaq = xr.DataArray(data=genpaq, dims=['q', 'y', 'x'], coords=dict(q=quantiles, y=y.y, x=y.x),\n",
+    "                      name='precipitation')"
    ]
   },
   {
@@ -230,8 +374,53 @@
    "outputs": [],
    "source": [
     "# compute bias in theoretical quantiles\n",
-    "biasq = tquantiles - equantiles"
+    "bias_gamma = gammaq - equantiles  # predicted - observed\n",
+    "bias_genpa = genpaq - equantiles"
    ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "b8089c11-a48d-4028-9d4b-e03101ff5e55",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "# plot spatial bias in different quantiles\n",
+    "plot_quantiles = quantiles[18::20]\n",
+    "fig, axes = plt.subplots(3, 3, sharex=True, sharey=True, figsize=(12, 12))\n",
+    "axes = axes.flatten()\n",
+    "\n",
+    "for dist in ['gamma', 'genpareto']:\n",
+    "    biasq = bias_gamma if dist == 'gamma' else bias_genpa\n",
+    "\n",
+    "    # iterate over quantiles to plot\n",
+    "    for ax, q in zip(axes, plot_quantiles):\n",
+    "        im = ax.imshow(biasq.sel(q=q).values, origin='lower', vmin=0, vmax=5, cmap='viridis_r')\n",
+    "        ax.set_title(str('P{:.0f}'.format(q * 100)), fontsize=14)\n",
+    "\n",
+    "    # adjust subplots\n",
+    "    fig.subplots_adjust(wspace=0.1, hspace=0.1)\n",
+    "\n",
+    "    # add colorbar for bias\n",
+    "    axes = axes.flatten()\n",
+    "    cbar_ax_bias = fig.add_axes([axes[2].get_position().x1 + 0.01, axes[2].get_position().y0,\n",
+    "                                 0.01, axes[2].get_position().y1 - axes[2].get_position().y0])\n",
+    "    cbar_bias = fig.colorbar(im, cax=cbar_ax_bias)\n",
+    "    cbar_bias.set_label(label='Bias (mm)', fontsize=14)\n",
+    "    cbar_bias.ax.tick_params(labelsize=14, pad=10)\n",
+    "\n",
+    "    # save figure\n",
+    "    fig\n",
+    "    fig.savefig('./Figures/pr_distribution_{}_grid.png'.format(dist), bbox_inches='tight', dpi=300)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "a5ee4d3f-608b-4598-b235-3cd20a184aff",
+   "metadata": {},
+   "outputs": [],
+   "source": []
   }
  ],
  "metadata": {
-- 
GitLab