{ "cells": [ { "cell_type": "markdown", "id": "0ade2427-0020-414a-957b-4ed71281b61f", "metadata": {}, "source": [ "## Generalized linear regression" ] }, { "cell_type": "code", "execution_count": null, "id": "cfd6644a-9af1-4b71-8dac-1e423786e0a8", "metadata": {}, "outputs": [], "source": [ "# builtins\n", "import pathlib\n", "\n", "# externals\n", "import numpy as np\n", "import pandas as pd\n", "import xarray as xr\n", "from sklearn.linear_model import TweedieRegressor\n", "\n", "# locals\n", "from climax.core.dataset import ERA5Dataset\n", "from climax.core.constants import ERA5_VARIABLES\n", "from climax.core.utils import search_files" ] }, { "cell_type": "code", "execution_count": null, "id": "646406dd-8c0d-4cdc-b7c3-b0b55b2ec4bd", "metadata": {}, "outputs": [], "source": [ "# path to ERA5 reanalysis data\n", "ERA5_PATH = pathlib.Path('/mnt/CEPH_PROJECTS/FACT_CLIMAX/REANALYSIS/ERA5')" ] }, { "cell_type": "code", "execution_count": null, "id": "18d222d5-fe39-4858-af39-04a14171cad6", "metadata": {}, "outputs": [], "source": [ "# list of valid predictor variable names\n", "ERA5_VARIABLES" ] }, { "cell_type": "code", "execution_count": null, "id": "1cdca239-c1c9-41c6-bce5-7929ff3b19de", "metadata": {}, "outputs": [], "source": [ "# define the predictor variables you want to use\n", "ERA5_PREDICTORS = ['geopotential', 'temperature', 'mean_sea_level_pressure'] # use geopotential, temperature and pressure\n", "\n", "# you can change this list as you wish, e.g.:\n", "# ERA5_PREDICTORS = ['geopotential', 'temperature'] # use only geopotential and temperature\n", "# ERA5_PREDICTORS = ERA5_VARIABLES # use all ERA5 variables as predictors\n", "\n", "# this checks if the variable names are correct\n", "assert all([p in ERA5_VARIABLES for p in ERA5_PREDICTORS]) " ] }, { "cell_type": "markdown", "id": "2fd49c6d-a7dc-49cb-87ba-c1f539ff2127", "metadata": {}, "source": [ "### Use the climax package to load ERA5 predictors" ] }, { "cell_type": "code", "execution_count": null, "id": "72e4c65c-add4-47a8-aae0-e4ae3301301a", "metadata": {}, "outputs": [], "source": [ "# define which pressure levels you want to use: currently only 500 and 850 are available\n", "PLEVELS = [500, 850]" ] }, { "cell_type": "code", "execution_count": null, "id": "35979cb6-0a18-41bd-9eea-081b187be5ff", "metadata": {}, "outputs": [], "source": [ "# create the xarray.Dataset of the specified predictor variables\n", "predictors = ERA5Dataset(ERA5_PATH, ERA5_PREDICTORS, plevels=PLEVELS)\n", "predictors = predictors.merge()" ] }, { "cell_type": "code", "execution_count": null, "id": "155c0954-cb8a-40d1-99e9-6c46281ae899", "metadata": {}, "outputs": [], "source": [ "# check out the xarray.Dataset: you will see all the variables you specified\n", "predictors" ] }, { "cell_type": "markdown", "id": "727fc0f2-c8ef-480d-a273-a07d3dff0141", "metadata": {}, "source": [ "### Load target data: observations" ] }, { "cell_type": "code", "execution_count": null, "id": "629d12b0-e4d1-4cc8-8c02-bd62bbb67cdc", "metadata": {}, "outputs": [], "source": [ "# path to observation data\n", "OBS_PATH = pathlib.Path('/mnt/CEPH_PROJECTS/FACT_CLIMAX/OBSERVATION')" ] }, { "cell_type": "code", "execution_count": null, "id": "c589cdea-09ed-415e-8613-a5ec311fdb2b", "metadata": {}, "outputs": [], "source": [ "# define the predictand, i.e. tasmax, tasmin or pr\n", "PREDICTAND = 'tasmax'" ] }, { "cell_type": "code", "execution_count": null, "id": "c8301c21-c32b-4123-a405-bde4b4d80f48", "metadata": {}, "outputs": [], "source": [ "# load the observation data\n", "predictand = xr.open_dataset(search_files(OBS_PATH.joinpath(PREDICTAND), '.nc$').pop())" ] }, { "cell_type": "code", "execution_count": null, "id": "4760558e-b1ed-4b16-9609-cf301aa39f55", "metadata": {}, "outputs": [], "source": [ "# check out the xarray.Dataset: you will see a single variable\n", "predictand" ] }, { "cell_type": "markdown", "id": "756e22e2-0b53-4d8f-877c-b9aee2659e17", "metadata": {}, "source": [ "### Prepare training data" ] }, { "cell_type": "code", "execution_count": null, "id": "8aec082a-1d6a-4932-8600-efc48f886cb3", "metadata": {}, "outputs": [], "source": [ "# define the training period and the validation period\n", "TRAIN_PERIOD = slice('1981-01-01', '1991-01-01')\n", "VALID_PERIOD = slice('1991-01-01', '2010-01-01')" ] }, { "cell_type": "code", "execution_count": null, "id": "0812673a-5a9d-48bf-ba87-3f96cb26ff49", "metadata": {}, "outputs": [], "source": [ "# select training and validation data: predictors\n", "predictors_train = predictors.sel(time=TRAIN_PERIOD)\n", "predictors_valid = predictors.sel(time=VALID_PERIOD)" ] }, { "cell_type": "code", "execution_count": null, "id": "c326326c-5621-4ed4-af2b-a942bc678e71", "metadata": {}, "outputs": [], "source": [ "# select training and validation data: predictand\n", "predictand_train = predictand.sel(time=TRAIN_PERIOD)\n", "predictand_valid = predictand.sel(time=VALID_PERIOD)" ] }, { "cell_type": "markdown", "id": "9a410594-714e-4124-b155-d2985ed5b6cb", "metadata": {}, "source": [ "### Train the generalized linear regression model" ] }, { "cell_type": "code", "execution_count": null, "id": "57da5a56-bcd7-44eb-9304-62697d779f7f", "metadata": {}, "outputs": [], "source": [ "# instanciate the GLM\n", "model = TweedieRegressor(power=0 if PREDICTAND in ['tasmax', 'tasmin'] else 2)\n", "model\n", "# power = 0: Normal distribution (tasmax, tasmin)\n", "# power = 1: Poisson distribution\n", "# power = (1, 2): Compound Poisson Gamma distribution\n", "# power = 2: Gamma distribution (pr)\n", "# power = 3: Inverse Gaussian" ] }, { "cell_type": "code", "execution_count": null, "id": "802169db-8a22-42ad-8d4f-e29f676a49e9", "metadata": {}, "outputs": [], "source": [ "# function to normalize predictors to [0, 1]\n", "def normalize(predictors):\n", " predictors -= predictors.min(axis=1, keepdims=True)\n", " predictors /= predictors.max(axis=1, keepdims=True)\n", " return predictors" ] }, { "cell_type": "code", "execution_count": null, "id": "5382f429-d8e7-4636-8f47-922ae20a91cc", "metadata": {}, "outputs": [], "source": [ "# iterate over the grid points\n", "prediction = np.ones(shape=(len(predictors_valid.time), len(predictors_valid.y), len(predictors_valid.x))) * np.nan\n", "for i, _ in enumerate(predictors_train.x):\n", " for j, _ in enumerate(predictors_train.y):\n", " \n", " # current grid point: xarray.Dataset, dimensions=(time)\n", " point_predictors = predictors_train.isel(x=i, y=j)\n", " point_predictand = predictand_train.isel(x=i, y=j)\n", " \n", " # convert xarray.Dataset to numpy.array: shape=(time, predictors)\n", " point_predictors = point_predictors.to_array().values.swapaxes(0, 1)\n", " point_predictand = point_predictand.to_array().values.squeeze()\n", " \n", " # check if the grid point is valid\n", " if np.isnan(point_predictors).any() or np.isnan(point_predictand).any():\n", " # move on to next grid point\n", " continue\n", " \n", " # normalize each predictor variable to [0, 1]\n", " point_predictors = normalize(point_predictors)\n", " \n", " # instanciate the model for the current grid point\n", " model = TweedieRegressor(power=0 if PREDICTAND in ['tasmax', 'tasmin'] else 2)\n", " \n", " # train model on training data\n", " model.fit(point_predictors, point_predictand)\n", " print('Processing grid point: ({:d}, {:d}), score: {:.2f}'.format(j, i, model.score(point_predictors, point_predictand)))\n", " \n", " # prepare predictors of validation period\n", " point_validation = predictors_valid.isel(x=i, y=j).to_array().values.swapaxes(0, 1)\n", " point_validation = normalize(point_validation)\n", " \n", " # predict validation period\n", " pred = model.predict(point_validation)\n", " \n", " # store predictions for current grid point\n", " prediction[:, j, i] = pred\n", " \n", "# store predictions in xarray.Dataset\n", "predictions = xr.DataArray(data=prediction, dims=['time', 'y', 'x'],\n", " coords=dict(time=pd.date_range(VALID_PERIOD.start, VALID_PERIOD.stop, freq='D'),\n", " lat=predictand_valid.y, lon=predictand_valid.x))\n", "predictions = predictions.to_dataset(name=PREDICTAND)" ] }, { "cell_type": "markdown", "id": "595ee11d-570d-42ed-8c40-226d724aa9fb", "metadata": {}, "source": [ "### Save predictions as NetCDF" ] }, { "cell_type": "code", "execution_count": null, "id": "fb708c40-c0e3-4c62-bc4e-9d0b6bfc52d4", "metadata": {}, "outputs": [], "source": [ "# specify the output path, filename: PREDICTAND.nc\n", "OUTPUT_PATH = pathlib.Path('~/{}'.format(PREDICTAND + '.nc'))\n", "\n", "# save to NetCDF\n", "predictions.to_netcdf(OUTPUT_PATH, engine='h5netcdf')" ] }, { "cell_type": "code", "execution_count": null, "id": "fc9aba1a-fde1-48c6-85e7-02cb5ebdcfa7", "metadata": {}, "outputs": [], "source": [ "# Enjoy and have fun!" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.10" } }, "nbformat": 4, "nbformat_minor": 5 }