From fb00c822f0330b451d6f54661750e0be1e132c5e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jos=C3=A9=20Morales?= Date: Wed, 12 Jun 2024 12:08:52 -0600 Subject: [PATCH] support timezones for pandas in fill_gaps (#85) --- nbs/preprocessing.ipynb | 20 +++++++++++++++----- settings.ini | 2 +- utilsforecast/__init__.py | 2 +- utilsforecast/preprocessing.py | 9 +++++++++ 4 files changed, 26 insertions(+), 7 deletions(-) diff --git a/nbs/preprocessing.ipynb b/nbs/preprocessing.ipynb index 91af9d5..83ec72c 100644 --- a/nbs/preprocessing.ipynb +++ b/nbs/preprocessing.ipynb @@ -72,6 +72,8 @@ " else:\n", " if isinstance(freq, str):\n", " # this raises a nice error message if it isn't a valid datetime\n", + " if isinstance(bound, pd.Timestamp) and bound.tz is not None:\n", + " bound = bound.tz_localize(None)\n", " val = np.datetime64(bound)\n", " else:\n", " val = bound\n", @@ -203,8 +205,13 @@ " # such as MS = 'Month Start' -> 'M', YS = 'Year Start' -> 'Y'\n", " freq = freq[0]\n", " delta: Union[np.timedelta64, int] = np.timedelta64(n, freq)\n", + " tz = df[time_col].dt.tz\n", + " if tz is not None:\n", + " df = df.copy(deep=False)\n", + " df[time_col] = df[time_col].dt.tz_localize(None)\n", " else:\n", " delta = freq\n", + " tz = None\n", " times_by_id = df.groupby(id_col, observed=True)[time_col].agg(['min', 'max'])\n", " starts = _determine_bound(start, freq, times_by_id, 'min')\n", " ends = _determine_bound(end, freq, times_by_id, 'max') + delta\n", @@ -228,6 +235,8 @@ " times += offset.base\n", " idx = pd.MultiIndex.from_arrays([uids, times], names=[id_col, time_col])\n", " res = df.set_index([id_col, time_col]).reindex(idx).reset_index()\n", + " if tz is not None:\n", + " res[time_col] = res[time_col].dt.tz_localize(tz, ambiguous='infer')\n", " extra_cols = df.columns.drop([id_col, time_col]).tolist()\n", " if extra_cols:\n", " check_col = extra_cols[0]\n", @@ -252,7 +261,7 @@ "text/markdown": [ "---\n", "\n", - "[source](https://github.com/Nixtla/utilsforecast/blob/main/utilsforecast/preprocessing.py#L56){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", + "[source](https://github.com/Nixtla/utilsforecast/blob/main/utilsforecast/preprocessing.py#L58){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", "\n", "### fill_gaps\n", "\n", @@ -278,7 +287,7 @@ "text/plain": [ "---\n", "\n", - "[source](https://github.com/Nixtla/utilsforecast/blob/main/utilsforecast/preprocessing.py#L56){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", + "[source](https://github.com/Nixtla/utilsforecast/blob/main/utilsforecast/preprocessing.py#L58){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", "\n", "### fill_gaps\n", "\n", @@ -1598,7 +1607,7 @@ " # inferred frequency is the expected\n", " first_serie = filled[filled['unique_id'] == 1]\n", " if isinstance(freq, str):\n", - " inferred_freq = pd.infer_freq(first_serie['ds'])\n", + " inferred_freq = pd.infer_freq(first_serie['ds'].dt.tz_localize(None))\n", " assert inferred_freq == pd.tseries.frequencies.to_offset(freq)\n", " else:\n", " assert all(first_serie['ds'].diff().value_counts().index == [freq])\n", @@ -1628,7 +1637,7 @@ " assert max_dates[0] == expected_end\n", "\n", "n_periods = 100\n", - "freqs = ['YE', 'YS', 'ME', 'MS', 'W', 'W-TUE', 'D', 's', 'ms', 1, 2, '20D', '30s', '2YE', '3YS', '30min', 'B', '1h', 'QS-OCT', 'QE']\n", + "freqs = ['YE', 'YS', 'ME', 'MS', 'W', 'W-TUE', 'D', 's', 'ms', 1, 2, '20D', '30s', '2YE', '3YS', '30min', 'B', '1h', 'QS-NOV', 'QE']\n", "try:\n", " pd.tseries.frequencies.to_offset('YE')\n", "except ValueError:\n", @@ -1640,8 +1649,9 @@ " for f in freqs if isinstance(f, str)\n", " ]\n", "for freq in freqs:\n", - " if isinstance(freq, (pd.offsets.BaseOffset, str)): \n", + " if isinstance(freq, (pd.offsets.BaseOffset, str)):\n", " dates = pd.date_range('1900-01-01', periods=n_periods, freq=freq)\n", + " dates = dates.tz_localize('Europe/Berlin')\n", " offset = pd.tseries.frequencies.to_offset(freq)\n", " else:\n", " dates = np.arange(0, freq * n_periods, freq, dtype=np.int64)\n", diff --git a/settings.ini b/settings.ini index 42f4c0f..0df391f 100644 --- a/settings.ini +++ b/settings.ini @@ -1,7 +1,7 @@ [DEFAULT] repo = utilsforecast lib_name = utilsforecast -version = 0.1.10 +version = 0.1.11 min_python = 3.8 license = apache2 black_formatting = True diff --git a/utilsforecast/__init__.py b/utilsforecast/__init__.py index 569b121..0c5c300 100644 --- a/utilsforecast/__init__.py +++ b/utilsforecast/__init__.py @@ -1 +1 @@ -__version__ = "0.1.10" +__version__ = "0.1.11" diff --git a/utilsforecast/preprocessing.py b/utilsforecast/preprocessing.py index 0ff1d2c..e6f4637 100644 --- a/utilsforecast/preprocessing.py +++ b/utilsforecast/preprocessing.py @@ -28,6 +28,8 @@ def _determine_bound(bound, freq, times_by_id, agg) -> np.ndarray: else: if isinstance(freq, str): # this raises a nice error message if it isn't a valid datetime + if isinstance(bound, pd.Timestamp) and bound.tz is not None: + bound = bound.tz_localize(None) val = np.datetime64(bound) else: val = bound @@ -149,8 +151,13 @@ def fill_gaps( # such as MS = 'Month Start' -> 'M', YS = 'Year Start' -> 'Y' freq = freq[0] delta: Union[np.timedelta64, int] = np.timedelta64(n, freq) + tz = df[time_col].dt.tz + if tz is not None: + df = df.copy(deep=False) + df[time_col] = df[time_col].dt.tz_localize(None) else: delta = freq + tz = None times_by_id = df.groupby(id_col, observed=True)[time_col].agg(["min", "max"]) starts = _determine_bound(start, freq, times_by_id, "min") ends = _determine_bound(end, freq, times_by_id, "max") + delta @@ -172,6 +179,8 @@ def fill_gaps( times += offset.base idx = pd.MultiIndex.from_arrays([uids, times], names=[id_col, time_col]) res = df.set_index([id_col, time_col]).reindex(idx).reset_index() + if tz is not None: + res[time_col] = res[time_col].dt.tz_localize(tz, ambiguous="infer") extra_cols = df.columns.drop([id_col, time_col]).tolist() if extra_cols: check_col = extra_cols[0]