From 8df419810a4cb7b3fafc99aa56ef2bf8317d55f3 Mon Sep 17 00:00:00 2001
From: "Daniel.Frisinghelli" <daniel.frisinghelli@eurac.edu>
Date: Tue, 28 Sep 2021 14:41:11 +0000
Subject: [PATCH] Implemented evaluation of PR with MSE.

---
 Notebooks/eval_precipitation.ipynb | 29 ++++++++++++++++-------------
 1 file changed, 16 insertions(+), 13 deletions(-)

diff --git a/Notebooks/eval_precipitation.ipynb b/Notebooks/eval_precipitation.ipynb
index 4bdac8d..ab98608 100644
--- a/Notebooks/eval_precipitation.ipynb
+++ b/Notebooks/eval_precipitation.ipynb
@@ -56,12 +56,12 @@
     "MODEL = 'USegNet'\n",
     "PPREDICTORS = 'ztuvq'\n",
     "PLEVELS = ['500', '850']\n",
-    "SPREDICTORS = 'p'\n",
+    "SPREDICTORS = 'mslp'\n",
     "DEM = 'dem'\n",
     "DEM_FEATURES = ''\n",
     "DOY = 'doy'\n",
-    "# WET_DAY_THRESHOLD = 1\n",
-    "# LOSS = 'MSELoss'\n",
+    "WET_DAY_THRESHOLD = ''\n",
+    "LOSS = 'MSELoss'\n",
     "# LOSS = 'BernoulliGammaLoss'"
    ]
   },
@@ -198,7 +198,8 @@
     "PATTERN = '_'.join([PATTERN, DEM]) if DEM else PATTERN\n",
     "PATTERN = '_'.join([PATTERN, DEM_FEATURES]) if DEM_FEATURES else PATTERN\n",
     "PATTERN = '_'.join([PATTERN, DOY]) if DOY else PATTERN\n",
-    "# PATTERN = '_'.join([PATTERN, '{:0d}mm_{}'.format(WET_DAY_THRESHOLD, LOSS)])\n",
+    "PATTERN = '_'.join([PATTERN, '{}mm'.format(str(WET_DAY_THRESHOLD).replace('.', ''))]) if WET_DAY_THRESHOLD else PATTERN\n",
+    "PATTERN = '_'.join([PATTERN, LOSS])\n",
     "PATTERN"
    ]
   },
@@ -225,7 +226,8 @@
    "source": [
     "# model predictions and observations NetCDF \n",
     "y_pred = xr.open_dataset(search_files(TARGET_PATH.joinpath(PREDICTAND), '.'.join([PATTERN, 'nc$'])).pop())\n",
-    "y_true = xr.open_dataset(search_files(OBS_PATH.joinpath(PREDICTAND), 'OBS_pr(.*).nc$').pop())"
+    "y_true = xr.open_dataset(search_files(OBS_PATH.joinpath(PREDICTAND), 'OBS_pr(.*).nc$').pop())\n",
+    "y_pred = y_pred.rename({'pr': 'precipitation'})"
    ]
   },
   {
@@ -260,9 +262,10 @@
    "outputs": [],
    "source": [
     "# align datasets and mask missing values in model predictions\n",
-    "y_true, y_refe, y_pred_pr, y_pred_prob = xr.align(y_true, y_refe, y_pred.precipitation.to_dataset(), y_pred.prob.to_dataset(), join='override')\n",
+    "# y_true, y_refe, y_pred_pr, y_pred_prob = xr.align(y_true, y_refe, y_pred.precipitation.to_dataset(), y_pred.prob.to_dataset(), join='override')\n",
+    "y_true, y_refe, y_pred_pr = xr.align(y_true, y_refe, y_pred, join='override')\n",
     "y_pred_pr = y_pred_pr.where(~np.isnan(y_true.precipitation), other=np.nan)\n",
-    "y_pred_prob = y_pred_prob.where(~np.isnan(y_true.precipitation), other=np.nan)\n",
+    "# y_pred_prob = y_pred_prob.where(~np.isnan(y_true.precipitation), other=np.nan)\n",
     "y_refe = y_refe.where(~np.isnan(y_true.precipitation), other=np.nan)"
    ]
   },
@@ -302,8 +305,8 @@
    "outputs": [],
    "source": [
     "# calculate monthly mean precipitation (mm / month)\n",
-    "y_pred_values = y_pred_pr[NAMES[PREDICTAND]].resample(time='1M').sum(skipna=False).groupby('time.month').mean(dim='time').values\n",
-    "y_true_values = y_true[NAMES[PREDICTAND]].resample(time='1M').sum(skipna=False).groupby('time.month').mean(dim='time').values"
+    "y_pred_values = y_pred_pr.resample(time='1M').sum(skipna=False).groupby('time.month').mean(dim='time').to_array().values\n",
+    "y_true_values = y_true.resample(time='1M').sum(skipna=False).groupby('time.month').mean(dim='time').to_array().values"
    ]
   },
   {
@@ -339,8 +342,8 @@
    "outputs": [],
    "source": [
     "# group timeseries by month and calculate mean over time and space\n",
-    "y_pred_ac = y_pred_pr.resample(time='1M').sum(skipna=False).groupby('time.month').mean(dim=('y', 'x', 'time'), skipna=True)\n",
-    "y_true_ac = y_true.resample(time='1M').sum(skipna=False).groupby('time.month').mean(dim=('y', 'x', 'time'), skipna=True)"
+    "y_pred_ac = y_pred_pr.resample(time='1M').sum(skipna=False).groupby('time.month').mean(dim=('y', 'x', 'time'), skipna=True).to_array().values.squeeze()\n",
+    "y_true_ac = y_true.resample(time='1M').sum(skipna=False).groupby('time.month').mean(dim=('y', 'x', 'time'), skipna=True).to_array().values.squeeze()"
    ]
   },
   {
@@ -380,8 +383,8 @@
     "\n",
     "# add axis for annual cycle\n",
     "axins = inset_axes(ax, width=\"30%\", height=\"40%\", loc=2, borderpad=0.25)\n",
-    "axins.plot(y_pred_ac[NAMES[PREDICTAND]].values, ls='--', color='k', label='Predicted')\n",
-    "axins.plot(y_true_ac[NAMES[PREDICTAND]].values, ls='-', color='k', label='Observed')\n",
+    "axins.plot(y_pred_ac, ls='--', color='k', label='Predicted')\n",
+    "axins.plot(y_true_ac, ls='-', color='k', label='Observed')\n",
     "axins.legend(frameon=False, fontsize=12, loc='lower center');\n",
     "axins.set_yticks(np.arange(0, 200, 50))\n",
     "axins.set_yticklabels(np.arange(0, 200, 50), fontsize=12)\n",
-- 
GitLab