diff --git a/.gitignore b/.gitignore index 51026ddbf4c7110b8d769f14cc075a126cc11318..91f1eff04fa35e0a39b8631358a5f97dc4cb493a 100644 --- a/.gitignore +++ b/.gitignore @@ -8,5 +8,5 @@ __pycache__/ # jupyter .ipynb_checkpoints/ -Figures/ +Notebooks/Figures/ *.nc diff --git a/Notebooks/Figures/pr_ROC.png b/Notebooks/Figures/pr_ROC.png index 3d794b96feeac9ff9b61ffd5d1293be0b9eab27a..05e4ac64ffaae4b95bc5cd1ffd79e3984de104b8 100644 Binary files a/Notebooks/Figures/pr_ROC.png and b/Notebooks/Figures/pr_ROC.png differ diff --git a/Notebooks/Figures/pr_average_bias.png b/Notebooks/Figures/pr_average_bias.png index 2c4e8edc26db176fae1b874825e38a3d154551c8..dece9cee6b82b6b1895b99f18e0d37ea91ade836 100644 Binary files a/Notebooks/Figures/pr_average_bias.png and b/Notebooks/Figures/pr_average_bias.png differ diff --git a/Notebooks/Figures/pr_average_bias_p98.png b/Notebooks/Figures/pr_average_bias_p98.png index 75aedf9f96ebaf90868ae438e75c32f9b40755bf..389092dafa3e94c3b69719e2c3712319d8e0cb6b 100644 Binary files a/Notebooks/Figures/pr_average_bias_p98.png and b/Notebooks/Figures/pr_average_bias_p98.png differ diff --git a/Notebooks/Figures/pr_average_bias_seasonal.png b/Notebooks/Figures/pr_average_bias_seasonal.png index 6f9bccbaeab009ad09e0ac9676dbdf6e75726843..c0301ffb0e292ce7cea0cbe82f192577b8b8cbdc 100644 Binary files a/Notebooks/Figures/pr_average_bias_seasonal.png and b/Notebooks/Figures/pr_average_bias_seasonal.png differ diff --git a/Notebooks/Figures/pr_average_bias_seasonal_ex.png b/Notebooks/Figures/pr_average_bias_seasonal_ex.png index 623758a81477f6db21c14191e137c800c7453901..33274b62674ffddac36e802982e6c6e6b33b484e 100644 Binary files a/Notebooks/Figures/pr_average_bias_seasonal_ex.png and b/Notebooks/Figures/pr_average_bias_seasonal_ex.png differ diff --git a/Notebooks/Figures/pr_bias_wet_days.png b/Notebooks/Figures/pr_bias_wet_days.png index d440737d9e35f3c053986e9160ed1b899ca47e19..cce0d4a8f423e431d075d074409f578889122206 100644 Binary files a/Notebooks/Figures/pr_bias_wet_days.png and b/Notebooks/Figures/pr_bias_wet_days.png differ diff --git a/Notebooks/Figures/pr_bias_wet_days_p.png b/Notebooks/Figures/pr_bias_wet_days_p.png index d9d92768028899bcf93a911ee797e5de2c43a7b3..49b3b8455022912535b8df697273c94ec64ebb43 100644 Binary files a/Notebooks/Figures/pr_bias_wet_days_p.png and b/Notebooks/Figures/pr_bias_wet_days_p.png differ diff --git a/Notebooks/Figures/pr_distribution.png b/Notebooks/Figures/pr_distribution.png index 16bebc879e7c905559c876f36c14aef4ab79b4d9..c0e82110ebf9744fc5e4b3f56010d3000039c143 100644 Binary files a/Notebooks/Figures/pr_distribution.png and b/Notebooks/Figures/pr_distribution.png differ diff --git a/Notebooks/Figures/pr_r2.png b/Notebooks/Figures/pr_r2.png index c53560b67b8f4b23393cb629ecc5ad700b90a29a..afe8e83daab47b571515f3a66129f67bfb41e6ef 100644 Binary files a/Notebooks/Figures/pr_r2.png and b/Notebooks/Figures/pr_r2.png differ diff --git a/Notebooks/Figures/pr_rbias_ERA_vs_model.png b/Notebooks/Figures/pr_rbias_ERA_vs_model.png index 1e1d862f30f996697ce015d90b749eced48ac021..846e3ee7c4c9526b5a91914b298994a46072e555 100644 Binary files a/Notebooks/Figures/pr_rbias_ERA_vs_model.png and b/Notebooks/Figures/pr_rbias_ERA_vs_model.png differ diff --git a/Notebooks/Figures/tasmin_average_bias.png b/Notebooks/Figures/tasmin_average_bias.png index 5992128610af6a48020688efda2db8b3de06dc1b..457180f301641a28503c3351ed83f0a1127bd556 100644 Binary files a/Notebooks/Figures/tasmin_average_bias.png and b/Notebooks/Figures/tasmin_average_bias.png differ diff --git a/Notebooks/Figures/tasmin_average_bias_seasonal.png b/Notebooks/Figures/tasmin_average_bias_seasonal.png index 409f8619b787f46f453e09a84793c000f308719b..6b149b124fab7ab8923f78c9b28311d933c88063 100644 Binary files a/Notebooks/Figures/tasmin_average_bias_seasonal.png and b/Notebooks/Figures/tasmin_average_bias_seasonal.png differ diff --git a/Notebooks/Figures/tasmin_r2.png b/Notebooks/Figures/tasmin_r2.png index 5e6d822120b7de648334bb422cbcd2cdaec64105..397cf48d0de2fcd17f9a12aa5ea54c4c2c93936f 100644 Binary files a/Notebooks/Figures/tasmin_r2.png and b/Notebooks/Figures/tasmin_r2.png differ diff --git a/Notebooks/check_predictors.ipynb b/Notebooks/check_predictors.ipynb index dd442c4a0a7aefb114e7c77240247f0a6c1f6cdb..1c23caa16b5cd7e359f2e0e62bca699f8498aec1 100644 --- a/Notebooks/check_predictors.ipynb +++ b/Notebooks/check_predictors.ipynb @@ -17,8 +17,10 @@ "source": [ "# builtins\n", "import pathlib\n", + "import warnings\n", "\n", "# externals\n", + "import torch\n", "import numpy as np\n", "import pandas as pd\n", "import xarray as xr\n", @@ -63,11 +65,13 @@ "outputs": [], "source": [ "# define the predictor variables you want to use\n", - "ERA5_PREDICTORS = ['geopotential', 'temperature', 'mean_sea_level_pressure'] # use geopotential, temperature and pressure\n", + "# ERA5_PREDICTORS = ['geopotential', 'temperature', 'mean_sea_level_pressure'] # use geopotential, temperature and pressure\n", "\n", "# you can change this list as you wish, e.g.:\n", "# ERA5_PREDICTORS = ['geopotential', 'temperature'] # use only geopotential and temperature\n", "# ERA5_PREDICTORS = ERA5_VARIABLES # use all ERA5 variables as predictors\n", + "ERA5_PREDICTORS = ['geopotential', 'temperature', 'surface_pressure',\n", + " 'u_component_of_wind', 'v_component_of_wind', 'specific_humidity']\n", "\n", "# this checks if the variable names are correct\n", "assert all([p in ERA5_VARIABLES for p in ERA5_PREDICTORS]) " @@ -81,6 +85,16 @@ "### Use the climax package to load ERA5 predictors" ] }, + { + "cell_type": "code", + "execution_count": null, + "id": "3432b46f-29a9-4333-af13-05dc0643ce63", + "metadata": {}, + "outputs": [], + "source": [ + "CHUNKS = {'time': 16, 'x': 16, 'y': 16}" + ] + }, { "cell_type": "code", "execution_count": null, @@ -101,20 +115,48 @@ "source": [ "# create the xarray.Dataset of the specified predictor variables\n", "predictors = ERA5Dataset(ERA5_PATH, ERA5_PREDICTORS, plevels=PLEVELS)\n", - "predictors = predictors.merge(chunks=-1)" + "predictors = predictors.merge(chunks=CHUNKS)" ] }, { "cell_type": "code", "execution_count": null, - "id": "ea5c1c43-fbbf-4db1-926b-1102e8079348", + "id": "f4c6a816-a099-4c53-8d9f-79c07ca6626a", + "metadata": {}, + "outputs": [], + "source": [ + "# add the day of the year\n", + "predictors = predictors.assign(ERA5Dataset.encode_doys(predictors, chunks=predictors.chunks))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3d453b92-818f-4d6c-bc61-bdc5869b49ea", "metadata": {}, "outputs": [], "source": [ - "# check out the xarray.Dataset: you will see all the variables you specified\n", "predictors" ] }, + { + "cell_type": "markdown", + "id": "9e804201-1fd4-44a7-b63b-6e4d55b77e65", + "metadata": {}, + "source": [ + "## Normalize predictors" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "662e2639-2553-4711-b103-e6a94d4ab07c", + "metadata": {}, + "outputs": [], + "source": [ + "predictors = ERA5Dataset.normalize(predictors, period=CALIB_PERIOD)" + ] + }, { "cell_type": "markdown", "id": "14668e55-3b47-44ea-b3f7-596b0e62eec6", @@ -144,7 +186,8 @@ "dem = search_files(DEM_PATH, '^eu_dem_v11_stt.nc$').pop()\n", "\n", "# read elevation and compute slope and aspect\n", - "dem = ERA5Dataset.dem_features(dem, {'y': predictors.y, 'x': predictors.x}, add_coord={'time': predictors.time})" + "dem = ERA5Dataset.dem_features(dem, {'y': predictors.y, 'x': predictors.x}, add_coord={'time': predictors.time})\n", + "dem = dem.drop_vars(['slope', 'aspect']).chunk(predictors.chunks)" ] }, { @@ -157,6 +200,16 @@ "dem" ] }, + { + "cell_type": "code", + "execution_count": null, + "id": "0a545ff3-ab4f-4e6f-8990-af2c4e63be30", + "metadata": {}, + "outputs": [], + "source": [ + "dem = ERA5Dataset.normalize(dem, dim=('y', 'x'))" + ] + }, { "cell_type": "markdown", "id": "d55ca9b1-46a9-462a-ba42-fb7a38c823fa", @@ -175,6 +228,62 @@ "Era5_ds = xr.merge([predictors, dem])" ] }, + { + "cell_type": "code", + "execution_count": null, + "id": "01557918-05e0-4a24-991a-daf6d6bf34de", + "metadata": {}, + "outputs": [], + "source": [ + "Era5_ds" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "4df9a0ec-8689-4efb-a753-938f70cf548a", + "metadata": {}, + "outputs": [], + "source": [ + "# select calibration period and load data\n", + "%time calib = Era5_ds.sel(time=CALIB_PERIOD).compute()" + ] + }, + { + "cell_type": "markdown", + "id": "86f24b93-07df-4f8c-8c9f-e6cbb8213d0d", + "metadata": {}, + "source": [ + "## Compute standardized anomalies" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "4da51abd-8786-47da-9b14-1a6a036172d3", + "metadata": {}, + "outputs": [], + "source": [ + "groups = predictors.groupby('time.dayofyear').groups\n", + "anomalies = {}\n", + "for doy, days in groups.items():\n", + " with warnings.catch_warnings():\n", + " warnings.simplefilter('ignore', category=RuntimeWarning)\n", + " anomalies[doy] = (predictors.isel(time=days) - predictors.isel(time=days).mean(dim='time')) / predictors.isel(time=days).std(dim='time')\n", + "anomalies = xr.concat(anomalies.values(), dim='time')\n", + "anomalies = anomalies.sortby(anomalies.time)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c500c934-25dd-4dcc-bac3-1aef4ac27181", + "metadata": {}, + "outputs": [], + "source": [ + "anomalies" + ] + }, { "cell_type": "markdown", "id": "dd584313-8913-45b5-a379-2724d42bc99f", @@ -190,7 +299,7 @@ "metadata": {}, "outputs": [], "source": [ - "PREDICTAND = 'pr'" + "PREDICTAND = 'tasmin'" ] }, { diff --git a/Notebooks/debug_loss.ipynb b/Notebooks/debug_loss.ipynb new file mode 100644 index 0000000000000000000000000000000000000000..6e51b4574b55dd4afd570f343f1bd14ec13a5900 --- /dev/null +++ b/Notebooks/debug_loss.ipynb @@ -0,0 +1,298 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "id": "5b9c57b3-4748-4016-8990-840475c44392", + "metadata": {}, + "outputs": [], + "source": [ + "\"\"\"Dynamical climate downscaling using deep convolutional neural networks.\"\"\"\n", + "\n", + "# !/usr/bin/env python\n", + "# -*- coding: utf-8 -*-\n", + "\n", + "# builtins\n", + "import sys\n", + "import time\n", + "import logging\n", + "from datetime import timedelta\n", + "from logging.config import dictConfig\n", + "\n", + "# externals\n", + "import torch\n", + "import xarray as xr\n", + "from sklearn.model_selection import train_test_split\n", + "from torch.utils.data import DataLoader\n", + "\n", + "# locals\n", + "from pysegcnn.core.utils import search_files\n", + "from pysegcnn.core.trainer import NetworkTrainer, LogConfig\n", + "from pysegcnn.core.models import Network\n", + "from pysegcnn.core.logging import log_conf\n", + "from climax.core.dataset import ERA5Dataset, NetCDFDataset\n", + "from climax.core.loss import MSELoss, L1Loss, BernoulliWeibullLoss\n", + "from climax.main.config import (ERA5_PLEVELS, ERA5_PREDICTORS, PREDICTAND,\n", + " CALIB_PERIOD, DOY, SHUFFLE, BATCH_SIZE, LR,\n", + " LAMBDA, NORM, TRAIN_CONFIG, NET, LOSS, FILTERS,\n", + " OVERWRITE, DEM, DEM_FEATURES, STRATIFY,\n", + " WET_DAY_THRESHOLD, VALID_SIZE)\n", + "from climax.main.io import ERA5_PATH, OBS_PATH, DEM_PATH, MODEL_PATH\n", + "\n", + "# module level logger\n", + "LOGGER = logging.getLogger(__name__)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "40c70d2f-2e31-4ea8-adfa-5009b950e1e0", + "metadata": {}, + "outputs": [], + "source": [ + "# initialize timing\n", + "start_time = time.monotonic()\n", + "\n", + "# initialize network filename\n", + "state_file = ERA5Dataset.state_file(\n", + " NET, PREDICTAND, ERA5_PREDICTORS, ERA5_PLEVELS, dem=DEM,\n", + " dem_features=DEM_FEATURES, doy=DOY, loss=LOSS)\n", + "\n", + "# path to model state\n", + "state_file = MODEL_PATH.joinpath(PREDICTAND, state_file)\n", + "\n", + "# initialize logging\n", + "log_file = MODEL_PATH.joinpath(PREDICTAND,\n", + " state_file.name.replace('.pt', '_log.txt'))\n", + "if log_file.exists():\n", + " log_file.unlink()\n", + "dictConfig(log_conf(log_file))\n", + "\n", + "# initialize downscaling\n", + "LogConfig.init_log('Initializing downscaling for period: {}'.format(\n", + " ' - '.join([str(CALIB_PERIOD[0]), str(CALIB_PERIOD[-1])])))\n", + "\n", + "# check if model exists\n", + "if state_file.exists() and not OVERWRITE:\n", + " # load pretrained network\n", + " net, _ = Network.load_pretrained_model(state_file, NET)\n", + " sys.exit()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "582a53c1-78ff-4c56-9caf-3b1dd40ce1b3", + "metadata": {}, + "outputs": [], + "source": [ + "# initialize ERA5 predictor dataset\n", + "LogConfig.init_log('Initializing ERA5 predictors.')\n", + "Era5 = ERA5Dataset(ERA5_PATH.joinpath('ERA5'), ERA5_PREDICTORS,\n", + " plevels=ERA5_PLEVELS)\n", + "Era5_ds = Era5.merge(chunks=-1)\n", + "\n", + "# initialize OBS predictand dataset\n", + "LogConfig.init_log('Initializing observations for predictand: {}'\n", + " .format(PREDICTAND))\n", + "\n", + "# check whether to joinlty train tasmin and tasmax\n", + "if PREDICTAND == 'tas':\n", + " # read both tasmax and tasmin\n", + " tasmax = xr.open_dataset(\n", + " search_files(OBS_PATH.joinpath('tasmax'), '.nc$').pop())\n", + " tasmin = xr.open_dataset(\n", + " search_files(OBS_PATH.joinpath('tasmin'), '.nc$').pop())\n", + " Obs_ds = xr.merge([tasmax, tasmin])\n", + "else:\n", + " # read in-situ gridded observations\n", + " Obs_ds = search_files(OBS_PATH.joinpath(PREDICTAND), 'OBS_pr(.*).nc$').pop()\n", + " Obs_ds = xr.open_dataset(Obs_ds)\n", + "\n", + "# whether to use digital elevation model\n", + "if DEM:\n", + " # digital elevation model: Copernicus EU-Dem v1.1\n", + " dem = search_files(DEM_PATH, '^eu_dem_v11_stt.nc$').pop()\n", + "\n", + " # read elevation and compute slope and aspect\n", + " dem = ERA5Dataset.dem_features(\n", + " dem, {'y': Era5_ds.y, 'x': Era5_ds.x},\n", + " add_coord={'time': Era5_ds.time})\n", + "\n", + " # check whether to use slope and aspect\n", + " if not DEM_FEATURES:\n", + " dem = dem.drop_vars(['slope', 'aspect'])\n", + "\n", + " # add dem to set of predictor variables\n", + " Era5_ds = xr.merge([Era5_ds, dem])" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "54afbbab-728a-449a-9856-51e55352fafc", + "metadata": {}, + "outputs": [], + "source": [ + "# initialize training data\n", + "LogConfig.init_log('Initializing training data.')\n", + "\n", + "# split calibration period into training and validation period\n", + "if PREDICTAND == 'pr' and STRATIFY:\n", + " # stratify training and validation dataset by number of\n", + " # observed wet days for precipitation\n", + " wet_days = (Obs_ds.sel(time=CALIB_PERIOD).mean(dim=('y', 'x'))\n", + " >= WET_DAY_THRESHOLD).to_array().values.squeeze()\n", + " train, valid = train_test_split(\n", + " CALIB_PERIOD, stratify=wet_days, test_size=VALID_SIZE)\n", + "\n", + " # sort chronologically\n", + " train, valid = sorted(train), sorted(valid)\n", + "else:\n", + " train, valid = train_test_split(CALIB_PERIOD, shuffle=False,\n", + " test_size=VALID_SIZE)\n", + "\n", + "# training and validation dataset\n", + "Era5_train, Obs_train = Era5_ds.sel(time=train), Obs_ds.sel(time=train)\n", + "Era5_valid, Obs_valid = Era5_ds.sel(time=valid), Obs_ds.sel(time=valid)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e6b861a1-3b2a-49b8-8019-48fce4a42664", + "metadata": {}, + "outputs": [], + "source": [ + "# create PyTorch compliant dataset and dataloader instances for model\n", + "# training\n", + "train_ds = NetCDFDataset(Era5_train, Obs_train, normalize=NORM, doy=DOY)\n", + "# valid_ds = NetCDFDataset(Era5_valid, Obs_valid, normalize=NORM, doy=DOY)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "79537498-bed8-4d2e-a3d4-363f943c57cf", + "metadata": {}, + "outputs": [], + "source": [ + "train_dl = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=SHUFFLE,\n", + " drop_last=False)\n", + "# valid_dl = DataLoader(valid_ds, batch_size=BATCH_SIZE, shuffle=SHUFFLE,\n", + "# drop_last=False)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "a4e87b23-9a12-492c-9113-d44b404ac9d7", + "metadata": {}, + "outputs": [], + "source": [ + "PREDICTAND = 'pr'" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "192e1128-e02d-4ca1-8df4-bc4be57d3acc", + "metadata": {}, + "outputs": [], + "source": [ + "# initialize network and optimizer\n", + "LogConfig.init_log('Initializing network and optimizer.')\n", + "\n", + "# define number of output fields\n", + "# check whether modelling pr with probabilistic approach\n", + "outputs = len(Obs_ds.data_vars)\n", + "if PREDICTAND == 'pr':\n", + " outputs = (1 if (isinstance(LOSS, MSELoss) or isinstance(LOSS, L1Loss))\n", + " else 3)\n", + "\n", + "# instanciate network\n", + "inputs = len(Era5_ds.data_vars) + 2 if DOY else len(Era5_ds.data_vars)\n", + "net = NET(state_file, inputs, outputs, filters=FILTERS)\n", + "\n", + "# initialize optimizer\n", + "optimizer = torch.optim.Adam(net.parameters(), lr=LR,\n", + " weight_decay=LAMBDA)\n", + "# optimizer = torch.optim.SGD(net.parameters(), lr=LR, momentum=0.9,\n", + "# weight_decay=LAMBDA)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "be030225-7d90-484b-81c8-325578c8648b", + "metadata": {}, + "outputs": [], + "source": [ + "LOSS = BernoulliWeibullLoss(min_amount=1, reduction='mean')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "4c84b9ef-158f-421a-8abd-d241b5dca35b", + "metadata": {}, + "outputs": [], + "source": [ + "# forward pass\n", + "torch.autograd.set_detect_anomaly(True)\n", + "for inputs, labels in train_dl:\n", + " outputs = net(inputs)\n", + " loss = LOSS(outputs, labels)\n", + " print(loss)\n", + " loss.backward()\n", + " optimizer.step()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d42ffa8d-e172-4fce-950c-2df7a371bc4c", + "metadata": {}, + "outputs": [], + "source": [ + "# initialize network trainer\n", + "trainer = NetworkTrainer(net, optimizer, net.state_file, train_dl,\n", + " valid_dl, loss_function=LOSS, **TRAIN_CONFIG)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "aea217fb-b9fb-4a17-8135-2fb052101f39", + "metadata": {}, + "outputs": [], + "source": [ + "# train model\n", + "# check for anomalies when training\n", + "torch.autograd.set_detect_anomaly(True)\n", + "state = trainer.train()" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.7.10" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/Notebooks/eval_precipitation.ipynb b/Notebooks/eval_precipitation.ipynb index 40f103a120f92c99c415b1e1a0799083eb83e4b8..e1f85e50510435e9d9f838888664025d2dc7682e 100644 --- a/Notebooks/eval_precipitation.ipynb +++ b/Notebooks/eval_precipitation.ipynb @@ -29,7 +29,8 @@ "- Specific humidity (q)\n", "\n", "**Predictors on surface**:\n", - "- Mean sea level pressure (msl)\n", + "- Surface pressure (p)\n", + "- Total precipitation (pr)\n", "\n", "**Auxiliary predictors**:\n", "- Elevation from Copernicus EU-DEM v1.1 (dem)\n", @@ -59,13 +60,52 @@ "PLEVELS = ['500', '850']\n", "# PLEVELS = []\n", "SPREDICTORS = 'p'\n", + "# SPREDICTORS = 'ppr'\n", "DEM = 'dem'\n", "DEM_FEATURES = ''\n", - "DOY = ''\n", + "DOY = 'doy'\n", "WET_DAY_THRESHOLD = '1'\n", "# LOSS = 'MSELoss'\n", "LOSS = 'BernoulliGammaLoss'\n", - "SEASON = 'season'" + "# LOSS = 'BernoulliWeibullLoss'\n", + "OPTIM = 'SGD'\n", + "# OPTIM = 'Adam'\n", + "SEASON = ''\n", + "DECAY = 1e-02\n", + "LR = 5e-03" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9f4b5c0e-670f-40ba-9e38-00c38e7e7855", + "metadata": {}, + "outputs": [], + "source": [ + "# construct file pattern to match\n", + "PATTERN = '_'.join([MODEL, PREDICTAND, PPREDICTORS, *PLEVELS])\n", + "PATTERN = '_'.join([PATTERN, SPREDICTORS]) if SPREDICTORS else PATTERN\n", + "PATTERN = '_'.join([PATTERN, DEM]) if DEM else PATTERN\n", + "PATTERN = '_'.join([PATTERN, DEM_FEATURES]) if DEM_FEATURES else PATTERN\n", + "PATTERN = '_'.join([PATTERN, DOY]) if DOY else PATTERN\n", + "PATTERN = '_'.join([PATTERN, '{}mm'.format(str(WET_DAY_THRESHOLD).replace('.', ''))]) if WET_DAY_THRESHOLD else PATTERN\n", + "PATTERN = '_'.join([PATTERN, LOSS])\n", + "PATTERN = '_'.join([PATTERN, OPTIM])\n", + "PATTERN = '_'.join([PATTERN, SEASON]) if SEASON else PATTERN\n", + "PATTERN = '_'.join([PATTERN, 'd{:.0e}'.format(DECAY)])\n", + "PATTERN = '_'.join([PATTERN, 'lr{:.0e}'.format(LR)])\n", + "PATTERN" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e38fabfa-60ca-4eb3-80da-dc646404fb0f", + "metadata": {}, + "outputs": [], + "source": [ + "# whether to search for the defined pattern\n", + "SEARCH = False" ] }, { @@ -73,7 +113,7 @@ "id": "dd188df0-69ee-44b0-82b2-d212994dc271", "metadata": {}, "source": [ - "### Imports" + "## Imports" ] }, { @@ -89,18 +129,22 @@ "import calendar\n", "\n", "# externals\n", + "import torch\n", "import xarray as xr\n", "import numpy as np\n", "import matplotlib.pyplot as plt\n", + "from matplotlib import cm\n", + "from matplotlib.colors import ListedColormap\n", "from mpl_toolkits.axes_grid1.inset_locator import inset_axes\n", "import scipy.stats as stats\n", "from IPython.display import Image\n", - "from sklearn.metrics import r2_score, roc_curve, auc, classification_report\n", + "from sklearn.metrics import r2_score, roc_curve, auc\n", "\n", "# locals\n", - "from climax.main.io import ERA5_PATH, OBS_PATH, TARGET_PATH, DEM_PATH\n", - "from pysegcnn.core.utils import search_files\n", - "from pysegcnn.core.graphics import plot_classification_report" + "from climax.main.io import ERA5_PATH, OBS_PATH, TARGET_PATH, DEM_PATH, MODEL_PATH\n", + "from climax.core.utils import plot_loss\n", + "from climax.core.dataset import ERA5Dataset\n", + "from pysegcnn.core.utils import search_files" ] }, { @@ -116,7 +160,7 @@ }, { "cell_type": "markdown", - "id": "a3d92474-2bc1-4035-8938-c5fbb07ae891", + "id": "9ca58029-ed0a-4aab-ba67-b0d5d96a12c2", "metadata": {}, "source": [ "### Model architecture" @@ -125,61 +169,13 @@ { "cell_type": "code", "execution_count": null, - "id": "c68e022b-41c2-438b-bfb0-e27eddee89bf", + "id": "9b083287-34d0-4e31-aeb0-4c0005cf3e83", "metadata": {}, "outputs": [], "source": [ "Image(\"./Figures/architecture.png\", width=900, height=400)" ] }, - { - "cell_type": "markdown", - "id": "c8833efe-c715-4872-9aee-a0b5766f5c67", - "metadata": {}, - "source": [ - "### Loss function" - ] - }, - { - "cell_type": "markdown", - "id": "bb801367-6872-4a0b-bff5-70cb6746e057", - "metadata": {}, - "source": [ - "For precipitation, the network is optimizing the negative log-likelihood of a Bernoulli-Gamma distribution after [Cannon (2008)](http://journals.ametsoc.org/doi/10.1175/2008JHM960.1)." - ] - }, - { - "cell_type": "markdown", - "id": "a8775d5e-5ad4-47e4-8fef-6dc230e15dee", - "metadata": {}, - "source": [ - "Bernoulli-Gamma distribution:" - ] - }, - { - "cell_type": "markdown", - "id": "ab10f8de-d8d2-4427-b9c8-5d68803543c3", - "metadata": {}, - "source": [ - "$$P(y \\mid, p, \\alpha, \\beta) = \\begin{cases} 1 - p, & \\text{for } y = 0\\\\ p \\cdot \\frac{y^{\\alpha -1} \\exp(-y/\\beta)}{\\beta^{\\alpha} \\tau(\\alpha)}, & \\text{for } y > 0\\end{cases}$$" - ] - }, - { - "cell_type": "markdown", - "id": "6b6dbd06-1f0e-4c52-84f2-b7ff31c75726", - "metadata": {}, - "source": [ - "Log-likelihood function:" - ] - }, - { - "cell_type": "markdown", - "id": "e41c7b39-f352-4a98-820f-9a7345b3283c", - "metadata": {}, - "source": [ - "$$\\mathcal{J}(p, \\alpha, \\beta \\mid y) = \\underbrace{(1 - P(y > 0)) \\log(1 - p)}_{\\text{Bernoulli}} + \\underbrace{P(y > 0) \\cdot \\left(\\log(p) + (\\alpha - 1) \\log(y) - \\frac{y}{\\beta} - \\alpha \\log(\\beta) - \\log(\\tau(\\alpha))\\right)}_{\\text{Gamma}}$$" - ] - }, { "cell_type": "markdown", "id": "5a0c55f0-79fb-4501-b3cf-b5414399a3d9", @@ -191,32 +187,33 @@ { "cell_type": "code", "execution_count": null, - "id": "efa76f8e-c089-47ff-a001-d4c2a11c4d6d", + "id": "ecaba394-f802-481f-8274-c44e2f5fdf1a", "metadata": {}, "outputs": [], "source": [ - "# construct file pattern to match\n", - "PATTERN = '_'.join([MODEL, PREDICTAND, PPREDICTORS, *PLEVELS])\n", - "PATTERN = '_'.join([PATTERN, SPREDICTORS]) if SPREDICTORS else PATTERN\n", - "PATTERN = '_'.join([PATTERN, DEM]) if DEM else PATTERN\n", - "PATTERN = '_'.join([PATTERN, DEM_FEATURES]) if DEM_FEATURES else PATTERN\n", - "PATTERN = '_'.join([PATTERN, DOY]) if DOY else PATTERN\n", - "PATTERN = '_'.join([PATTERN, '{}mm'.format(str(WET_DAY_THRESHOLD).replace('.', ''))]) if WET_DAY_THRESHOLD else PATTERN\n", - "PATTERN = '_'.join([PATTERN, LOSS])\n", - "PATTERN = '_'.join([PATTERN, SEASON]) if SEASON else PATTERN\n", - "PATTERN" + "# digital elevation model\n", + "dem = xr.open_dataset(search_files(DEM_PATH, 'eu_dem_v11_stt.nc').pop())\n", + "dem = dem.Band1.to_dataset().rename({'Band1': 'elevation'})" ] }, { "cell_type": "code", "execution_count": null, - "id": "ecaba394-f802-481f-8274-c44e2f5fdf1a", + "id": "859ec647-6010-4526-bd96-edb22e336c65", "metadata": {}, "outputs": [], "source": [ - "# digital elevation model\n", - "dem = xr.open_dataset(search_files(DEM_PATH, 'eu_dem_v11_stt.nc').pop())\n", - "dem = dem.Band1.to_dataset().rename({'Band1': 'elevation'})" + "# model predictions\n", + "if SEARCH:\n", + " y_pred = xr.open_dataset(search_files(TARGET_PATH.joinpath(PREDICTAND), '.'.join([PATTERN, 'nc$'])).pop())\n", + "else:\n", + " filename = 'USegNet_tasmin_ztuvq_500_850_p_dem_doy_L1Loss_Adam_d1e-03_lr3e-04.nc'\n", + " file = search_files(TARGET_PATH, filename).pop()\n", + " y_pred = xr.open_dataset(file)\n", + "try:\n", + " y_pred = y_pred.rename({'pr': 'precipitation'})\n", + "except ValueError:\n", + " pass" ] }, { @@ -228,13 +225,8 @@ }, "outputs": [], "source": [ - "# model predictions and observations NetCDF \n", - "y_pred = xr.open_dataset(search_files(TARGET_PATH.joinpath(PREDICTAND), '.'.join([PATTERN, 'nc$'])).pop())\n", - "y_true = xr.open_dataset(search_files(OBS_PATH.joinpath(PREDICTAND), 'OBS_pr(.*).nc$').pop())\n", - "try:\n", - " y_pred = y_pred.rename({'pr': 'precipitation'})\n", - "except ValueError:\n", - " pass" + "# observations\n", + "y_true = xr.open_dataset(search_files(OBS_PATH.joinpath(PREDICTAND), 'OBS_pr(.*).nc$').pop())" ] }, { @@ -268,7 +260,7 @@ "metadata": {}, "outputs": [], "source": [ - "# align datasets\n", + "# aligndata_varsasets\n", "if len(y_pred.data_vars) > 1:\n", " y_true, y_refe, y_pred_pr, y_pred_prob = xr.align(y_true.precipitation, y_refe.precipitation, y_pred.precipitation, y_pred.prob, join='override')\n", "else:\n", @@ -304,67 +296,112 @@ }, { "cell_type": "markdown", - "id": "b269a131-cf5b-4c6c-9f8e-a5408250aa83", + "id": "6232ec70-f595-401e-bddd-015eeb606ab0", "metadata": {}, "source": [ - "## Model validation: precipitation amount" + "## Model convergence" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "0e3ee22c-7ec1-44c2-9152-c10726621848", + "metadata": {}, + "outputs": [], + "source": [ + "# plot model state\n", + "model_state = MODEL_PATH.joinpath(PREDICTAND, '.'.join([PATTERN, 'pt']))\n", + "try:\n", + " fig = plot_loss(model_state, step=5)\n", + "except FileNotFoundError:\n", + " pass" ] }, { "cell_type": "markdown", - "id": "0fa1fe82-0d6e-4676-b5b8-9eac2fd28ffb", + "id": "b269a131-cf5b-4c6c-9f8e-a5408250aa83", "metadata": {}, "source": [ - "### Coefficient of determination: monthly mean" + "## Model validation: precipitation amount" ] }, { "cell_type": "code", "execution_count": null, - "id": "4d11cee6-1ebd-4424-8ec6-92e5a196bac4", + "id": "b73658af-d994-40a9-a41c-5b434e45504a", "metadata": {}, "outputs": [], "source": [ - "# calculate monthly mean precipitation (mm / month)\n", - "y_pred_values = y_pred_pr.resample(time='1M').sum(skipna=False).groupby('time.month').mean(dim='time').values\n", - "y_true_values = y_true.resample(time='1M').sum(skipna=False).groupby('time.month').mean(dim='time').values" + "# calculate monthly precipitation (mm / month)\n", + "y_pred_m = y_pred_pr.resample(time='1M').sum(skipna=False)\n", + "y_true_m = y_true.resample(time='1M').sum(skipna=False)\n", + "y_refe_m = y_refe.resample(time='1M').sum(skipna=False)" ] }, { "cell_type": "code", "execution_count": null, - "id": "6e55cdd5-8adb-40c5-9742-46f4fc3d4be9", + "id": "93bd8fd9-77c7-40ed-bd31-0d9a3ce86a12", "metadata": {}, "outputs": [], "source": [ - "# apply mask of valid pixels\n", - "mask = (~np.isnan(y_pred_values) & ~np.isnan(y_true_values))\n", - "y_pred_values = y_pred_values[mask]\n", - "y_true_values = y_true_values[mask]" + "# calculate mean annual cycle\n", + "y_pred_ac = y_pred_m.groupby('time.month').mean(dim=('y', 'x', 'time'), skipna=True).values.squeeze()\n", + "y_true_ac = y_true_m.groupby('time.month').mean(dim=('y', 'x', 'time'), skipna=True).values.squeeze()" ] }, { "cell_type": "code", "execution_count": null, - "id": "13a9ff21-34ea-4db1-9c0a-bcb7ea1001f7", + "id": "8ed529fb-ec3f-4c40-958a-4a7eaec8b25c", "metadata": {}, "outputs": [], "source": [ - "# calculate coefficient of determination\n", - "r2 = r2_score(y_true_values, y_pred_values)\n", - "r2" + "# compute daily anomalies\n", + "y_pred_anom = ERA5Dataset.anomalies(y_pred_pr, timescale='time.month')\n", + "y_true_anom = ERA5Dataset.anomalies(y_true, timescale='time.month')" + ] + }, + { + "cell_type": "markdown", + "id": "0fa1fe82-0d6e-4676-b5b8-9eac2fd28ffb", + "metadata": {}, + "source": [ + "### Coefficient of determination" ] }, { "cell_type": "code", "execution_count": null, - "id": "a1b33431-9f2b-4e42-bbb0-f4fa258edd98", + "id": "ea73bb99-886c-40e9-913e-7ec25418013e", "metadata": {}, "outputs": [], "source": [ - "# group timeseries by month and calculate mean over time and space\n", - "y_pred_ac = y_pred_pr.resample(time='1M').sum(skipna=False).groupby('time.month').mean(dim=('y', 'x', 'time'), skipna=True).values.squeeze()\n", - "y_true_ac = y_true.resample(time='1M').sum(skipna=False).groupby('time.month').mean(dim=('y', 'x', 'time'), skipna=True).values.squeeze()" + "# get predicted and observed daily anomalies\n", + "y_pred_av = y_pred_anom.values.flatten()\n", + "y_true_av = y_true_anom.values.flatten()\n", + "\n", + "# apply mask of valid pixels\n", + "mask = (~np.isnan(y_pred_av) & ~np.isnan(y_true_av))\n", + "y_pred_av = y_pred_av[mask]\n", + "y_true_av = y_true_av[mask]\n", + "\n", + "# get predicted and observed monthly means\n", + "y_pred_mv = y_pred_m.values.flatten()\n", + "y_true_mv = y_true_m.values.flatten()\n", + "\n", + "# apply mask of valid pixels\n", + "mask = (~np.isnan(y_pred_mv) & ~np.isnan(y_true_mv))\n", + "y_pred_mv = y_pred_mv[mask]\n", + "y_true_mv = y_true_mv[mask]\n", + "\n", + "# calculate coefficient of determination on monthly means\n", + "r2_mm = r2_score(y_true_mv, y_pred_mv)\n", + "print('R2 on monthly means: {:.2f}'.format(r2_mm))\n", + "\n", + "# calculate coefficient of determination on daily anomalies\n", + "r2_anom = r2_score(y_true_av, y_pred_av)\n", + "print('R2 on daily anomalies: {:.2f}'.format(r2_anom))" ] }, { @@ -382,14 +419,15 @@ "# ax.plot(y_true_values[subset], y_pred_values[subset], 'o', alpha=.5, markeredgecolor='grey', markerfacecolor='none', markersize=3);\n", "\n", "# plot entire dataset\n", - "ax.plot(y_true_values, y_pred_values, 'o', alpha=.5, markeredgecolor='grey', markerfacecolor='none', markersize=3);\n", + "ax.plot(y_true_mv, y_pred_mv, 'o', alpha=.5, markeredgecolor='grey', markerfacecolor='none', markersize=3);\n", "\n", "# plot 1:1 mapping line\n", - "interval = np.arange(0, 300, 50)\n", + "interval = np.arange(0, 750, 50)\n", "ax.plot(interval, interval, color='k', lw=2, ls='--')\n", "\n", - "# add coefficient of determination: calculated on entire dataset!\n", - "ax.text(interval[-1] - 2, interval[0] + 2, s='Coefficient of determination R$^2$ = {:.2f}'.format(r2), ha='right', fontsize=18)\n", + "# add coefficients of determination\n", + "ax.text(interval[-1] - 0.5, interval[0] + 2, s='R$^2$ (monthly means)= {:.2f}'.format(r2_mm), ha='right', fontsize=18)\n", + "ax.text(interval[-1] - 0.5, interval[0] + 27.5, s='R$^2$ (daily anomalies) = {:.2f}'.format(r2_anom), ha='right', fontsize=18)\n", "\n", "# format axes\n", "ax.set_ylim(interval[0], interval[-1])\n", @@ -403,7 +441,7 @@ "ax.set_title('Monthly mean {} (mm / month)'.format(NAMES[PREDICTAND]), fontsize=20, pad=10);\n", "\n", "# add axis for annual cycle\n", - "axins = inset_axes(ax, width=\"30%\", height=\"40%\", loc=2, borderpad=0.25)\n", + "axins = inset_axes(ax, width=\"30%\", height=\"40%\", loc=2, borderpad=0.5)\n", "axins.plot(y_pred_ac, ls='--', color='k', label='Predicted')\n", "axins.plot(y_true_ac, ls='-', color='k', label='Observed')\n", "axins.legend(frameon=False, fontsize=12, loc='lower center');\n", @@ -414,7 +452,91 @@ "axins.set_xticklabels([calendar.month_name[i + 1] for i in np.arange(0, 12)], rotation=90, fontsize=12)\n", "\n", "# save figure\n", - "fig.savefig('../Notebooks/Figures/{}_r2.png'.format(PREDICTAND), dpi=300, bbox_inches='tight')" + "fig.savefig('../Notebooks/Figures/{}_r2_{}.png'.format(PREDICTAND, LOSS), dpi=300, bbox_inches='tight')" + ] + }, + { + "cell_type": "markdown", + "id": "0f0b28e8-742d-4342-a01b-04fc5f736026", + "metadata": {}, + "source": [ + "### Coefficient of determination: Spatially" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "db5a5180-e1fb-4be6-8970-532064fc9298", + "metadata": {}, + "outputs": [], + "source": [ + "# iterate over the grid points\n", + "r2 = np.ones((2, len(y_pred_m.x), len(y_pred_m.y)), dtype=np.float32) * np.nan\n", + "for i, _ in enumerate(y_pred_m.x):\n", + " for j, _ in enumerate(y_pred_m.y):\n", + " # get observed and predicted monthly precipitation for current grid point\n", + " point_true = y_true_m.isel(x=i, y=j)\n", + " point_pred = y_pred_m.isel(x=i, y=j)\n", + " \n", + " # remove missing values\n", + " mask = ((~np.isnan(point_true)) & (~np.isnan(point_pred)))\n", + " point_true = point_true[mask].values\n", + " point_pred = point_pred[mask].values\n", + " if point_true.size < 1:\n", + " continue\n", + " \n", + " # get anomalies for current grid point\n", + " point_anom_true = y_true_anom.isel(x=i, y=j)\n", + " point_anom_pred = y_pred_anom.isel(x=i, y=j)\n", + " \n", + " # remove missing values\n", + " mask_anom = ((~np.isnan(point_anom_true)) & (~np.isnan(point_anom_pred)))\n", + " point_anom_true = point_anom_true[mask_anom].values\n", + " point_anom_pred = point_anom_pred[mask_anom].values\n", + "\n", + " # compute coefficient of determination\n", + " r2[0, j, i] = r2_score(point_true, point_pred)\n", + " r2[1, j, i] = r2_score(point_anom_true, point_anom_pred)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c6a23548-77b7-4acb-b9ea-e2404c1c2256", + "metadata": {}, + "outputs": [], + "source": [ + "# define color map: red to green\n", + "grn = cm.get_cmap('Greens', 128)\n", + "red = cm.get_cmap('Reds_r', 128)\n", + "red2green = ListedColormap(np.vstack((red(np.linspace(0, 1, 128)),\n", + " grn(np.linspace(0, 1, 128)))))\n", + "\n", + "# plot coefficients of determination\n", + "vmin, vmax = -1, 1\n", + "fig, ax = plt.subplots(1, 2, figsize=(10, 10))\n", + "\n", + "# monthly means\n", + "im0 = ax[0].imshow(r2[0, :], origin='lower', cmap=red2green, vmin=vmin, vmax=vmax)\n", + "ax[0].text(x=r2.shape[1] - 2, y=2, s='Average R$^2$: {:.2f}'.format(np.nanmean(r2[0, :])), fontsize=14, ha='right');\n", + "ax[0].set_axis_off()\n", + "ax[0].set_title('Monthly mean {} (mm / month)'.format(NAMES[PREDICTAND]), fontsize=14, pad=10);\n", + "\n", + "# daily anomalies\n", + "im1 = ax[1].imshow(r2[1, :], origin='lower', cmap=red2green, vmin=vmin, vmax=vmax)\n", + "ax[1].text(x=r2.shape[1] - 2, y=2, s='Average R$^2$: {:.2f}'.format(np.nanmean(r2[1, :])), fontsize=14, ha='right');\n", + "ax[1].set_axis_off()\n", + "ax[1].set_title('Daily {} anomaly (mm / day)'.format(NAMES[PREDICTAND]), fontsize=14, pad=10);\n", + "\n", + "# add colorbar \n", + "cbar_ax_bias = fig.add_axes([ax[1].get_position().x1 + 0.05, ax[1].get_position().y0,\n", + " 0.03, ax[1].get_position().y1 - ax[1].get_position().y0])\n", + "cbar_bias = fig.colorbar(im0, cax=cbar_ax_bias)\n", + "cbar_bias.set_label(label='Coefficient of determination R$^2$', fontsize=14)\n", + "cbar_bias.ax.tick_params(labelsize=14, pad=10)\n", + "\n", + "# save figure\n", + "fig.savefig('../Notebooks/Figures/{}_r2_spatial_{}.png'.format(PREDICTAND, LOSS), dpi=300, bbox_inches='tight')" ] }, { @@ -440,7 +562,7 @@ "metadata": {}, "outputs": [], "source": [ - "# yearly average bias over reference period\n", + "# average bias of daily precipitation over reference period\n", "y_pred_yearly_avg = y_pred_pr.groupby('time.year').mean(dim='time')\n", "y_refe_yearly_avg = y_refe.groupby('time.year').mean(dim='time')\n", "y_true_yearly_avg = y_true.groupby('time.year').mean(dim='time')\n", @@ -472,8 +594,8 @@ "outputs": [], "source": [ "# root mean squared error over reference period\n", - "rmse_avg = ((y_pred_yearly_avg - y_true_yearly_avg) ** 2).mean()\n", - "rmse_avg_ref = ((y_refe_yearly_avg - y_true_yearly_avg) **2).mean()\n", + "rmse_avg = np.sqrt(((y_pred_yearly_avg - y_true_yearly_avg) ** 2).mean())\n", + "rmse_avg_ref = np.sqrt(((y_refe_yearly_avg - y_true_yearly_avg) **2).mean())\n", "print('(Model) Yearly average RMSE: {:.2f} mm / day'.format(rmse_avg.item()))\n", "print('(ERA-5) Yearly average RMSE: {:.2f} mm / day'.format(rmse_avg_ref.item()))" ] @@ -490,7 +612,10 @@ " y_p = y_pred_yearly_avg.sel(year=year).values \n", " y_t = y_true_yearly_avg.sel(year=year).values\n", " r, _ = stats.pearsonr(y_p[~np.isnan(y_p)], y_t[~np.isnan(y_t)])\n", - " print('({:0d}) Pearson correlation: {:.2f}'.format(year.item(), np.asarray(r).mean()))" + " print('({:0d}) Pearson correlation: {:.2f}'.format(year.item(), np.asarray(r).mean()))\n", + "r, _ = stats.pearsonr(y_pred_yearly_avg.values[~np.isnan(y_pred_yearly_avg.values)],\n", + " y_true_yearly_avg.values[~np.isnan(y_true_yearly_avg.values)])\n", + "print('Total: {:.2f}'.format(r))" ] }, { @@ -500,18 +625,18 @@ "metadata": {}, "outputs": [], "source": [ - "# plot yearly average MAE of reference vs. prediction\n", - "vmin, vmax = 0, 5\n", + "# plot yearly average bias of reference vs. prediction\n", + "vmin, vmax = -40, 40\n", "fig, axes = plt.subplots(1, 3, figsize=(24, 8), sharex=True, sharey=True)\n", "\n", "# plot bias of ERA-5 reference\n", "reference = bias_yearly_avg_ref.mean(dim='year')\n", - "im1 = axes[0].imshow(reference.values, origin='lower', cmap='RdBu_r', vmin=-40, vmax=40)\n", + "im1 = axes[0].imshow(reference.values, origin='lower', cmap='RdBu_r', vmin=vmin, vmax=vmax)\n", "axes[0].text(x=reference.shape[0] - 2, y=2, s='Average: {:.1f}%'.format(reference.mean().item()), fontsize=14, ha='right')\n", "\n", "# plot MAE of model\n", "prediction = bias_yearly_avg.mean(dim='year')\n", - "im2 = axes[1].imshow(prediction.values, origin='lower', cmap='RdBu_r', vmin=-40, vmax=40)\n", + "im2 = axes[1].imshow(prediction.values, origin='lower', cmap='RdBu_r', vmin=vmin, vmax=vmax)\n", "axes[1].text(x=reference.shape[0] - 2, y=2, s='Average: {:.1f}%'.format(prediction.mean().item()), fontsize=14, ha='right')\n", "\n", "# plot topography\n", @@ -558,7 +683,7 @@ "#axes[1].text(x=ds.shape[0] - 2, y=12, s='RMSE = {:.1f}'.format(rmse_avg[NAMES[PREDICTAND]].item()) + 'mm$^2$ day$^{-2}$', fontsize=14, ha='right')\n", "\n", "# save figure\n", - "fig.savefig('../Notebooks/Figures/{}_rbias_ERA_vs_model.png'.format(PREDICTAND), dpi=300, bbox_inches='tight')" + "fig.savefig('../Notebooks/Figures/{}_rbias_ERA_{}.png'.format(PREDICTAND, LOSS), dpi=300, bbox_inches='tight')" ] }, { @@ -620,7 +745,7 @@ "axes[1].text(x=ds.shape[0] - 2, y=12, s='RMSE = {:.1f}'.format(rmse_avg.mean().item()) + 'mm day$^{-1}$', fontsize=14, ha='right')\n", "\n", "# save figure\n", - "fig.savefig('../Notebooks/Figures/{}_average_bias.png'.format(PREDICTAND), dpi=300, bbox_inches='tight')" + "fig.savefig('../Notebooks/Figures/{}_bias_{}.png'.format(PREDICTAND, LOSS), dpi=300, bbox_inches='tight')" ] }, { @@ -736,7 +861,7 @@ "cbar.ax.tick_params(labelsize=14)\n", "\n", "# save figure\n", - "fig.savefig('../Notebooks/Figures/{}_average_bias_seasonal.png'.format(PREDICTAND), dpi=300, bbox_inches='tight')" + "fig.savefig('../Notebooks/Figures/{}_bias_seasonal_{}.png'.format(PREDICTAND, LOSS), dpi=300, bbox_inches='tight')" ] }, { @@ -849,8 +974,8 @@ "outputs": [], "source": [ "# root mean squared error in extreme quantile\n", - "rmse_ex = ((y_pred_ex - y_true_ex) ** 2).mean()\n", - "rmse_ex_ref = ((y_refe_ex - y_true_ex) ** 2).mean()" + "rmse_ex = np.sqrt(((y_pred_ex - y_true_ex) ** 2).mean())\n", + "rmse_ex_ref = np.sqrt(((y_refe_ex - y_true_ex) ** 2).mean())" ] }, { @@ -884,9 +1009,9 @@ "source": [ "# plot extremes of observation, prediction, and bias\n", "vmin, vmax = 10, 40\n", - "fig, axes = plt.subplots(1, 3, figsize=(24, 6), sharex=True, sharey=True)\n", - "axes = axes.reshape(1, -1)\n", - "for ds, ax in zip([y_true_ex, y_pred_ex, bias_ex], axes[i, ...]):\n", + "fig, axes = plt.subplots(1, 3, figsize=(24, 8), sharex=True, sharey=True)\n", + "axes = axes.flatten()\n", + "for ds, ax in zip([y_true_ex, y_pred_ex, bias_ex], axes):\n", " if ds is bias_ex:\n", " ds = ds.mean(dim='year')\n", " im2 = ax.imshow(ds.values, origin='lower', cmap='RdBu_r', vmin=-40, vmax=40)\n", @@ -895,9 +1020,9 @@ " im1 = ax.imshow(ds.mean(dim='year').values, origin='lower', cmap='BuPu', vmin=vmin, vmax=vmax)\n", " \n", "# set titles\n", - "axes[0, 0].set_title('Observed', fontsize=16, pad=10);\n", - "axes[0, 1].set_title('Predicted', fontsize=16, pad=10);\n", - "axes[0, 2].set_title('Bias', fontsize=16, pad=10);\n", + "axes[0].set_title('Observed', fontsize=16, pad=10);\n", + "axes[1].set_title('Predicted', fontsize=16, pad=10);\n", + "axes[2].set_title('Bias', fontsize=16, pad=10);\n", "\n", "# adjust axes\n", "for ax in axes.flat:\n", @@ -934,7 +1059,7 @@ "axes[1].text(x=ds.shape[0] - 2, y=12, s='RMSE = {:.1f}'.format(rmse_ex.item()) + 'mm$^2$ day$^{-2}$', fontsize=14, ha='right')\n", "\n", "# save figure\n", - "fig.savefig('../Notebooks/Figures/{}_average_bias_p{:.0f}.png'.format(PREDICTAND, quantile * 100), dpi=300, bbox_inches='tight')" + "fig.savefig('../Notebooks/Figures/{}_bias_p{:.0f}_{}.png'.format(PREDICTAND, quantile * 100, LOSS), dpi=300, bbox_inches='tight')" ] }, { @@ -1044,7 +1169,7 @@ "cbar.ax.tick_params(labelsize=14)\n", "\n", "# save figure\n", - "fig.savefig('../Notebooks/Figures/{}_average_bias_seasonal_ex.png'.format(PREDICTAND), dpi=300, bbox_inches='tight')" + "fig.savefig('../Notebooks/Figures/{}_bias_p{:.0f}_seasonal_{}.png'.format(PREDICTAND, quantile * 100, LOSS), dpi=300, bbox_inches='tight')" ] }, { @@ -1137,12 +1262,12 @@ "outputs": [], "source": [ "# plot average of observation, prediction, and bias\n", - "fig, axes = plt.subplots(2, 3, figsize=(24, 12), sharex=True, sharey=True)\n", + "fig, axes = plt.subplots(2, 3, figsize=(24, 16), sharex=True, sharey=True)\n", "axes = axes.flatten()\n", "\n", "# plot annual average bias of extreme\n", "ds = bias_wet\n", - "im = axes[0].imshow(ds.values, origin='lower', cmap='RdBu_r', vmin=-40, vmax=40)\n", + "im = axes[0].imshow(ds.values, origin='lower', cmap='RdBu_r', vmin=-100, vmax=100)\n", "axes[0].set_title('Annual', fontsize=16);\n", "axes[0].text(x=ds.shape[0] - 2, y=2, s='Average: {:.1f}%'.format(ds.mean().item()), fontsize=14, ha='right')\n", "\n", @@ -1178,7 +1303,7 @@ "cbar_predictand.ax.tick_params(labelsize=14)\n", "\n", "# save figure\n", - "fig.savefig('../Notebooks/Figures/{}_bias_wet_days.png'.format(PREDICTAND), dpi=300, bbox_inches='tight')" + "fig.savefig('../Notebooks/Figures/{}_bias_wd_{}.png'.format(PREDICTAND, LOSS), dpi=300, bbox_inches='tight')" ] }, { @@ -1220,7 +1345,7 @@ "outputs": [], "source": [ "# plot average of observation, prediction, and bias\n", - "fig, axes = plt.subplots(1, 3, figsize=(24, 6), sharex=True, sharey=True)\n", + "fig, axes = plt.subplots(1, 3, figsize=(24, 8), sharex=True, sharey=True)\n", "for ds, ax in zip([dii_true, dii_pred, bias_dii], axes):\n", " if ds is bias_dii:\n", " im2 = ax.imshow(ds.values, origin='lower', cmap='RdBu_r', vmin=-40, vmax=40)\n", @@ -1264,7 +1389,7 @@ "cbar_predictand.ax.tick_params(labelsize=14)\n", "\n", "# save figure\n", - "fig.savefig('../Notebooks/Figures/{}_bias_wet_days_p.png'.format(PREDICTAND), dpi=300, bbox_inches='tight')" + "fig.savefig('../Notebooks/Figures/{}_bias_wdp_{}.png'.format(PREDICTAND, LOSS), dpi=300, bbox_inches='tight')" ] }, { @@ -1291,7 +1416,7 @@ "outputs": [], "source": [ "# true and predicted probability of precipitation\n", - "p_true = (y_true > float(WET_DAY_THRESHOLD)).values.flatten()\n", + "p_true = (y_true >= float(WET_DAY_THRESHOLD)).values.flatten()\n", "p_pred = y_pred_prob.values.flatten()" ] }, @@ -1360,13 +1485,13 @@ "ax.legend(frameon=False, loc='lower right', fontsize=14);\n", "\n", "# save figure\n", - "fig.savefig('../Notebooks/Figures/{}_ROC.png'.format(PREDICTAND), dpi=300, bbox_inches='tight')" + "fig.savefig('../Notebooks/Figures/{}_ROC_{}.png'.format(PREDICTAND, LOSS), dpi=300, bbox_inches='tight')" ] } ], "metadata": { "kernelspec": { - "display_name": "Python 3", + "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, diff --git a/Notebooks/eval_sensitivity.ipynb b/Notebooks/eval_sensitivity.ipynb new file mode 100644 index 0000000000000000000000000000000000000000..6ab07d1003938f6c1ea152669d57de2387d28fc2 --- /dev/null +++ b/Notebooks/eval_sensitivity.ipynb @@ -0,0 +1,558 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "f2bfabd5-c8d6-46aa-b340-c169cc03bee7", + "metadata": {}, + "source": [ + "# Precipitation: Sensitivity analysis of hyperparameters" + ] + }, + { + "cell_type": "markdown", + "id": "9f024696-3cb3-4637-ab30-c57df733aeed", + "metadata": {}, + "source": [ + "We used **1981-1991 as training** period and **1991-2010 as reference** period. The results shown in this notebook are based on the model predictions on the reference period." + ] + }, + { + "cell_type": "markdown", + "id": "17ca3513-5757-470c-9d3c-7494b5bade56", + "metadata": {}, + "source": [ + "**Predictors on pressure levels (500, 850)**:\n", + "- Geopotential (z)\n", + "- Temperature (t)\n", + "- Zonal wind (u)\n", + "- Meridional wind (v)\n", + "- Specific humidity (q)\n", + "\n", + "**Predictors on surface**:\n", + "- Surface pressure (p)\n", + "- Totat precipitation (pr)\n", + "\n", + "**Auxiliary predictors**:\n", + "- Elevation from Copernicus EU-DEM v1.1 (dem)\n", + "- Day of the year (doy)" + ] + }, + { + "cell_type": "markdown", + "id": "a7e732f3-a219-4317-aa42-638cae86b4af", + "metadata": {}, + "source": [ + "Define the predictand and the model to evaluate:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d75e3627-13c4-46b8-b749-d582e4b802ac", + "metadata": {}, + "outputs": [], + "source": [ + "# define the model parameters\n", + "PREDICTAND = 'pr' # precipitation, tasmin, tasmax\n", + "MODEL = 'USegNet'\n", + "# PPREDICTORS = 'ztuvq'\n", + "PPREDICTORS = ''\n", + "PLEVELS = ['500', '850']\n", + "# SPREDICTORS = 'p'\n", + "SPREDICTORS = 'pr'\n", + "# DEM = 'dem'\n", + "DEM = ''\n", + "DEM_FEATURES = ''\n", + "DOY = ''\n", + "# DOY = 'doy'\n", + "OPTIMS = ['Adam', 'SGD', 'SGDCLR']" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "29602f30-6b7f-4be8-aefa-eff01997bd1e", + "metadata": {}, + "outputs": [], + "source": [ + "# parameters for sensivity analysis\n", + "WET_DAY_THRESHOLD = [0, 0.5, 1, 2, 3, 5]\n", + "LAMBDA = [0, 1e-6, 1e-5, 1e-4, 1e-3, 1e-2, 1]\n", + "LOSS = ['BernoulliGammaLoss', 'L1Loss' ,'MSELoss'] if PREDICTAND == 'pr' else ['L1Loss', 'MSELoss']" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5351191a-7cd6-48e6-a6c2-c7cd695db277", + "metadata": {}, + "outputs": [], + "source": [ + "# maximum learning rates for optimizers and loss functions\n", + "LR = {'pr': {'L1Loss': {'SGD': 0.001, 'Adam': 0.001},\n", + " 'MSELoss': {'SGD': 0.0004, 'Adam': 0.0004},\n", + " 'BernoulliGammaLoss': {'SGD': 0.001, 'Adam': 0.0005}},\n", + " 'tasmin': {'L1Loss': {'SGD': 0.004, 'Adam': 0.001},\n", + " 'MSELoss': {'SGD': 0.002, 'Adam': 0.001}},\n", + " 'tasmax': {'L1Loss': {'SGD': 0.001, 'Adam': 0.001},\n", + " 'MSELoss': {'SGD': 0.004, 'Adam': 0.001}}\n", + " }" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "bda0b985-f5f2-40b3-8b36-94bba633ea51", + "metadata": {}, + "outputs": [], + "source": [ + "# constant wet day threshold not evaluating wet days\n", + "CWDT = 1" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5a47524c-4d70-4556-88a5-79d25038c9c0", + "metadata": {}, + "outputs": [], + "source": [ + "# parameter to investigate\n", + "PARAMETER = {'regularization': LAMBDA}\n", + "# PARAMETER = {'Wet day threshold': WET_DAY_THRESHOLD}" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b7a61547-9887-4068-8fbf-9a799922ff8d", + "metadata": {}, + "outputs": [], + "source": [ + "# extreme quantile of interest\n", + "quantile = 0.98" + ] + }, + { + "cell_type": "markdown", + "id": "cda0cf8b-adce-4513-af69-31df6eb5542f", + "metadata": {}, + "source": [ + "### Imports" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "01ed9fbc-49c2-4028-a3f3-692a014dcf89", + "metadata": {}, + "outputs": [], + "source": [ + "# builtins\n", + "import datetime\n", + "import warnings\n", + "import calendar\n", + "from itertools import product\n", + "\n", + "# externals\n", + "import xarray as xr\n", + "import numpy as np\n", + "import pandas as pd\n", + "import matplotlib.pyplot as plt\n", + "import matplotlib.ticker as mticker\n", + "import seaborn as sns\n", + "from mpl_toolkits.axes_grid1.inset_locator import inset_axes\n", + "from sklearn.metrics import r2_score, roc_curve, auc, classification_report\n", + "\n", + "# locals\n", + "from climax.core.dataset import ERA5Dataset\n", + "from climax.main.config import VALID_PERIOD\n", + "from climax.main.io import ERA5_PATH, OBS_PATH, TARGET_PATH, DEM_PATH" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6b6f5703-d247-40f5-a859-2d362592b1d8", + "metadata": {}, + "outputs": [], + "source": [ + "# path of predictions for hyperparameter sensitivity analysis\n", + "TARGET_PATH = TARGET_PATH.joinpath('sensitivity')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "073dfc6a-8b33-4aca-ab71-c06ef92bbb80", + "metadata": {}, + "outputs": [], + "source": [ + "# mapping from predictands to variable names\n", + "NAMES = {'tasmin': 'minimum temperature', 'tasmax': 'maximum temperature', 'pr': 'precipitation'}" + ] + }, + { + "cell_type": "markdown", + "id": "9410ed09-50a9-4d19-9e90-8551ab701509", + "metadata": {}, + "source": [ + "### Load datasets" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "01509828-e94c-4ad9-b2a0-67d8f77b4f84", + "metadata": {}, + "outputs": [], + "source": [ + "# load observations\n", + "y_true = xr.open_dataset(OBS_PATH.joinpath(PREDICTAND, 'OBS_{}_1980_2018.nc'.format(PREDICTAND)), chunks={'time': 365})\n", + "y_true = y_true.sel(time=VALID_PERIOD)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "4ca10fb8-bc89-4dea-9346-339e14db9c98", + "metadata": {}, + "outputs": [], + "source": [ + "# mask of missing values\n", + "missing = np.isnan(y_true[NAMES[PREDICTAND]]) if PREDICTAND == 'pr' else np.isnan(y_true[PREDICTAND])" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "030c221d-e0ad-4e67-b742-db34de549983", + "metadata": {}, + "outputs": [], + "source": [ + "# construct file pattern to match\n", + "PATTERN = '_'.join([MODEL, PREDICTAND])\n", + "PATTERN = '_'.join([PATTERN, PPREDICTORS]) if PPREDICTORS else PATTERN\n", + "PATTERN = '_'.join([PATTERN, *PLEVELS]) if PPREDICTORS else PATTERN\n", + "PATTERN = '_'.join([PATTERN, SPREDICTORS]) if SPREDICTORS else PATTERN\n", + "PATTERN = '_'.join([PATTERN, DEM]) if DEM else PATTERN\n", + "PATTERN = '_'.join([PATTERN, DEM_FEATURES]) if DEM_FEATURES else PATTERN\n", + "PATTERN = '_'.join([PATTERN, DOY]) if DOY else PATTERN" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7106fed2-9f63-453b-a9af-b5d490efbbb4", + "metadata": {}, + "outputs": [], + "source": [ + "# file pattern for different parametersa\n", + "if list(*PARAMETER.values()) == WET_DAY_THRESHOLD:\n", + " PATTERNS = ['_'.join([PATTERN, '{}mm_{}'.format(str(t).replace('.', ''), LOSS), OPTIM]) for t in WET_DAY_THRESHOLD]\n", + " \n", + "if list(*PARAMETER.values()) == LAMBDA:\n", + " p = '_'.join([PATTERN, 'loss'])\n", + " p = '_'.join([p, 'optim'])\n", + " PATTERNS = ['_'.join([p, 'd{:.0e}'.format(decay)]) for decay in LAMBDA]\n", + "\n", + "# iterate over loss functions and parameter space\n", + "STATE_FILES = {l: {o: {p: '_'.join([v.replace('loss', '{}mm_{}'.format(CWDT, l) if l in ['BernoulliGammaLoss', 'BernoulliWeibullLoss'] else l).replace(\n", + " 'optim', o if o != 'SGDCLR' else 'SGD'), 'lr{:.0e}'.format(LR[PREDICTAND][l][o if o != 'SGDCLR' else 'SGD'] / 4), 'CyclicLR' if o == 'SGDCLR' else ''])\n", + " for p, v in zip(*PARAMETER.values(), PATTERNS)} for o in OPTIMS} for l in LOSS}\n", + "STATE_FILES" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "58177982-7f86-4c9e-9980-563cf5d48852", + "metadata": {}, + "outputs": [], + "source": [ + "# align datasets and mask missing values\n", + "y_pred = {}\n", + "y_prob = {}\n", + "for l in LOSS:\n", + " y_pred[l], y_prob[l] = {}, {}\n", + " for o in OPTIMS:\n", + " y_pred[l][o], y_prob[l][o] = {}, {}\n", + " for k, v in STATE_FILES[l][o].items():\n", + " ds = xr.open_dataset(TARGET_PATH.joinpath(PREDICTAND, '.'.join([v.rstrip('_'), 'nc'])), chunks={'time': 365})\n", + " if PREDICTAND == 'pr':\n", + " # align datasets and mask missing values\n", + " if l == 'BernoulliGammaLoss':\n", + " _, y, p = xr.align(y_true[NAMES[PREDICTAND]], ds[NAMES[PREDICTAND]], ds.prob, join='override')\n", + " y = y.where(~missing, other=np.nan)\n", + " p = p.where(~missing, other=np.nan)\n", + " else:\n", + " _, y = xr.align(y_true, ds[PREDICTAND], join='override')\n", + " y = y.where(~missing, other=np.nan)\n", + " p = None\n", + " try: \n", + " y = y.rename({PREDICTAND: NAMES[PREDICTAND]})\n", + " except ValueError:\n", + " pass \n", + " y_pred[l][o][k] = y\n", + " y_prob[l][o][k] = p \n", + " else: \n", + " # align datasets and mask missing values\n", + " _, y = xr.align(y_true[PREDICTAND], ds[PREDICTAND], join='override')\n", + " y = y.where(~missing, other=np.nan)\n", + " y_pred[l][o][k] = y\n", + "\n", + "# rename predictand\n", + "if PREDICTAND == 'pr':\n", + " PREDICTAND = NAMES[PREDICTAND]" + ] + }, + { + "cell_type": "markdown", + "id": "5a72117c-2e1c-462a-968e-91db18b4aaaf", + "metadata": {}, + "source": [ + "## Model sensitivity" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "0164fd81-aff4-4fa0-b156-5ba0d817ed49", + "metadata": {}, + "outputs": [], + "source": [ + "# whether to overwrite calculated metrics\n", + "OVERWRITE = False" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "419006dd-9cfa-4a49-a5d2-9c250054aa50", + "metadata": {}, + "outputs": [], + "source": [ + "# initialize DataFrame filename\n", + "filename = '_'.join([PREDICTAND, PPREDICTORS]) if PPREDICTORS else PREDICTAND\n", + "filename = '_'.join([filename, SPREDICTORS]) if SPREDICTORS else filename\n", + "filename = '_'.join([filename, DEM]) if DEM else filename\n", + "filename = '_'.join([filename, DOY]) if DOY else filename\n", + "filename = '_'.join([filename, 'sensitivity_{}.csv'.format(*PARAMETER.keys())])\n", + "df_name = TARGET_PATH.joinpath(PREDICTAND if PREDICTAND != 'precipitation' else 'pr', filename)\n", + "df_name" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9758d58b-82e9-47b2-9cf7-1c93d5ed04a7", + "metadata": {}, + "outputs": [], + "source": [ + "# yearly average of observations\n", + "y_true_yearly_avg = y_true[PREDICTAND].groupby('time.year').mean(dim='time')\n", + "\n", + "# daily anomalies\n", + "y_true_anom = ERA5Dataset.anomalies(y_true[PREDICTAND], timescale='time.month').values\n", + "mask = np.isnan(y_true_anom)\n", + "y_true_values = y_true_anom[~mask]" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "57ca80c6-2c02-44d1-87eb-35c8a78ec71f", + "metadata": {}, + "outputs": [], + "source": [ + "if df_name.exists() and not OVERWRITE:\n", + " df = pd.read_csv(df_name, sep=',')\n", + "else:\n", + " # initialize model hyperparameter DataFrame\n", + " df = pd.DataFrame(columns=['R2', 'Bias', 'MAE', 'RMSE', 'AUC', 'ROCSS', 'Loss', 'Optim', *PARAMETER.keys()])\n", + " df\n", + " \n", + " # define string to log\n", + " if PREDICTAND == 'precipitation':\n", + " logstring = '({}), ({}), ({}={:.0e}): r2: {:.2f}, bias: {:.2f}%, MAE: {:.2f}mm, RMSE: {:.2f}mm, AUC: {:.2f}, ROCSS: {:.2f}'\n", + " else:\n", + " logstring = '({}), ({}), ({}={:.0e}): r2: {:.2f}, bias: {:.2f}°C, MAE: {:.2f}°C, RMSE: {:.2f}°C, AUC: {:.2f}, ROCSS: {:.2f}'\n", + "\n", + " # calculate metrics for each hyperparameter combination\n", + " for loss, prediction in y_pred.items():\n", + " for optim, ds in prediction.items():\n", + " for param, value in ds.items(): \n", + " with warnings.catch_warnings():\n", + " warnings.simplefilter('ignore', category=RuntimeWarning)\n", + "\n", + " # coefficient of determination\n", + " try:\n", + " # compute predicted daily anomalies\n", + " y_pred_anom = ERA5Dataset.anomalies(value, timescale='time.month').values\n", + "\n", + " # apply mask of valid pixels\n", + " y_pred_values = y_pred_anom[~mask]\n", + "\n", + " # coefficient of determination\n", + " r2 = r2_score(y_true_values, y_pred_values)\n", + " except ValueError:\n", + " r2 = np.nan\n", + "\n", + " # calculate metrics on yearly average of daily values\n", + " y_pred_yearly_avg = value.groupby('time.year').mean(dim='time')\n", + "\n", + " # relative or absolute bias\n", + " if PREDICTAND == 'precipitation':\n", + " # calculate relative bias for precipitation\n", + " bias = ((y_pred_yearly_avg - y_true_yearly_avg) / y_true_yearly_avg).mean().values.item() * 100\n", + " else:\n", + " # calculate absolute bias for temperatures\n", + " bias = (y_pred_yearly_avg - y_true_yearly_avg).mean().values.item()\n", + "\n", + " # mean absolute error and root mean squared error\n", + " mae = np.abs(y_pred_yearly_avg - y_true_yearly_avg).mean().values.item()\n", + " rmse = np.sqrt(((y_pred_yearly_avg - y_true_yearly_avg) ** 2).mean()).values.item()\n", + "\n", + " # receiver operating characteristics\n", + " if loss in ['BernoulliGammaLoss', 'BernoulliWeibullLoss']:\n", + " p_true = (y_true[PREDICTAND] >= param if list(*PARAMETER.values()) == WET_DAY_THRESHOLD\n", + " else y_true[PREDICTAND] >= CWDT).values\n", + " p_pred = y_prob[loss][optim][param].values\n", + "\n", + " # apply mask of valid pixels\n", + " p_pred = p_pred[~mask]\n", + " p_t = p_true[~mask].astype(float)\n", + "\n", + " # calculate ROC: false positive rate vs. true positive rate\n", + " fpr, tpr, _ = roc_curve(p_t, p_pred)\n", + " area = auc(fpr, tpr) # area under ROC curve\n", + " rocss = 2 * area - 1 # ROC skill score (cf. https://journals.ametsoc.org/view/journals/clim/16/24/1520-0442_2003_016_4145_otrsop_2.0.co_2.xml)\n", + " else:\n", + " area, rocss = np.nan, np.nan\n", + "\n", + " # log metrics to console\n", + " print(logstring.format(loss, optim, *PARAMETER.keys(), param, r2, bias, mae, rmse, area, rocss))\n", + "\n", + " # add simulation to dataframe\n", + " df = df.append(pd.DataFrame([[r2, bias, mae, rmse, area, rocss, loss, optim, param]], columns=df.columns))\n", + "\n", + " # save DataFrame of calculated metrics to csv\n", + " df.to_csv(df_name, header=True, index=False, sep=',')\n", + " df" + ] + }, + { + "cell_type": "markdown", + "id": "975d77a7-e4b1-46ac-877d-f1a4464e4050", + "metadata": {}, + "source": [ + "## Visualize results" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "026167b6-1737-495c-b33c-dd70779b351c", + "metadata": {}, + "outputs": [], + "source": [ + "# color palette\n", + "PALETTE = 'mako_r'\n", + "sns.color_palette(PALETTE, n_colors=len(LOSS))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8c66b1cb-318c-48b5-a27c-3e6394455e69", + "metadata": {}, + "outputs": [], + "source": [ + "# metrics to plot\n", + "METRICS = ['Bias', 'MAE', 'ROCSS' if PREDICTAND == 'precipitation' else 'RMSE', 'R2']" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "a6a8d8b1-6c66-488a-b961-86e60593e7c7", + "metadata": {}, + "outputs": [], + "source": [ + "# instanciate figure\n", + "fig, axes = plt.subplots(4, len(OPTIMS), sharex=True, figsize=(16, 10))\n", + "for axis, m in zip(axes, METRICS):\n", + " for ax, optim in zip(axis, OPTIMS):\n", + " # subset to current optimizer\n", + " data = df[df['Optim'] == optim]\n", + " sns.barplot(x=list(PARAMETER.keys()).pop(), y=m, hue='Loss', data=data, ax=ax, palette=PALETTE)\n", + " \n", + " # axis properties\n", + " ax.set_ylabel('')\n", + " ax.set_xlabel('')\n", + " ax.get_legend().remove()\n", + " ax.tick_params(axis='both', which='major', labelsize=14)\n", + " if m == 'Bias': \n", + " ax.set_ylim((-20, 20) if PREDICTAND == 'precipitation' else (-1, 1))\n", + " else:\n", + " ax.set_ylim(0, 1)\n", + " \n", + " # adjust y-ticklabels\n", + " for ax in axis[1:]:\n", + " ax.set_yticklabels('')\n", + "\n", + " # adjust y-label\n", + " if m == 'Bias':\n", + " axis[0].set_ylabel('Relative Bias (%)' if PREDICTAND == 'precipitation' else 'Bias (°C)', fontsize=16)\n", + " elif m in ['MAE', 'RMSE']:\n", + " axis[0].set_ylabel('{} ({})'.format(m, 'mm' if PREDICTAND == 'precipitation' else '°C'), fontsize=16)\n", + " else:\n", + " axis[0].set_ylabel(m, fontsize=16)\n", + "\n", + "# adjust x-tick labels\n", + "for ax in axes[-1, :]:\n", + " ax.set_xticklabels(['$10^{{ {:.0f} }}$'.format(np.log10(v)) if v > 0 else '{:.0f}'.format(v) for v in list(*PARAMETER.values())])\n", + " \n", + "# add optimizer to axes\n", + "for ax, optim in zip(axes[0, :], OPTIMS):\n", + " ax.set_title('SGD + CLR' if optim == 'SGDCLR' else optim, fontsize=16)\n", + "\n", + "# add legend\n", + "axes[-1, 0].legend(loc='upper left', bbox_to_anchor=(-0.05, -0.15), frameon=False, ncol=len(LOSS), fontsize=16);\n", + "\n", + "# adjust subplots\n", + "fig.subplots_adjust(wspace=0.05, hspace=0.15)\n", + "fig.savefig('./Figures/{}'.format(df_name.name.replace('csv', 'png')), bbox_inches='tight', dpi=300)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "18b21e6e-711c-4afd-91de-9d7e4403c6ce", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.7.10" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/Notebooks/eval_temperature.ipynb b/Notebooks/eval_temperature.ipynb index f315da9240883ed208b9144b88e58c8865caca2e..6422cb5a47f3d8a82346cec838131bb6793e681b 100644 --- a/Notebooks/eval_temperature.ipynb +++ b/Notebooks/eval_temperature.ipynb @@ -29,7 +29,7 @@ "- Specific humidity (q)\n", "\n", "**Predictors on surface**:\n", - "- Mean sea level pressure (msl)\n", + "- Surface pressure (p)\n", "\n", "**Auxiliary predictors**:\n", "- Elevation from Copernicus EU-DEM v1.1 (dem)\n", @@ -47,19 +47,50 @@ { "cell_type": "code", "execution_count": null, - "id": "bb24b6ed-2d0a-44e0-b9a9-abdcb2a8294d", + "id": "bd50716f-eeba-431a-97ce-24ea5d625090", "metadata": {}, "outputs": [], "source": [ "# define the model parameters\n", - "PREDICTAND = 'tasmax'\n", + "PREDICTAND = 'tasmin'\n", "MODEL = 'USegNet'\n", "PPREDICTORS = 'ztuvq'\n", + "# PPREDICTORS = ''\n", "PLEVELS = ['500', '850']\n", + "# PLEVELS = []\n", "SPREDICTORS = 'p'\n", + "# SPREDICTORS = 'ppr'\n", "DEM = 'dem'\n", "DEM_FEATURES = ''\n", - "DOY = 'doy'" + "DOY = 'doy'\n", + "LOSS = 'L1Loss'\n", + "# LOSS = 'MSELoss'\n", + "OPTIM = 'SGD'\n", + "# OPTIM = 'Adam'\n", + "SEASON = ''\n", + "DECAY = '1e-06'\n", + "LR = '1e-03'" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d43149c7-ed56-4c91-9ce9-e8b226b48426", + "metadata": {}, + "outputs": [], + "source": [ + "# construct file pattern to match\n", + "PATTERN = '_'.join([MODEL, PREDICTAND, PPREDICTORS, *PLEVELS])\n", + "PATTERN = '_'.join([PATTERN, SPREDICTORS]) if SPREDICTORS else PATTERN\n", + "PATTERN = '_'.join([PATTERN, DEM]) if DEM else PATTERN\n", + "PATTERN = '_'.join([PATTERN, DEM_FEATURES]) if DEM_FEATURES else PATTERN\n", + "PATTERN = '_'.join([PATTERN, DOY]) if DOY else PATTERN\n", + "PATTERN = '_'.join([PATTERN, LOSS])\n", + "PATTERN = '_'.join([PATTERN, OPTIM])\n", + "PATTERN = '_'.join([PATTERN, SEASON]) if SEASON else PATTERN\n", + "PATTERN = '_'.join([PATTERN, 'd{}'.format(DECAY)])\n", + "PATTERN = '_'.join([PATTERN, 'lr{}'.format(LR)])\n", + "PATTERN" ] }, { @@ -73,6 +104,17 @@ "NAMES = {'tasmin': 'minimum temperature', 'tasmax': 'maximum temperature', 'pr': 'precipitation'}" ] }, + { + "cell_type": "code", + "execution_count": null, + "id": "f1ae6427-67e3-4df3-b84a-5aad32504d79", + "metadata": {}, + "outputs": [], + "source": [ + "# whether to search for the defined pattern\n", + "SEARCH = False " + ] + }, { "cell_type": "markdown", "id": "4d5d12c5-50fd-4c5c-9240-c3df78e49b44", @@ -97,13 +139,17 @@ "import xarray as xr\n", "import numpy as np\n", "import matplotlib.pyplot as plt\n", + "from matplotlib import cm\n", + "from matplotlib.colors import ListedColormap\n", "from mpl_toolkits.axes_grid1.inset_locator import inset_axes\n", "import scipy.stats as stats\n", "from IPython.display import Image\n", "from sklearn.metrics import r2_score\n", "\n", "# locals\n", - "from climax.main.io import ERA5_PATH, OBS_PATH, TARGET_PATH\n", + "from climax.main.io import ERA5_PATH, OBS_PATH, TARGET_PATH, DEM_PATH, MODEL_PATH\n", + "from climax.core.utils import plot_loss\n", + "from climax.core.dataset import ERA5Dataset\n", "from pysegcnn.core.utils import search_files" ] }, @@ -160,16 +206,13 @@ { "cell_type": "code", "execution_count": null, - "id": "e8cb1011-32dd-4e32-9692-4a4aa009869f", + "id": "ad84738f-3600-47e4-8178-4ee4ad5e3fc3", "metadata": {}, "outputs": [], "source": [ - "# construct file pattern to match\n", - "PATTERN = '_'.join([MODEL, PREDICTAND, PPREDICTORS, *PLEVELS])\n", - "PATTERN = '_'.join([PATTERN, SPREDICTORS]) if SPREDICTORS else PATTERN\n", - "PATTERN = '_'.join([PATTERN, DEM]) if DEM else PATTERN\n", - "PATTERN = '_'.join([PATTERN, DEM_FEATURES]) if DEM_FEATURES else PATTERN\n", - "PATTERN = '_'.join([PATTERN, DOY]) if DOY else PATTERN" + "# digital elevation model\n", + "dem = xr.open_dataset(search_files(DEM_PATH, 'eu_dem_v11_stt.nc').pop())\n", + "dem = dem.Band1.to_dataset().rename({'Band1': 'elevation'})" ] }, { @@ -179,8 +222,23 @@ "metadata": {}, "outputs": [], "source": [ - "# model predictions and observations NetCDF\n", - "y_pred = xr.open_dataset(search_files(TARGET_PATH.joinpath(PREDICTAND), '.'.join([PATTERN, 'nc$'])).pop())\n", + "# model predictions\n", + "if SEARCH:\n", + " y_pred = xr.open_dataset(search_files(TARGET_PATH.joinpath(PREDICTAND), '.'.join([PATTERN, 'nc$'])).pop())\n", + "else:\n", + " filename = 'USegNet_tasmin_ztuvq_500_850_p_dem_doy_L1Loss_Adam_d1e-03_lr3e-04.nc'\n", + " file = search_files(TARGET_PATH, filename).pop()\n", + " y_pred = xr.open_dataset(file)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8bb55ddf-1d19-4f1a-8b31-9da778b287fd", + "metadata": {}, + "outputs": [], + "source": [ + "# observations\n", "if PREDICTAND == 'tas':\n", " # read both tasmax and tasmin\n", " tasmax = xr.open_dataset(search_files(OBS_PATH.joinpath('tasmax'), '.nc$').pop())\n", @@ -251,76 +309,130 @@ "outputs": [], "source": [ "# align datasets and mask missing values in model predictions\n", - "y_true, y_pred, y_refe = xr.align(y_true, y_pred, y_refe, join='override')\n", + "y_true, y_refe, y_pred = xr.align(y_true[PREDICTAND], y_refe[PREDICTAND], y_pred[PREDICTAND], join='override')\n", "y_pred = y_pred.where(~np.isnan(y_true), other=np.nan)\n", "y_refe = y_refe.where(~np.isnan(y_true), other=np.nan)" ] }, + { + "cell_type": "code", + "execution_count": null, + "id": "dee1b549-c5a6-44cd-9fe2-563ccfea658d", + "metadata": {}, + "outputs": [], + "source": [ + "# align digital elevation model\n", + "_, dem = xr.align(y_true.isel(time=0), dem, join='override')\n", + "dem = dem.where(~np.isnan(y_true.isel(time=0)), other=np.nan)" + ] + }, { "cell_type": "markdown", - "id": "ddebdf9f-862c-461e-aa57-cd344d54eee9", + "id": "9a696a5c-1de8-4027-b3f8-64045b7333fb", "metadata": {}, "source": [ - "## Model validation: temperature" + "### Model convergence" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "98349528-6a4d-469f-86f4-60ce4d18e719", + "metadata": {}, + "outputs": [], + "source": [ + "# plot model state\n", + "model_state = MODEL_PATH.joinpath(PREDICTAND, '.'.join([PATTERN, 'pt']))\n", + "try:\n", + " fig = plot_loss(model_state, step=5)\n", + "except FileNotFoundError:\n", + " pass" ] }, { "cell_type": "markdown", - "id": "ab15d557-c7ea-40c0-9977-a3d410fea784", + "id": "ddebdf9f-862c-461e-aa57-cd344d54eee9", "metadata": {}, "source": [ - "### Coefficient of determination" + "## Model validation: temperature" ] }, { "cell_type": "code", "execution_count": null, - "id": "bb8223e0-c5cb-477c-8c31-3dfb94b20ce1", + "id": "dda74f76-2b3e-484c-b56d-d1d3d23784c2", "metadata": {}, "outputs": [], "source": [ - "# group timeseries by month and calculate mean over time and space\n", - "y_pred_ac = y_pred.groupby('time.month').mean(dim=('time', 'y', 'x'))\n", - "y_true_ac = y_true.groupby('time.month').mean(dim=('time', 'y', 'x'))" + "# calculate monthly means\n", + "y_pred_mm = y_pred.groupby('time.month').mean(dim=('time'))\n", + "y_true_mm = y_true.groupby('time.month').mean(dim=('time'))" ] }, { "cell_type": "code", "execution_count": null, - "id": "619d5dc9-4d36-43a3-b23c-a4ea51229c78", + "id": "0993676b-19de-4426-9c8c-c3897b4f9bf7", "metadata": {}, "outputs": [], "source": [ - "# get predicted and observed values over entire time series and grid points\n", - "y_pred_values = y_pred[PREDICTAND].groupby('time.month').mean(dim='time').values.flatten()\n", - "y_true_values = y_true[PREDICTAND].groupby('time.month').mean(dim='time').values.flatten()" + "# calculate mean annual cycle\n", + "y_pred_ac = y_pred_mm.mean(dim=('x', 'y'))\n", + "y_true_ac = y_true_mm.mean(dim=('x', 'y'))" ] }, { "cell_type": "code", "execution_count": null, - "id": "49dff6ce-a629-460b-a43b-d1a0ef447351", - "metadata": { - "tags": [] - }, + "id": "c3e5f725-9d6a-4837-88a6-7ad1baa06027", + "metadata": {}, "outputs": [], "source": [ - "# apply mask of valid pixels\n", - "mask = (~np.isnan(y_pred_values) & ~np.isnan(y_true_values))\n", - "y_pred_values = y_pred_values[mask]\n", - "y_true_values = y_true_values[mask]" + "# compute daily anomalies\n", + "y_pred_anom = ERA5Dataset.anomalies(y_pred, timescale='time.month')\n", + "y_true_anom = ERA5Dataset.anomalies(y_true, timescale='time.month')" + ] + }, + { + "cell_type": "markdown", + "id": "ab15d557-c7ea-40c0-9977-a3d410fea784", + "metadata": {}, + "source": [ + "### Coefficient of determination" ] }, { "cell_type": "code", "execution_count": null, - "id": "e9e46770-4176-4257-8ea2-7050d3325e98", + "id": "619d5dc9-4d36-43a3-b23c-a4ea51229c78", "metadata": {}, "outputs": [], "source": [ - "# calculate coefficient of determination\n", - "r2 = r2_score(y_true_values, y_pred_values)\n", - "r2" + "# get predicted and observed daily anomalies\n", + "y_pred_av = y_pred_anom.values.flatten()\n", + "y_true_av = y_true_anom.values.flatten()\n", + "\n", + "# apply mask of valid pixels\n", + "mask = (~np.isnan(y_pred_av) & ~np.isnan(y_true_av))\n", + "y_pred_av = y_pred_av[mask]\n", + "y_true_av = y_true_av[mask]\n", + "\n", + "# get predicted and observed monthly means\n", + "y_pred_mv = y_pred_mm.values.flatten()\n", + "y_true_mv = y_true_mm.values.flatten()\n", + "\n", + "# apply mask of valid pixels\n", + "mask = (~np.isnan(y_pred_mv) & ~np.isnan(y_true_mv))\n", + "y_pred_mv = y_pred_mv[mask]\n", + "y_true_mv = y_true_mv[mask]\n", + "\n", + "# calculate coefficient of determination on monthly means\n", + "r2_mm = r2_score(y_true_mv, y_pred_mv)\n", + "print('R2 on monthly means: {:.2f}'.format(r2_mm))\n", + "\n", + "# calculate coefficient of determination on daily anomalies\n", + "r2_anom = r2_score(y_true_av, y_pred_av)\n", + "print('R2 on daily anomalies: {:.2f}'.format(r2_anom))" ] }, { @@ -338,15 +450,18 @@ "# ax.plot(y_true_values[subset], y_pred_values[subset], 'o', alpha=.5, markeredgecolor='grey', markerfacecolor='none', markersize=3);\n", "\n", "# plot entire dataset\n", - "ax.plot(y_true_values, y_pred_values, 'o', alpha=.5, markeredgecolor='grey', markerfacecolor='none', markersize=3);\n", + "ax.plot(y_true_mv, y_pred_mv, 'o', alpha=.5, markeredgecolor='grey', markerfacecolor='none', markersize=3);\n", "\n", "# plot 1:1 mapping line\n", - "interval = np.arange(-15, 45, 5)\n", - "#interval = np.arange(-30, 35, 5)\n", + "if PREDICTAND == 'tasmin':\n", + " interval = np.arange(-25, 30, 5)\n", + "else:\n", + " interval = np.arange(-15, 45, 5)\n", "ax.plot(interval, interval, color='k', lw=2, ls='--')\n", "\n", - "# add coefficient of determination: calculated on entire dataset!\n", - "ax.text(interval[-1] - 0.5, interval[0] + 0.5, s='Coefficient of determination R$^2$ = {:.2f}'.format(r2), ha='right', fontsize=18)\n", + "# add coefficients of determination\n", + "ax.text(interval[-1] - 0.5, interval[0] + 0.5, s='R$^2$ (monthly means)= {:.2f}'.format(r2_mm), ha='right', fontsize=18)\n", + "ax.text(interval[-1] - 0.5, interval[0] + 2.5, s='R$^2$ (daily anomalies) = {:.2f}'.format(r2_anom), ha='right', fontsize=18)\n", "\n", "# format axes\n", "ax.set_ylim(interval[0], interval[-1])\n", @@ -361,21 +476,102 @@ "\n", "# add axis for annual cycle\n", "axins = inset_axes(ax, width=\"30%\", height=\"40%\", loc=2, borderpad=1)\n", - "axins.plot(y_pred_ac[PREDICTAND].values, ls='--', color='k', label='Predicted')\n", - "axins.plot(y_true_ac[PREDICTAND].values, ls='-', color='k', label='Observed')\n", + "axins.plot(y_pred_ac.values, ls='--', color='k', label='Predicted')\n", + "axins.plot(y_true_ac.values, ls='-', color='k', label='Observed')\n", "axins.legend(frameon=False, fontsize=12, loc='lower center');\n", "axins.yaxis.tick_right()\n", - "#axins.set_yticks(np.arange(-10, 11, 2))\n", - "#axins.set_yticklabels(np.arange(-10, 11, 2), fontsize=12)\n", - "axins.set_yticks(np.arange(-0, 22, 2))\n", - "axins.set_yticklabels(np.arange(0, 22, 2), fontsize=12)\n", + "axins.set_yticks(np.arange(-10, 11, 2) if PREDICTAND == 'tasmin' else np.arange(0, 20, 2))\n", + "axins.set_yticklabels(np.arange(-10, 11, 2) if PREDICTAND == 'tasmin' else np.arange(0, 20, 2), fontsize=12)\n", "axins.set_xticks(np.arange(0, 12))\n", "axins.set_xticklabels([calendar.month_name[i + 1] for i in np.arange(0, 12)], rotation=90, fontsize=12)\n", - "# axins.set_ylabel('{} / (°C)'.format(NAMES[PREDICTAND].capitalize()), fontsize=14)\n", - "# axins.set_xlabel('Month', fontsize=14);\n", + "axins.set_title('Mean annual cycle', fontsize=14, pad=5)\n", + "\n", + "# save figure\n", + "fig.savefig('../Notebooks/Figures/{}_r2_{}.png'.format(PREDICTAND, LOSS), dpi=300, bbox_inches='tight')" + ] + }, + { + "cell_type": "markdown", + "id": "4e50ff44-5598-4694-9ea3-b06ee50ee21f", + "metadata": {}, + "source": [ + "### Coefficient of determination: Spatially" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "816fad18-b365-4c62-b669-e13e7dc8d322", + "metadata": {}, + "outputs": [], + "source": [ + "# iterate over the grid points\n", + "r2 = np.ones((2, len(y_pred.x), len(y_pred.y)), dtype=np.float32) * np.nan\n", + "for i, _ in enumerate(y_pred.x):\n", + " for j, _ in enumerate(y_pred.y):\n", + " # get observed and predicted monthly mean temperature for current grid point\n", + " point_true = y_true_mm.isel(x=i, y=j)\n", + " point_pred = y_pred_mm.isel(x=i, y=j)\n", + "\n", + " # remove missing values\n", + " mask = ((~np.isnan(point_true)) & (~np.isnan(point_pred)))\n", + " point_true = point_true[mask].values\n", + " point_pred = point_pred[mask].values\n", + " if point_true.size < 1:\n", + " continue\n", + " \n", + " # get anomalies for current grid point\n", + " point_anom_true = y_true_anom.isel(x=i, y=j)\n", + " point_anom_pred = y_pred_anom.isel(x=i, y=j)\n", + " \n", + " # remove missing values\n", + " mask_anom = ((~np.isnan(point_anom_true)) & (~np.isnan(point_anom_pred)))\n", + " point_anom_true = point_anom_true[mask_anom].values\n", + " point_anom_pred = point_anom_pred[mask_anom].values\n", + "\n", + " # compute coefficient of determination\n", + " r2[0, j, i] = r2_score(point_true, point_pred)\n", + " r2[1, j, i] = r2_score(point_anom_true, point_anom_pred)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b0f0e83c-3dfd-46dd-bd02-7ff1ba8f8573", + "metadata": {}, + "outputs": [], + "source": [ + "# define color map: red to green\n", + "grn = cm.get_cmap('Greens', 128)\n", + "red = cm.get_cmap('Reds_r', 128)\n", + "red2green = ListedColormap(np.vstack((red(np.linspace(0, 1, 128)),\n", + " grn(np.linspace(0, 1, 128)))))\n", + "\n", + "# plot coefficient of determination\n", + "vmin, vmax = -1, 1\n", + "fig, ax = plt.subplots(1, 2, figsize=(10, 10))\n", + "\n", + "# monthly means\n", + "im0 = ax[0].imshow(r2[0, :], origin='lower', cmap=red2green, vmin=vmin, vmax=vmax)\n", + "ax[0].text(x=r2.shape[1] - 2, y=2, s='Average R$^2$: {:.2f}'.format(np.nanmean(r2[0, :])), fontsize=14, ha='right');\n", + "ax[0].set_axis_off()\n", + "ax[0].set_title('Monthly mean {} (°C)'.format(NAMES[PREDICTAND]), fontsize=14, pad=10);\n", + "\n", + "# daily anomalies\n", + "im1 = ax[1].imshow(r2[1, :], origin='lower', cmap=red2green, vmin=vmin, vmax=vmax)\n", + "ax[1].text(x=r2.shape[1] - 2, y=2, s='Average R$^2$: {:.2f}'.format(np.nanmean(r2[1, :])), fontsize=14, ha='right');\n", + "ax[1].set_axis_off()\n", + "ax[1].set_title('Daily {} anomaly (°C)'.format(NAMES[PREDICTAND]), fontsize=14, pad=10);\n", + "\n", + "# add colorbar \n", + "cbar_ax_bias = fig.add_axes([ax[1].get_position().x1 + 0.05, ax[1].get_position().y0,\n", + " 0.03, ax[1].get_position().y1 - ax[1].get_position().y0])\n", + "cbar_bias = fig.colorbar(im0, cax=cbar_ax_bias)\n", + "cbar_bias.set_label(label='Coefficient of determination R$^2$', fontsize=14)\n", + "cbar_bias.ax.tick_params(labelsize=14, pad=10)\n", "\n", "# save figure\n", - "fig.savefig('../Notebooks/Figures/{}_r2.png'.format(PREDICTAND), dpi=300, bbox_inches='tight')" + "fig.savefig('../Notebooks/Figures/{}_r2_spatial_{}.png'.format(PREDICTAND, LOSS), dpi=300, bbox_inches='tight')" ] }, { @@ -407,9 +603,8 @@ "y_true_yearly_avg = y_true.groupby('time.year').mean(dim='time')\n", "bias_yearly_avg = y_pred_yearly_avg - y_true_yearly_avg\n", "bias_yearly_avg_ref = y_refe_yearly_avg - y_true_yearly_avg\n", - "for var in bias_yearly_avg:\n", - " print('(Model) Yearly average bias of {}: {:.2f}°C'.format(var, bias_yearly_avg[var].mean().item()))\n", - " print('(ERA-5) Yearly average bias of {}: {:.2f}°C'.format(var, bias_yearly_avg_ref.mean().to_array().values.item()))" + "print('(Model) Yearly average bias of {}: {:.2f}°C'.format(PREDICTAND, bias_yearly_avg.mean().item()))\n", + "print('(ERA-5) Yearly average bias of {}: {:.2f}°C'.format(PREDICTAND, bias_yearly_avg_ref.mean().item()))" ] }, { @@ -422,9 +617,8 @@ "# mean absolute error over reference period\n", "mae_avg = np.abs(y_pred_yearly_avg - y_true_yearly_avg).mean()\n", "mae_avg_ref = np.abs(y_refe_yearly_avg - y_true_yearly_avg).mean()\n", - "for var in mae_avg:\n", - " print('(Model) Yearly average MAE of {}: {:.2f}°C'.format(var, mae_avg[var].item()))\n", - " print('(ERA-5) Yearly average MAE of {}: {:.2f}°C'.format(var, mae_avg_ref.mean().to_array().values.item()))" + "print('(Model) Yearly average MAE of {}: {:.2f}°C'.format(PREDICTAND, mae_avg.mean().item()))\n", + "print('(ERA-5) Yearly average MAE of {}: {:.2f}°C'.format(PREDICTAND, mae_avg_ref.mean().item()))" ] }, { @@ -435,11 +629,10 @@ "outputs": [], "source": [ "# root mean squared error over reference period\n", - "rmse_avg = ((y_pred_yearly_avg - y_true_yearly_avg) ** 2).mean()\n", - "rmse_avg_ref = ((y_refe_yearly_avg - y_true_yearly_avg) **2).mean()\n", - "for var in rmse_avg:\n", - " print('(Model) Yearly average RMSE of {}: {:.2f}°C'.format(var, rmse_avg[var].item()))\n", - " print('(ERA-5) Yearly average RMSE of {}: {:.2f}°C'.format(var, rmse_avg_ref.mean().to_array().values.item()))" + "rmse_avg = np.sqrt(((y_pred_yearly_avg - y_true_yearly_avg) ** 2)).mean()\n", + "rmse_avg_ref = np.sqrt(((y_refe_yearly_avg - y_true_yearly_avg) **2)).mean()\n", + "print('(Model) Yearly average RMSE of {}: {:.2f}°C'.format(PREDICTAND, rmse_avg.mean().item()))\n", + "print('(ERA-5) Yearly average RMSE of {}: {:.2f}°C'.format(PREDICTAND, rmse_avg_ref.mean().item()))" ] }, { @@ -450,14 +643,82 @@ "outputs": [], "source": [ "# Pearson's correlation coefficient over reference period\n", - "for var in y_pred_yearly_avg:\n", - " correlations = []\n", - " for year in y_pred_yearly_avg.year:\n", - " y_p = y_pred_yearly_avg[var].sel(year=year).values \n", - " y_t = y_true_yearly_avg[var].sel(year=year).values\n", - " r, _ = stats.pearsonr(y_p[~np.isnan(y_p)], y_t[~np.isnan(y_t)])\n", - " correlations.append(r)\n", - "print('Yearly average Pearson correlation coefficient for {}: {:.2f}'.format(var, np.asarray(r).mean()))" + "correlations = []\n", + "for year in y_pred_yearly_avg.year:\n", + " y_p = y_pred_yearly_avg.sel(year=year).values \n", + " y_t = y_true_yearly_avg.sel(year=year).values\n", + " r, _ = stats.pearsonr(y_p[~np.isnan(y_p)], y_t[~np.isnan(y_t)])\n", + " print('({:d}): {:.2f}'.format(year.item(), r))\n", + " correlations.append(r)\n", + "print('Yearly average Pearson correlation coefficient for {}: {:.2f}'.format(PREDICTAND, np.asarray(r).mean()))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "4ec9b998-bc6c-4ef4-a602-ed87ae3260ae", + "metadata": {}, + "outputs": [], + "source": [ + "# plot yearly average bias of reference vs. prediction\n", + "vmin, vmax = -2, 2\n", + "fig, axes = plt.subplots(1, 2, figsize=(16, 8), sharex=True, sharey=True)\n", + "\n", + "# plot bias of ERA-5 reference\n", + "reference = bias_yearly_avg_ref.mean(dim='year')\n", + "im1 = axes[0].imshow(reference.values, origin='lower', cmap='RdBu_r', vmin=vmin, vmax=vmax)\n", + "axes[0].text(x=reference.shape[0] - 2, y=2, s='Average: {:.2f}°C'.format(reference.mean().item()), fontsize=14, ha='right')\n", + "\n", + "# plot MAE of model\n", + "prediction = bias_yearly_avg.mean(dim='year')\n", + "im2 = axes[1].imshow(prediction.values, origin='lower', cmap='RdBu_r', vmin=vmin, vmax=vmax)\n", + "axes[1].text(x=reference.shape[0] - 2, y=2, s='Average: {:.2f}°C'.format(prediction.mean().item()), fontsize=14, ha='right')\n", + "\n", + "# plot topography\n", + "# im_dem = axes[2].imshow(dem['elevation'].values, origin='lower', cmap='terrain', vmin=0, vmax=4000)\n", + "\n", + "# set titles\n", + "axes[0].set_title('ERA-5', fontsize=14, pad=10);\n", + "axes[1].set_title('DCEDN', fontsize=14, pad=10);\n", + "# axes[2].set_title('Copernicus EU-DEM v1.1', fontsize=14, pad=10)\n", + "\n", + "# adjust axes\n", + "for ax in axes.flat:\n", + " ax.axes.get_xaxis().set_ticklabels([])\n", + " ax.axes.get_xaxis().set_ticks([])\n", + " ax.axes.get_yaxis().set_ticklabels([])\n", + " ax.axes.get_yaxis().set_ticks([])\n", + " ax.axes.axis('tight')\n", + " ax.set_xlabel('')\n", + " ax.set_ylabel('')\n", + " ax.set_axis_off()\n", + "\n", + "# adjust figure\n", + "# fig.suptitle('Average yearly mean absolute error: 1991 - 2010', fontsize=20);\n", + "fig.subplots_adjust(hspace=0, wspace=0, top=0.85)\n", + "\n", + "# add colorbar for dem\n", + "axes = axes.flatten()\n", + "# cbar_ax_bias = fig.add_axes([axes[-1].get_position().x1 + 0.01, axes[-1].get_position().y0,\n", + "# 0.01, axes[-1].get_position().y1 - axes[-1].get_position().y0])\n", + "# cbar_bias = fig.colorbar(im_dem, cax=cbar_ax_bias)\n", + "# cbar_bias.set_label(label='Elevation (m)', fontsize=14)\n", + "# cbar_bias.ax.tick_params(labelsize=14, pad=10)\n", + "\n", + "# add colorbar for predictand\n", + "cbar_ax_predictand = fig.add_axes([axes[0].get_position().x0, axes[0].get_position().y0 - 0.1,\n", + " axes[-1].get_position().x0 - axes[0].get_position().x0,\n", + " 0.03])\n", + "cbar_predictand = fig.colorbar(im1, cax=cbar_ax_predictand, orientation='horizontal')\n", + "cbar_predictand.set_label(label='Mean error (°C)', fontsize=14)\n", + "cbar_predictand.ax.tick_params(labelsize=14, pad=10)\n", + "\n", + "# add metrics: MAE and RMSE\n", + "#axes[1].text(x=ds.shape[0] - 2, y=2, s='MAE = {:.1f}'.format(mae_avg[NAMES[PREDICTAND]].item()) + 'mm day$^{-1}$', fontsize=14, ha='right')\n", + "#axes[1].text(x=ds.shape[0] - 2, y=12, s='RMSE = {:.1f}'.format(rmse_avg[NAMES[PREDICTAND]].item()) + 'mm$^2$ day$^{-2}$', fontsize=14, ha='right')\n", + "\n", + "# save figure\n", + "fig.savefig('../Notebooks/Figures/{}_rbias_ERA_{}.png'.format(PREDICTAND, LOSS), dpi=300, bbox_inches='tight')" ] }, { @@ -469,22 +730,20 @@ "source": [ "# plot average of observation, prediction, and bias\n", "vmin, vmax = (-15, 15) if PREDICTAND == 'tasmin' else (0, 25)\n", - "fig, axes = plt.subplots(len(y_pred_yearly_avg.data_vars), 3, figsize=(24, len(y_pred_yearly_avg.data_vars) * 6),\n", - " sharex=True, sharey=True)\n", - "axes = axes.reshape(len(y_pred_yearly_avg.data_vars), -1)\n", - "for i, var in enumerate(y_pred_yearly_avg):\n", - " for ds, ax in zip([y_true_yearly_avg, y_pred_yearly_avg, bias_yearly_avg], axes[i, ...]):\n", - " if ds is bias_yearly_avg:\n", - " ds = ds[var].mean(dim='year')\n", - " im2 = ax.imshow(ds.values, origin='lower', cmap='RdBu_r', vmin=-2, vmax=2)\n", - " ax.text(x=ds.shape[0] - 2, y=2, s='Average: {:.2f}°C'.format(ds.mean().item()), fontsize=14, ha='right')\n", - " else:\n", - " im1 = ax.imshow(ds[var].mean(dim='year').values, origin='lower', cmap='RdYlBu_r', vmin=vmin, vmax=vmax)\n", + "fig, axes = plt.subplots(1, 3, figsize=(24, 8), sharex=True, sharey=True)\n", + "axes = axes.flatten()\n", + "for ds, ax in zip([y_true_yearly_avg, y_pred_yearly_avg, bias_yearly_avg], axes):\n", + " if ds is bias_yearly_avg:\n", + " ds = ds.mean(dim='year')\n", + " im2 = ax.imshow(ds.values, origin='lower', cmap='RdBu_r', vmin=-2, vmax=2)\n", + " ax.text(x=ds.shape[0] - 2, y=2, s='Average: {:.2f}°C'.format(ds.mean().item()), fontsize=14, ha='right')\n", + " else:\n", + " im1 = ax.imshow(ds.mean(dim='year').values, origin='lower', cmap='RdYlBu_r', vmin=vmin, vmax=vmax)\n", " \n", "# set titles\n", - "axes[0, 0].set_title('Observed', fontsize=16, pad=10);\n", - "axes[0, 1].set_title('Predicted', fontsize=16, pad=10);\n", - "axes[0, 2].set_title('Bias', fontsize=16, pad=10);\n", + "axes[0].set_title('Observed', fontsize=16, pad=10);\n", + "axes[1].set_title('Predicted', fontsize=16, pad=10);\n", + "axes[2].set_title('Bias', fontsize=16, pad=10);\n", "\n", "# adjust axes\n", "for ax in axes.flat:\n", @@ -498,7 +757,7 @@ "\n", "# adjust figure\n", "fig.suptitle('Average {}: 1991 - 2010'.format(NAMES[PREDICTAND]), fontsize=20);\n", - "fig.subplots_adjust(hspace=0, wspace=0, top=0.85)c\n", + "fig.subplots_adjust(hspace=0, wspace=0, top=0.85)\n", "\n", "# add colorbar for bias\n", "axes = axes.flatten()\n", @@ -517,8 +776,8 @@ "cbar_predictand.ax.tick_params(labelsize=14)\n", "\n", "# add metrics: MAE and RMSE\n", - "axes[1].text(x=ds.shape[0] - 2, y=2, s='MAE = {:.2f}°C'.format(mae_avg[PREDICTAND].item()), fontsize=14, ha='right')\n", - "axes[1].text(x=ds.shape[0] - 2, y=12, s='RMSE = {:.2f}°C$^2$'.format(rmse_avg[PREDICTAND].item()), fontsize=14, ha='right')\n", + "axes[1].text(x=ds.shape[0] - 2, y=2, s='MAE = {:.2f}°C'.format(mae_avg.mean().item()), fontsize=14, ha='right')\n", + "axes[1].text(x=ds.shape[0] - 2, y=12, s='RMSE = {:.2f}°C'.format(rmse_avg.mean().item()), fontsize=14, ha='right')\n", "\n", "# save figure\n", "fig.savefig('../Notebooks/Figures/{}_average_bias.png'.format(PREDICTAND), dpi=300, bbox_inches='tight')" @@ -563,9 +822,8 @@ "outputs": [], "source": [ "# print average bias per season: ERA-5\n", - "for var in bias_snl_ref.data_vars:\n", - " for season in bias_snl_ref[PREDICTAND].season:\n", - " print('(ERA-5) Average bias of mean {} for season {}: {:.1f}°C'.format(var, season.values.item(), bias_snl_ref[var].sel(season=season).mean().item()))" + "for season in bias_snl_ref.season:\n", + " print('(ERA-5) Average bias of mean {} for season {}: {:.1f}°C'.format(PREDICTAND, season.item(), bias_snl_ref.sel(season=season).mean().item()))" ] }, { @@ -576,9 +834,8 @@ "outputs": [], "source": [ "# print average bias per season: model\n", - "for var in bias_snl.data_vars:\n", - " for season in bias_snl[PREDICTAND].season:\n", - " print('(Model) Average bias of mean {} for season {}: {:.1f}°C'.format(var, season.values.item(), bias_snl[var].sel(season=season).mean().item()))" + "for season in bias_snl.season:\n", + " print('(Model) Average bias of mean {} for season {}: {:.1f}°C'.format(PREDICTAND, season.item(), bias_snl.sel(season=season).mean().item()))" ] }, { @@ -597,20 +854,21 @@ "outputs": [], "source": [ "# plot seasonal differences\n", - "fig, axes = plt.subplots(nrows=2, ncols=3, figsize=(24, 12), sharex=True, sharey=True)\n", + "seasons = ('DJF', 'JJA')\n", + "fig, axes = plt.subplots(nrows=1, ncols=len(seasons) + 1, figsize=(24, 8), sharex=True, sharey=True)\n", "axes = axes.flatten()\n", "\n", "# plot annual average bias\n", - "ds = bias_yearly_avg[PREDICTAND].mean(dim='year')\n", + "ds = bias_yearly_avg.mean(dim='year')\n", "im = axes[0].imshow(ds.values, origin='lower', cmap='RdBu_r', vmin=-2, vmax=2)\n", "axes[0].set_title('Annual', fontsize=16);\n", "axes[0].text(x=ds.shape[0] - 2, y=2, s='Average: {:.2f}°C'.format(ds.mean().item()), fontsize=14, ha='right')\n", "\n", "# plot seasonal average bias\n", - "for ax, season in zip(axes[1:], bias_snl.season):\n", - " ds = bias_snl[PREDICTAND].sel(season=season)\n", + "for ax, season in zip(axes[1:], seasons):\n", + " ds = bias_snl.sel(season=season)\n", " ax.imshow(ds.values, origin='lower', cmap='RdBu_r', vmin=-2, vmax=2)\n", - " ax.set_title(season.item(), fontsize=16);\n", + " ax.set_title(season, fontsize=16);\n", " ax.text(x=ds.shape[0] - 2, y=2, s='Average: {:.2f}°C'.format(ds.mean().item()), fontsize=14, ha='right')\n", "\n", "# adjust axes\n", @@ -622,13 +880,10 @@ " ax.axes.axis('tight')\n", " ax.set_xlabel('')\n", " ax.set_ylabel('')\n", - " \n", - "# turn off last axis\n", - "axes[-1].set_visible(False)\n", "\n", "# adjust figure\n", "fig.suptitle('Average bias of {}: 1991 - 2010'.format(NAMES[PREDICTAND]), fontsize=20);\n", - "fig.subplots_adjust(hspace=0.1, wspace=0, top=0.925)\n", + "fig.subplots_adjust(hspace=0, wspace=0, top=0.85)\n", "\n", "# add colorbar\n", "cbar_ax_predictand = fig.add_axes([axes[-1].get_position().x1 + 0.01, axes[-1].get_position().y0,\n", @@ -638,51 +893,7 @@ "cbar_predictand.ax.tick_params(labelsize=14)\n", "\n", "# save figure\n", - "fig.savefig('../Notebooks/Figures/{}_average_bias_seasonal.png'.format(PREDICTAND), dpi=300, bbox_inches='tight')" - ] - }, - { - "cell_type": "markdown", - "id": "41600269-2f8c-4717-8f74-b3dfaef60359", - "metadata": {}, - "source": [ - "Calculate the mean annual cycle:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "c9f27c01-4dfc-4d16-8d29-00e69b7794cd", - "metadata": {}, - "outputs": [], - "source": [ - "# group timeseries by month and calculate mean over time and space\n", - "y_pred_ac = y_pred.groupby('time.month').mean(dim=('time', 'y', 'x'))\n", - "y_true_ac = y_true.groupby('time.month').mean(dim=('time', 'y', 'x'))" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "bcc63e73-2636-49c1-a66b-241eb5407e2e", - "metadata": {}, - "outputs": [], - "source": [ - "# plot mean annual cycle\n", - "fig, ax = plt.subplots(1, 1, figsize=(10, 10))\n", - "ax.plot(y_pred_ac[PREDICTAND].values, ls='--', color='k', label='Predicted')\n", - "ax.plot(y_true_ac[PREDICTAND].values, ls='-', color='k', label='Observed')\n", - "ax.legend(frameon=False, fontsize=14);\n", - "ax.set_yticks(np.arange(np.floor(y_true_ac[PREDICTAND].min().item()), np.ceil(y_true_ac[PREDICTAND].max().item()) + 1, 1))\n", - "ax.set_yticklabels(np.arange(np.floor(y_true_ac[PREDICTAND].min().item()), np.ceil(y_true_ac[PREDICTAND].max().item()) + 1, 1), fontsize=12)\n", - "ax.set_xticks(np.arange(0, 12))\n", - "ax.set_xticklabels([calendar.month_name[i + 1] for i in np.arange(0, 12)], rotation=90, fontsize=12)\n", - "ax.set_title('Mean annual cycle of {}: 1991 - 2010'.format(NAMES[PREDICTAND]), pad=20, fontsize=16);\n", - "ax.set_ylabel('{} / (°C)'.format(NAMES[PREDICTAND].capitalize()), fontsize=14)\n", - "ax.set_xlabel('Month', fontsize=14);\n", - "\n", - "# save figure\n", - "fig.savefig('../Notebooks/Figures/{}_mean_annual_cycle.png'.format(PREDICTAND), dpi=300, bbox_inches='tight')" + "fig.savefig('../Notebooks/Figures/{}_bias_seasonal_{}.png'.format(PREDICTAND, LOSS), dpi=300, bbox_inches='tight')" ] }, { @@ -739,8 +950,7 @@ "outputs": [], "source": [ "# bias of extreme quantile: ERA-5\n", - "for var in bias_ex_ref:\n", - " print('(ERA-5) Yearly average bias for P{:.0f} of {}: {:.1f}°C'.format(quantile * 100, var, bias_ex_ref[var].mean().item()))" + "print('(ERA-5) Yearly average bias for P{:.0f} of {}: {:.1f}°C'.format(quantile * 100, PREDICTAND, bias_ex_ref.mean().item()))" ] }, { @@ -751,8 +961,7 @@ "outputs": [], "source": [ "# bias of extreme quantile: Model\n", - "for var in bias_ex:\n", - " print('(Model) Yearly average bias for P{:.0f} of {}: {:.1f}°C'.format(quantile * 100, var, bias_ex[var].mean().item()))" + "print('(Model) Yearly average bias for P{:.0f} of {}: {:.1f}°C'.format(quantile * 100, PREDICTAND, bias_ex.mean().item()))" ] }, { @@ -775,8 +984,7 @@ "outputs": [], "source": [ "# mae of extreme quantile: ERA-5\n", - "for var in mae_ex_ref:\n", - " print('(ERA-5) Yearly average MAE for P{:.0f} of {}: {:.1f}°C'.format(quantile * 100, var, mae_ex_ref[var].item()))" + "print('(ERA-5) Yearly average MAE for P{:.0f} of {}: {:.1f}°C'.format(quantile * 100, PREDICTAND, mae_ex_ref.item()))" ] }, { @@ -787,8 +995,7 @@ "outputs": [], "source": [ "# mae of extreme quantile: Model\n", - "for var in mae_ex:\n", - " print('(Model) Yearly average MAE for P{:.0f} of {}: {:.1f}°C'.format(quantile * 100, var, mae_ex[var].item()))" + "print('(Model) Yearly average MAE for P{:.0f} of {}: {:.1f}°C'.format(quantile * 100, PREDICTAND, mae_ex.item()))" ] }, { @@ -799,8 +1006,8 @@ "outputs": [], "source": [ "# root mean squared error in extreme quantile\n", - "rmse_ex = ((y_pred_ex - y_true_ex) ** 2).mean()\n", - "rmse_ex_ref = ((y_refe_ex - y_true_ex) ** 2).mean()" + "rmse_ex = np.sqrt(((y_pred_ex - y_true_ex) ** 2).mean())\n", + "rmse_ex_ref = np.sqrt(((y_refe_ex - y_true_ex) ** 2).mean())" ] }, { @@ -811,8 +1018,7 @@ "outputs": [], "source": [ "# rmse of extreme quantile: ERA-5\n", - "for var in rmse_ex_ref:\n", - " print('(ERA-5) Yearly average RMSE for P{:.0f} of {}: {:.1f}°C'.format(quantile * 100, var, rmse_ex_ref[var].item()))" + "print('(ERA-5) Yearly average RMSE for P{:.0f} of {}: {:.1f}°C'.format(quantile * 100, PREDICTAND, rmse_ex_ref.item()))" ] }, { @@ -823,8 +1029,7 @@ "outputs": [], "source": [ "# rmse of extreme quantile: Model\n", - "for var in rmse_ex:\n", - " print('(Model) Yearly average RMSE for P{:.0f} of {}: {:.1f}°C'.format(quantile * 100, var, rmse_ex[var].item()))" + "print('(Model) Yearly average RMSE for P{:.0f} of {}: {:.1f}°C'.format(quantile * 100, PREDICTAND, rmse_ex.item()))" ] }, { @@ -838,23 +1043,21 @@ "source": [ "# plot extremes of observation, prediction, and bias\n", "vmin, vmax = (-20, 0) if PREDICTAND == 'tasmin' else (10, 40)\n", - "fig, axes = plt.subplots(len(y_pred_ex.data_vars), 3, figsize=(24, len(y_pred_ex.data_vars) * 6),\n", - " sharex=True, sharey=True)\n", - "axes = axes.reshape(len(y_pred_ex.data_vars), -1)\n", - "for i, var in enumerate(y_pred_ex):\n", - " for ds, ax in zip([y_true_ex, y_pred_ex, bias_ex], axes[i, ...]):\n", - " if ds is bias_ex:\n", - " ds = ds[var].mean(dim='year')\n", - " im2 = ax.imshow(ds.values, origin='lower', cmap='RdBu_r', vmin=-2, vmax=2)\n", - " ax.text(x=ds.shape[0] - 2, y=2, s='Average: {:.2f}°C'.format(ds.mean().item()), fontsize=14, ha='right')\n", - " else:\n", - " im1 = ax.imshow(ds[var].mean(dim='year').values, origin='lower', cmap='Blues_r' if PREDICTAND == 'tasmin' else 'Reds',\n", - " vmin=vmin, vmax=vmax)\n", - " \n", + "fig, axes = plt.subplots(1, 3, figsize=(24, 8), sharex=True, sharey=True)\n", + "axes = axes.flatten()\n", + "for ds, ax in zip([y_true_ex, y_pred_ex, bias_ex], axes):\n", + " if ds is bias_ex:\n", + " ds = ds.mean(dim='year')\n", + " im2 = ax.imshow(ds.values, origin='lower', cmap='RdBu_r', vmin=-2, vmax=2)\n", + " ax.text(x=ds.shape[0] - 2, y=2, s='Average: {:.2f}°C'.format(ds.mean().item()), fontsize=14, ha='right')\n", + " else:\n", + " im1 = ax.imshow(ds.mean(dim='year').values, origin='lower', cmap='Blues_r' if PREDICTAND == 'tasmin' else 'Reds',\n", + " vmin=vmin, vmax=vmax)\n", + "\n", "# set titles\n", - "axes[0, 0].set_title('Observed', fontsize=16, pad=10);\n", - "axes[0, 1].set_title('Predicted', fontsize=16, pad=10);\n", - "axes[0, 2].set_title('Bias', fontsize=16, pad=10);\n", + "axes[0].set_title('Observed', fontsize=16, pad=10);\n", + "axes[1].set_title('Predicted', fontsize=16, pad=10);\n", + "axes[2].set_title('Bias', fontsize=16, pad=10);\n", "\n", "# adjust axes\n", "for ax in axes.flat:\n", @@ -887,11 +1090,11 @@ "cbar_predictand.ax.tick_params(labelsize=14)\n", "\n", "# add metrics: MAE and RMSE\n", - "axes[1].text(x=ds.shape[0] - 2, y=2, s='MAE = {:.2f}°C'.format(mae_ex[PREDICTAND].item()), fontsize=14, ha='right')\n", - "axes[1].text(x=ds.shape[0] - 2, y=12, s='RMSE = {:.2f}°C$^2$'.format(rmse_ex[PREDICTAND].item()), fontsize=14, ha='right')\n", + "axes[1].text(x=ds.shape[0] - 2, y=2, s='MAE = {:.2f}°C'.format(mae_ex.item()), fontsize=14, ha='right')\n", + "axes[1].text(x=ds.shape[0] - 2, y=12, s='RMSE = {:.2f}°C$^2$'.format(rmse_ex.item()), fontsize=14, ha='right')\n", "\n", "# save figure\n", - "fig.savefig('../Notebooks/Figures/{}_average_bias_p{:.0f}.png'.format(PREDICTAND, quantile * 100), dpi=300, bbox_inches='tight')" + "fig.savefig('../Notebooks/Figures/{}_bias_p{:.0f}_{}.png'.format(PREDICTAND, quantile * 100, LOSS), dpi=300, bbox_inches='tight')" ] }, { @@ -1002,7 +1205,7 @@ ], "metadata": { "kernelspec": { - "display_name": "Python 3", + "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, diff --git a/Notebooks/lr_range_test.ipynb b/Notebooks/lr_range_test.ipynb new file mode 100644 index 0000000000000000000000000000000000000000..f12faf6d50687fee13302ae02360e18564b3fb4c --- /dev/null +++ b/Notebooks/lr_range_test.ipynb @@ -0,0 +1,290 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "4d010ae8-4172-417f-a457-c14e46adbf85", + "metadata": {}, + "source": [ + "# Learning rate range test" + ] + }, + { + "cell_type": "markdown", + "id": "c3200572-47da-48a6-a143-0ab26a3a3e40", + "metadata": {}, + "source": [ + "Define the predictand and the model to evaluate:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "dd249096-9e57-425f-a0e9-29f31ccf4f6c", + "metadata": {}, + "outputs": [], + "source": [ + "# define the model parameters\n", + "PREDICTANDS = ['pr', 'tasmax', 'tasmin']\n", + "MODEL = 'USegNet'\n", + "PPREDICTORS = 'ztuvq'\n", + "# PPREDICTORS = ''\n", + "PLEVELS = ['500', '850']\n", + "# PLEVELS = []\n", + "SPREDICTORS = 'p'\n", + "# SPREDICTORS = 'pr'\n", + "DEM = 'dem'\n", + "DEM_FEATURES = ''\n", + "DOY = 'doy'\n", + "LOSS = ['BernoulliGammaLoss', 'L1Loss', 'MSELoss']\n", + "OPTIM = ['SGD', 'Adam']\n", + "WET_DAY_THRESHOLD = '1'" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5c4f6f06-6bf2-453c-a2a8-88386358c36f", + "metadata": {}, + "outputs": [], + "source": [ + "# mapping from predictands to variable names\n", + "NAMES = {'tasmin': 'minimum temperature', 'tasmax': 'maximum temperature', 'pr': 'precipitation'}" + ] + }, + { + "cell_type": "markdown", + "id": "0f47fdee-484b-457b-bdbf-a3b2c67132d9", + "metadata": {}, + "source": [ + "### Imports" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "21b9927a-ebbe-48ff-9c64-fbb27a069f4e", + "metadata": {}, + "outputs": [], + "source": [ + "# builtins\n", + "import datetime\n", + "import warnings\n", + "import calendar\n", + "\n", + "# externals\n", + "import torch\n", + "import xarray as xr\n", + "import numpy as np\n", + "import matplotlib.pyplot as plt\n", + "import matplotlib.lines as mlines\n", + "import seaborn as sns\n", + "\n", + "# locals\n", + "from climax.main.io import ERA5_PATH, OBS_PATH, TARGET_PATH, DEM_PATH, MODEL_PATH\n", + "from climax.core.utils import plot_loss\n", + "from pysegcnn.core.utils import search_files\n", + "from pysegcnn.core.graphics import running_mean" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "ff1b059d-71b6-46bb-b2c3-66b56120546f", + "metadata": {}, + "outputs": [], + "source": [ + "# path to models for learning rate range test\n", + "MODEL_PATH = MODEL_PATH.joinpath('lr-range')" + ] + }, + { + "cell_type": "markdown", + "id": "40c94ce9-ed05-4434-908a-91683c1400cb", + "metadata": {}, + "source": [ + "### Load datasets" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "49cf0a29-9564-4c78-9fcd-57af945855df", + "metadata": {}, + "outputs": [], + "source": [ + "# construct file pattern to match\n", + "PATTERN = '_'.join([MODEL, 'predictand',])\n", + "PATTERN = '_'.join([PATTERN, PPREDICTORS]) if PPREDICTORS else PATTERN\n", + "PATTERN = '_'.join([PATTERN, *PLEVELS]) if PPREDICTORS else PATTERN\n", + "PATTERN = '_'.join([PATTERN, SPREDICTORS]) if SPREDICTORS else PATTERN\n", + "PATTERN = '_'.join([PATTERN, DEM]) if DEM else PATTERN\n", + "PATTERN = '_'.join([PATTERN, DOY]) if DOY else PATTERN\n", + "PATTERN = '_'.join([PATTERN, 'loss'])\n", + "PATTERN = '_'.join([PATTERN, 'optim'])\n", + "PATTERN = '_'.join([PATTERN, 'lr_test.pt'])\n", + "PATTERNS = {k: PATTERN.replace('predictand', k) for k in PREDICTANDS}\n", + "PATTERNS = {k1: {v2: v1.replace('loss', '{}mm_{}'.format(WET_DAY_THRESHOLD, v2)) if v2 in\n", + " ['BernoulliGammaLoss', 'BernoulliWeibullLoss'] else v1.replace('loss', v2) for\n", + " v2 in LOSS} for k1, v1 in PATTERNS.items()}\n", + "PATTERNS = {optim: {k: {l: v.replace('optim', optim) for l, v in PATTERNS[k].items()} for k in PREDICTANDS} for optim in OPTIM}\n", + "PATTERNS" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "fee242d6-cd1a-48b4-ac3d-6d35d428247d", + "metadata": {}, + "outputs": [], + "source": [ + "# load datasets based on loss function\n", + "y_pred = {}\n", + "for o in OPTIM:\n", + " y_pred[o] = {}\n", + " for k in PREDICTANDS:\n", + " y_pred[o][k] = {}\n", + " for l, v in PATTERNS[o][k].items():\n", + " try:\n", + " y_pred[o][k][l] = torch.load(MODEL_PATH.joinpath(k, v), map_location=torch.device('cpu'))['state']\n", + " except FileNotFoundError:\n", + " y_pred[o][k][l] = np.nan" + ] + }, + { + "cell_type": "markdown", + "id": "14493ac2-e2bd-40f8-a52b-57d2d9df5942", + "metadata": {}, + "source": [ + "## Loss as function of learning rate" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b19e74f7-50c8-4445-8190-520ec7f9c2c2", + "metadata": {}, + "outputs": [], + "source": [ + "# minimum learning rate\n", + "MIN_LR = 1e-4\n", + "\n", + "# factor of learning rate increase\n", + "GAMMA = 1.6" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d3b806dc-f729-4f4f-ba63-a9e1595d5666", + "metadata": {}, + "outputs": [], + "source": [ + "# color palette\n", + "PALETTE = 'mako'\n", + "COLORS = sns.color_palette(PALETTE, n_colors=len(LOSS))\n", + "COLORS" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7916e76e-806c-43c6-8dd6-c5d0296d982a", + "metadata": {}, + "outputs": [], + "source": [ + "# plot loss as function of learning rate\n", + "fig, axes = plt.subplots(1, len(PREDICTANDS), figsize=(16, 6), sharex=True, sharey=True)\n", + "axes = axes.flatten()\n", + "max_lrs = {}\n", + "for ax, p in zip(axes, PREDICTANDS):\n", + " max_lrs[p] = {}\n", + " for l, c in zip(LOSS, COLORS):\n", + " max_lrs[p][l] = {}\n", + " # plot observed loss as function of learning rate\n", + " for o, ls, fs in zip(OPTIM, ['-', ':'], ['full', 'none']):\n", + " try:\n", + " loss = y_pred[o][p][l]['train_loss']\n", + " batches, epochs = loss.shape \n", + " lr = running_mean(np.repeat(np.asarray([MIN_LR * (GAMMA ** e) for e in range(0, epochs)]), batches), batches)\n", + " loss = running_mean(loss.flatten('F').clip(max=100), batches)\n", + "\n", + " # learning rate at minimum loss\n", + " mask = ~np.isnan(loss)\n", + " try:\n", + " max_lr = lr[mask][np.argmin(loss[mask])]\n", + " min_loss = np.min(loss[mask])\n", + " except ValueError:\n", + " max_lr = np.nan\n", + " max_lrs[p][l][o] = max_lr\n", + "\n", + " # plot running mean\n", + " ax.plot(lr, loss, color=c, ls=ls)\n", + "\n", + " # plot minimum loss\n", + " ax.plot(max_lr, min_loss, 'o', color=c, fillstyle=fs)\n", + " # ax.legend(frameon=False, fontsize=14);\n", + " except TypeError:\n", + " ax.plot(np.nan, np.nan)\n", + "\n", + " # axes properties\n", + " ax.set_title('{}'.format(NAMES[p].capitalize()), fontsize=16, pad=10)\n", + " ax.set_xscale('log')\n", + " ax.set_yscale('log')\n", + " ax.set_ylim(5e-1, 2e2);\n", + " ax.set_xlim(5e-5, 3e0);\n", + " ax.set_xlabel('Learning rate $\\\\alpha$', fontsize=16)\n", + " ax.tick_params(axis='both', which='major', labelsize=16)\n", + " \n", + "# add legends\n", + "\n", + "# loss functions\n", + "patches = [mlines.Line2D([], [], color=c, label=l) for c, l in\n", + " zip(COLORS, LOSS)]\n", + "axes[0].legend(handles=patches, loc='upper left', frameon=False, bbox_to_anchor=(-0.05, -0.15), ncol=len(LOSS), fontsize=14)\n", + "\n", + "# optimizers\n", + "patches = [mlines.Line2D([], [], color='black', ls=ls, label=l) for ls, l in\n", + " zip(['-', ':'], OPTIM)]\n", + "axes[1].legend(handles=patches, loc='upper left', frameon=False, bbox_to_anchor=(-1.1, -0.25), ncol=len(OPTIM), fontsize=14)\n", + " \n", + "# figure properties\n", + "axes[0].set_ylabel('Training loss', fontsize=16) \n", + "fig.subplots_adjust(wspace=0.05)\n", + "fig.savefig('./Figures/lr_range_test.png', dpi=300, bbox_inches='tight')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d5588c08-1fec-42ba-9463-7e714e0daf03", + "metadata": {}, + "outputs": [], + "source": [ + "# print learning rates at minimum of loss, i.e. max learning rate\n", + "for p in PREDICTANDS:\n", + " [[print('({}), {}: {}: Max LR = {:.5f}'.format(p, l, o, lr)) for (o, lr) in max_lrs[p][l].items()] for l in max_lrs[p].keys()]" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.7.10" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/Notebooks/pr_distribution.ipynb b/Notebooks/pr_distribution.ipynb index 1a41e3d539ccb0dfd5b3310cc214242655de63f8..f8f697499276ade57ca415fd074b7ed57ddbd59b 100644 --- a/Notebooks/pr_distribution.ipynb +++ b/Notebooks/pr_distribution.ipynb @@ -62,6 +62,17 @@ "quantiles = np.arange(0.01, 1, 0.005)" ] }, + { + "cell_type": "code", + "execution_count": null, + "id": "0fd28eaf-2939-40e1-841a-49dabe6778d7", + "metadata": {}, + "outputs": [], + "source": [ + "# wet day threshold\n", + "WET_DAY_THRESHOLD = 1" + ] + }, { "cell_type": "markdown", "id": "12382efb-1a3a-4ede-a904-7f762bfe56c7", @@ -129,10 +140,10 @@ "outputs": [], "source": [ "# helper function retrieving only valid observations\n", - "def valid(ds):\n", + "def valid(ds, min_amount=WET_DAY_THRESHOLD):\n", " valid = ds.precipitation.values\n", " valid = valid[~np.isnan(valid)] # mask missing values\n", - " valid = valid[valid > 0] # only consider pr > 0\n", + " valid = valid[valid > min_amount] # only consider pr > 0\n", " return valid" ] }, @@ -179,8 +190,8 @@ "outputs": [], "source": [ "# fit generalized pareto distribution to data\n", - "alpha, loc, beta = stats.genpareto.fit(y_valid, floc=0)\n", - "genpareto = stats.genpareto(alpha, loc=loc, scale=beta)" + "# alpha, loc, beta = stats.genpareto.fit(y_valid, floc=0)\n", + "# genpareto = stats.genpareto(alpha, loc=loc, scale=beta)" ] }, { @@ -207,6 +218,16 @@ "weibull = stats.weibull_min(alpha, loc=loc, scale=beta)" ] }, + { + "cell_type": "code", + "execution_count": null, + "id": "c7daee53-d996-42ed-aa84-a4770191cafe", + "metadata": {}, + "outputs": [], + "source": [ + "alpha, loc, beta" + ] + }, { "cell_type": "code", "execution_count": null, @@ -217,7 +238,7 @@ "# empirical quantiles and theoretical quantiles\n", "eq = np.quantile(y_valid, quantiles)\n", "tq_gamma = gamma.ppf(quantiles)\n", - "tq_genpareto = genpareto.ppf(quantiles)\n", + "#tq_genpareto = genpareto.ppf(quantiles)\n", "tq_expon = expon.ppf(quantiles)\n", "tq_lognorm = lognorm.ppf(quantiles)\n", "tq_weibull = weibull.ppf(quantiles)\n", @@ -226,7 +247,7 @@ "RANGE = 40\n", "fig, ax = plt.subplots(1, 1, figsize=(6, 6))\n", "ax.scatter(eq, tq_gamma, marker='*', color='k', label='Gamma')\n", - "ax.scatter(eq, tq_genpareto, marker='x', color='k', label='GenPareto')\n", + "# ax.scatter(eq, tq_genpareto, marker='x', color='k', label='GenPareto')\n", "ax.scatter(eq, tq_expon, marker='o', color='k', label='Expon')\n", "ax.scatter(eq, tq_lognorm, marker='+', color='k', label='LogNorm')\n", "ax.scatter(eq, tq_weibull, marker='^', color='k', label='Weibull')\n", diff --git a/Notebooks/pr_loss.ipynb b/Notebooks/pr_loss.ipynb new file mode 100644 index 0000000000000000000000000000000000000000..ffbd160b3c7a6c5854bc69e8189d56d555420677 --- /dev/null +++ b/Notebooks/pr_loss.ipynb @@ -0,0 +1,193 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "b21026e9-2bd2-4b52-a902-57a3b21875b1", + "metadata": {}, + "source": [ + "# Loss functions" + ] + }, + { + "cell_type": "markdown", + "id": "6b81f5f2-4f49-4680-b6d2-32e7668ca53c", + "metadata": {}, + "source": [ + "For precipitation, the network is optimizing the negative log-likelihood of a mixed discrete Bernoulli and a continuous Gamma or Weibull distribution. The loss function is defined as the Log-likelihood $l(p, \\alpha, \\beta \\mid y)$ function of the mixed discrete-continuous distribution $P(y \\mid p, \\alpha, \\beta)$ for $i = 1 \\dots N$ i.i.d observations $y_i$:" + ] + }, + { + "cell_type": "markdown", + "id": "ac97e5ef-1881-44d7-9704-076d26347e15", + "metadata": {}, + "source": [ + "$$l(p, \\alpha, \\beta \\mid y) = \\log\\left(\\prod_{i=1}^{N} P(y_i \\mid p, \\alpha, \\beta)\\right) = \\sum_{i=1}^{N} \\log\\left(P(y_i \\mid p, \\alpha, \\beta)\\right)$$" + ] + }, + { + "cell_type": "markdown", + "id": "b0697679-ec7b-4ad4-a676-1c362f3affff", + "metadata": {}, + "source": [ + "## Bernoulli-Gamma loss function" + ] + }, + { + "cell_type": "markdown", + "id": "209fc0e8-fff2-44fe-bf70-470c558068d9", + "metadata": {}, + "source": [ + "The Bernoulli-Gamma distribution for precipitation was introduced by [Cannon (2008)](http://journals.ametsoc.org/doi/10.1175/2008JHM960.1) and is defined as:" + ] + }, + { + "cell_type": "markdown", + "id": "9289faf7-b006-4a25-8afe-c09a07b710ae", + "metadata": {}, + "source": [ + "$$P(y \\mid, p, \\alpha, \\beta) = \\begin{cases} 1 - p, & \\text{for } y = 0\\\\ p \\cdot \\frac{y^{\\alpha -1} \\exp\\left(-\\frac{y}{\\beta}\\right)}{\\beta^{\\alpha} \\tau(\\alpha)}, & \\text{for } y > 0\\end{cases}$$" + ] + }, + { + "cell_type": "markdown", + "id": "4d40fbc3-0180-4ed3-8fdf-97326f9cb5c8", + "metadata": {}, + "source": [ + "with $\\alpha, \\beta > 0$ as the *shape* and *scale* parameters, $p \\in [0, 1]$ the *predicted probability* of precipitation, $y$ the observed precipitation amount, and $\\tau(\\alpha)$ the *gamma function*." + ] + }, + { + "cell_type": "markdown", + "id": "9f8e2eb6-02b3-40a4-be3f-daea5a623235", + "metadata": {}, + "source": [ + "Log-likelihood function of the Bernoulli-Gamma distribution," + ] + }, + { + "cell_type": "markdown", + "id": "8c563a68-a154-47e1-bd4c-d0fb1a3bd16f", + "metadata": {}, + "source": [ + "$$\\mathcal{l}(p, \\alpha, \\beta \\mid y) = \\underbrace{(1 - P(y > 0)) \\log(1 - p)}_{\\text{Bernoulli}} + \\underbrace{P(y > 0) \\cdot \\left(\\log(p) + (\\alpha - 1) \\log(y) - \\frac{y}{\\beta} - \\alpha \\log(\\beta) - \\log(\\tau(\\alpha))\\right)}_{\\text{Gamma}}$$" + ] + }, + { + "cell_type": "markdown", + "id": "8d8ed32a-e7bf-491a-a59f-34cc9be82584", + "metadata": {}, + "source": [ + "where $P(y > 0) \\in \\{0, 1\\}$ is the *true probability* of precipitation." + ] + }, + { + "cell_type": "markdown", + "id": "8e398417-8e20-4684-bf30-fbd3c186f6c1", + "metadata": {}, + "source": [ + "## Bernoulli-Weibull loss function" + ] + }, + { + "cell_type": "markdown", + "id": "e39f72f9-58a8-4093-8fce-4870e2364de1", + "metadata": {}, + "source": [ + "The Bernoulli-Weibull distribution is defined as:" + ] + }, + { + "cell_type": "markdown", + "id": "41067b17-745e-4f8e-9530-c2c72ec5d263", + "metadata": {}, + "source": [ + "$$P(y \\mid, p, \\alpha, \\beta) = \\begin{cases} 1 - p, & \\text{for } y = 0\\\\ p \\cdot \\frac{\\alpha}{\\beta}\\left(\\frac{y}{\\beta}\\right)^{\\alpha - 1} \\exp\\left(\\left(-\\frac{y}{\\beta}\\right)^{\\alpha}\\right), & \\text{for } y > 0\\end{cases}$$" + ] + }, + { + "cell_type": "markdown", + "id": "b8d14bcc-3c4c-4cef-bd07-3cd96cb3f79f", + "metadata": {}, + "source": [ + "with $\\alpha, \\beta > 0$ as the *shape* and *scale* parameters, $p \\in [0, 1]$ the *predicted probability* of precipitation, and $y$ the observed precipitation amount." + ] + }, + { + "cell_type": "markdown", + "id": "e87fe6bc-a666-421e-b04a-941827342244", + "metadata": {}, + "source": [ + "Log-likelihood function of the Bernoulli-Weibull distribution," + ] + }, + { + "cell_type": "markdown", + "id": "db6518d7-a329-432b-b750-8a6cba54d909", + "metadata": {}, + "source": [ + "$$\\mathcal{l}(p, \\alpha, \\beta \\mid y) = \\underbrace{(1 - P(y > 0)) \\log(1 - p)}_{\\text{Bernoulli}} + \\underbrace{P(y > 0) \\cdot \\left(\\log(\\alpha) - \\alpha \\log(\\beta) + (\\alpha - 1) \\log(y) - \\left(\\frac{y}{\\beta}\\right)^{\\alpha}\\right)}_{\\text{Weibull}}$$" + ] + }, + { + "cell_type": "markdown", + "id": "c955910c-49c6-4e63-b06f-761fc3d7cb1d", + "metadata": {}, + "source": [ + "where $P(y > 0) \\in \\{0, 1\\}$ is the *true probability* of precipitation." + ] + }, + { + "cell_type": "markdown", + "id": "f9847ca2-16f2-4286-8eb0-54773a560234", + "metadata": {}, + "source": [ + "## Loss function with regularization" + ] + }, + { + "cell_type": "markdown", + "id": "e95ff9df-a44b-4d03-bbee-5406835df453", + "metadata": {}, + "source": [ + "The loss function with **regularization (shrinkage)** term is defined as," + ] + }, + { + "cell_type": "markdown", + "id": "fffb81c0-3358-40fa-b3f2-e9a19014cde8", + "metadata": {}, + "source": [ + "$$\\mathcal{l}(p, \\alpha, \\beta \\mid y, \\theta)_{\\lambda} = \\underbrace{\\mathcal{l}(p, \\alpha, \\beta \\mid y, \\theta)}_{\\text{loss function}} + \\underbrace{\\lambda \\mid\\mid\\theta\\mid\\mid_2^2}_{L_2-\\text{penalty}},$$" + ] + }, + { + "cell_type": "markdown", + "id": "5a62bc1a-c469-4114-812d-588f0916918a", + "metadata": {}, + "source": [ + "where $\\lambda \\in [0, 1]$ is the weight decay factor." + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.7.10" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +}