From f3f1c003c312bdbba3ebffc6e57353da7bc54e06 Mon Sep 17 00:00:00 2001
From: "Daniel.Frisinghelli" <daniel.frisinghelli@eurac.edu>
Date: Tue, 26 Oct 2021 15:22:25 +0200
Subject: [PATCH] Implemented generic computation of anomalies on arbitrary
 time-scale.

---
 climax/core/dataset.py | 47 ++++++++++++++++++++++++++++++++++++++++++
 1 file changed, 47 insertions(+)

diff --git a/climax/core/dataset.py b/climax/core/dataset.py
index fce18fe..332442b 100644
--- a/climax/core/dataset.py
+++ b/climax/core/dataset.py
@@ -6,6 +6,7 @@
 # builtins
 import logging
 import pathlib
+import warnings
 from datetime import date
 
 # externals
@@ -232,6 +233,52 @@ class EoDataset(torch.utils.data.Dataset):
 
         return dem_features
 
+    @staticmethod
+    def anomalies(ds, timescale='time.dayofyear', standard=False):
+        # group dataset by day of the year
+        LOGGER.info('Computing standardized anomalies ...')
+        groups = ds.groupby('time.dayofyear').groups
+
+        # compute standardized anomalies for each day of the year over time
+        anomalies = {}
+        for time, time_scale in groups.items():
+            with warnings.catch_warnings():
+                warnings.simplefilter('ignore', category=RuntimeWarning)
+                # anomaly = (x(t) - mean(x, t))
+                anomalies[time] = (ds.isel(time=time_scale) -
+                                   ds.isel(time=time_scale).mean(dim='time'))
+
+                # standardized anomaly = (x(t) - mean(x, t)) / std(x, t)
+                if standard:
+                    anomalies[time] = (anomalies[time] /
+                                       ds.isel(time=time_scale).std(dim='time')
+                                       )
+
+        # concatenate anomalies and sort chronologically
+        anomalies = xr.concat(anomalies.values(), dim='time')
+        anomalies = anomalies.sortby(anomalies.time)
+
+        return anomalies
+
+    @staticmethod
+    def normalize(ds, dim=('time', 'y', 'x'), period=None):
+        # normalize predictors to [0, 1]
+        LOGGER.info('Normalizing data to [0, 1] ...')
+        with warnings.catch_warnings():
+            warnings.simplefilter('ignore', category=RuntimeWarning)
+
+            # whether to normalize using statistics for a specific period
+            # NOTE: this can result in values that are not in [0, 1]
+            if period is not None:
+                ds -= ds.sel(time=period).min(dim=dim)
+                ds /= ds.sel(time=period).max(dim=dim)
+            # normalize using entire period: [0, 1]
+            else:
+                ds -= ds.min(dim=dim)
+                ds /= ds.max(dim=dim)
+
+        return ds
+
 
 class NetCDFDataset(EoDataset):
 
-- 
GitLab