diff --git a/Notebooks/eval_precipitation.ipynb b/Notebooks/eval_precipitation.ipynb index 4bdac8d9b968b1fcbeee519b34fb67684df33631..ab986088d8e6a0354a9eeb9d9b1c8cd54e081833 100644 --- a/Notebooks/eval_precipitation.ipynb +++ b/Notebooks/eval_precipitation.ipynb @@ -56,12 +56,12 @@ "MODEL = 'USegNet'\n", "PPREDICTORS = 'ztuvq'\n", "PLEVELS = ['500', '850']\n", - "SPREDICTORS = 'p'\n", + "SPREDICTORS = 'mslp'\n", "DEM = 'dem'\n", "DEM_FEATURES = ''\n", "DOY = 'doy'\n", - "# WET_DAY_THRESHOLD = 1\n", - "# LOSS = 'MSELoss'\n", + "WET_DAY_THRESHOLD = ''\n", + "LOSS = 'MSELoss'\n", "# LOSS = 'BernoulliGammaLoss'" ] }, @@ -198,7 +198,8 @@ "PATTERN = '_'.join([PATTERN, DEM]) if DEM else PATTERN\n", "PATTERN = '_'.join([PATTERN, DEM_FEATURES]) if DEM_FEATURES else PATTERN\n", "PATTERN = '_'.join([PATTERN, DOY]) if DOY else PATTERN\n", - "# PATTERN = '_'.join([PATTERN, '{:0d}mm_{}'.format(WET_DAY_THRESHOLD, LOSS)])\n", + "PATTERN = '_'.join([PATTERN, '{}mm'.format(str(WET_DAY_THRESHOLD).replace('.', ''))]) if WET_DAY_THRESHOLD else PATTERN\n", + "PATTERN = '_'.join([PATTERN, LOSS])\n", "PATTERN" ] }, @@ -225,7 +226,8 @@ "source": [ "# model predictions and observations NetCDF \n", "y_pred = xr.open_dataset(search_files(TARGET_PATH.joinpath(PREDICTAND), '.'.join([PATTERN, 'nc$'])).pop())\n", - "y_true = xr.open_dataset(search_files(OBS_PATH.joinpath(PREDICTAND), 'OBS_pr(.*).nc$').pop())" + "y_true = xr.open_dataset(search_files(OBS_PATH.joinpath(PREDICTAND), 'OBS_pr(.*).nc$').pop())\n", + "y_pred = y_pred.rename({'pr': 'precipitation'})" ] }, { @@ -260,9 +262,10 @@ "outputs": [], "source": [ "# align datasets and mask missing values in model predictions\n", - "y_true, y_refe, y_pred_pr, y_pred_prob = xr.align(y_true, y_refe, y_pred.precipitation.to_dataset(), y_pred.prob.to_dataset(), join='override')\n", + "# y_true, y_refe, y_pred_pr, y_pred_prob = xr.align(y_true, y_refe, y_pred.precipitation.to_dataset(), y_pred.prob.to_dataset(), join='override')\n", + "y_true, y_refe, y_pred_pr = xr.align(y_true, y_refe, y_pred, join='override')\n", "y_pred_pr = y_pred_pr.where(~np.isnan(y_true.precipitation), other=np.nan)\n", - "y_pred_prob = y_pred_prob.where(~np.isnan(y_true.precipitation), other=np.nan)\n", + "# y_pred_prob = y_pred_prob.where(~np.isnan(y_true.precipitation), other=np.nan)\n", "y_refe = y_refe.where(~np.isnan(y_true.precipitation), other=np.nan)" ] }, @@ -302,8 +305,8 @@ "outputs": [], "source": [ "# calculate monthly mean precipitation (mm / month)\n", - "y_pred_values = y_pred_pr[NAMES[PREDICTAND]].resample(time='1M').sum(skipna=False).groupby('time.month').mean(dim='time').values\n", - "y_true_values = y_true[NAMES[PREDICTAND]].resample(time='1M').sum(skipna=False).groupby('time.month').mean(dim='time').values" + "y_pred_values = y_pred_pr.resample(time='1M').sum(skipna=False).groupby('time.month').mean(dim='time').to_array().values\n", + "y_true_values = y_true.resample(time='1M').sum(skipna=False).groupby('time.month').mean(dim='time').to_array().values" ] }, { @@ -339,8 +342,8 @@ "outputs": [], "source": [ "# group timeseries by month and calculate mean over time and space\n", - "y_pred_ac = y_pred_pr.resample(time='1M').sum(skipna=False).groupby('time.month').mean(dim=('y', 'x', 'time'), skipna=True)\n", - "y_true_ac = y_true.resample(time='1M').sum(skipna=False).groupby('time.month').mean(dim=('y', 'x', 'time'), skipna=True)" + "y_pred_ac = y_pred_pr.resample(time='1M').sum(skipna=False).groupby('time.month').mean(dim=('y', 'x', 'time'), skipna=True).to_array().values.squeeze()\n", + "y_true_ac = y_true.resample(time='1M').sum(skipna=False).groupby('time.month').mean(dim=('y', 'x', 'time'), skipna=True).to_array().values.squeeze()" ] }, { @@ -380,8 +383,8 @@ "\n", "# add axis for annual cycle\n", "axins = inset_axes(ax, width=\"30%\", height=\"40%\", loc=2, borderpad=0.25)\n", - "axins.plot(y_pred_ac[NAMES[PREDICTAND]].values, ls='--', color='k', label='Predicted')\n", - "axins.plot(y_true_ac[NAMES[PREDICTAND]].values, ls='-', color='k', label='Observed')\n", + "axins.plot(y_pred_ac, ls='--', color='k', label='Predicted')\n", + "axins.plot(y_true_ac, ls='-', color='k', label='Observed')\n", "axins.legend(frameon=False, fontsize=12, loc='lower center');\n", "axins.set_yticks(np.arange(0, 200, 50))\n", "axins.set_yticklabels(np.arange(0, 200, 50), fontsize=12)\n",