Skip to content

Commit

Permalink
support timezones for pandas in fill_gaps (#85)
Browse files Browse the repository at this point in the history
  • Loading branch information
jmoralez authored Jun 12, 2024
1 parent 6d99b22 commit fb00c82
Show file tree
Hide file tree
Showing 4 changed files with 26 additions and 7 deletions.
20 changes: 15 additions & 5 deletions nbs/preprocessing.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,8 @@
" else:\n",
" if isinstance(freq, str):\n",
" # this raises a nice error message if it isn't a valid datetime\n",
" if isinstance(bound, pd.Timestamp) and bound.tz is not None:\n",
" bound = bound.tz_localize(None)\n",
" val = np.datetime64(bound)\n",
" else:\n",
" val = bound\n",
Expand Down Expand Up @@ -203,8 +205,13 @@
" # such as MS = 'Month Start' -> 'M', YS = 'Year Start' -> 'Y'\n",
" freq = freq[0]\n",
" delta: Union[np.timedelta64, int] = np.timedelta64(n, freq)\n",
" tz = df[time_col].dt.tz\n",
" if tz is not None:\n",
" df = df.copy(deep=False)\n",
" df[time_col] = df[time_col].dt.tz_localize(None)\n",
" else:\n",
" delta = freq\n",
" tz = None\n",
" times_by_id = df.groupby(id_col, observed=True)[time_col].agg(['min', 'max'])\n",
" starts = _determine_bound(start, freq, times_by_id, 'min')\n",
" ends = _determine_bound(end, freq, times_by_id, 'max') + delta\n",
Expand All @@ -228,6 +235,8 @@
" times += offset.base\n",
" idx = pd.MultiIndex.from_arrays([uids, times], names=[id_col, time_col])\n",
" res = df.set_index([id_col, time_col]).reindex(idx).reset_index()\n",
" if tz is not None:\n",
" res[time_col] = res[time_col].dt.tz_localize(tz, ambiguous='infer')\n",
" extra_cols = df.columns.drop([id_col, time_col]).tolist()\n",
" if extra_cols:\n",
" check_col = extra_cols[0]\n",
Expand All @@ -252,7 +261,7 @@
"text/markdown": [
"---\n",
"\n",
"[source](https://github.com/Nixtla/utilsforecast/blob/main/utilsforecast/preprocessing.py#L56){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
"[source](https://github.com/Nixtla/utilsforecast/blob/main/utilsforecast/preprocessing.py#L58){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
"\n",
"### fill_gaps\n",
"\n",
Expand All @@ -278,7 +287,7 @@
"text/plain": [
"---\n",
"\n",
"[source](https://github.com/Nixtla/utilsforecast/blob/main/utilsforecast/preprocessing.py#L56){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
"[source](https://github.com/Nixtla/utilsforecast/blob/main/utilsforecast/preprocessing.py#L58){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
"\n",
"### fill_gaps\n",
"\n",
Expand Down Expand Up @@ -1598,7 +1607,7 @@
" # inferred frequency is the expected\n",
" first_serie = filled[filled['unique_id'] == 1]\n",
" if isinstance(freq, str):\n",
" inferred_freq = pd.infer_freq(first_serie['ds'])\n",
" inferred_freq = pd.infer_freq(first_serie['ds'].dt.tz_localize(None))\n",
" assert inferred_freq == pd.tseries.frequencies.to_offset(freq)\n",
" else:\n",
" assert all(first_serie['ds'].diff().value_counts().index == [freq])\n",
Expand Down Expand Up @@ -1628,7 +1637,7 @@
" assert max_dates[0] == expected_end\n",
"\n",
"n_periods = 100\n",
"freqs = ['YE', 'YS', 'ME', 'MS', 'W', 'W-TUE', 'D', 's', 'ms', 1, 2, '20D', '30s', '2YE', '3YS', '30min', 'B', '1h', 'QS-OCT', 'QE']\n",
"freqs = ['YE', 'YS', 'ME', 'MS', 'W', 'W-TUE', 'D', 's', 'ms', 1, 2, '20D', '30s', '2YE', '3YS', '30min', 'B', '1h', 'QS-NOV', 'QE']\n",
"try:\n",
" pd.tseries.frequencies.to_offset('YE')\n",
"except ValueError:\n",
Expand All @@ -1640,8 +1649,9 @@
" for f in freqs if isinstance(f, str)\n",
" ]\n",
"for freq in freqs:\n",
" if isinstance(freq, (pd.offsets.BaseOffset, str)): \n",
" if isinstance(freq, (pd.offsets.BaseOffset, str)):\n",
" dates = pd.date_range('1900-01-01', periods=n_periods, freq=freq)\n",
" dates = dates.tz_localize('Europe/Berlin')\n",
" offset = pd.tseries.frequencies.to_offset(freq)\n",
" else:\n",
" dates = np.arange(0, freq * n_periods, freq, dtype=np.int64)\n",
Expand Down
2 changes: 1 addition & 1 deletion settings.ini
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
[DEFAULT]
repo = utilsforecast
lib_name = utilsforecast
version = 0.1.10
version = 0.1.11
min_python = 3.8
license = apache2
black_formatting = True
Expand Down
2 changes: 1 addition & 1 deletion utilsforecast/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "0.1.10"
__version__ = "0.1.11"
9 changes: 9 additions & 0 deletions utilsforecast/preprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ def _determine_bound(bound, freq, times_by_id, agg) -> np.ndarray:
else:
if isinstance(freq, str):
# this raises a nice error message if it isn't a valid datetime
if isinstance(bound, pd.Timestamp) and bound.tz is not None:
bound = bound.tz_localize(None)
val = np.datetime64(bound)
else:
val = bound
Expand Down Expand Up @@ -149,8 +151,13 @@ def fill_gaps(
# such as MS = 'Month Start' -> 'M', YS = 'Year Start' -> 'Y'
freq = freq[0]
delta: Union[np.timedelta64, int] = np.timedelta64(n, freq)
tz = df[time_col].dt.tz
if tz is not None:
df = df.copy(deep=False)
df[time_col] = df[time_col].dt.tz_localize(None)
else:
delta = freq
tz = None
times_by_id = df.groupby(id_col, observed=True)[time_col].agg(["min", "max"])
starts = _determine_bound(start, freq, times_by_id, "min")
ends = _determine_bound(end, freq, times_by_id, "max") + delta
Expand All @@ -172,6 +179,8 @@ def fill_gaps(
times += offset.base
idx = pd.MultiIndex.from_arrays([uids, times], names=[id_col, time_col])
res = df.set_index([id_col, time_col]).reindex(idx).reset_index()
if tz is not None:
res[time_col] = res[time_col].dt.tz_localize(tz, ambiguous="infer")
extra_cols = df.columns.drop([id_col, time_col]).tolist()
if extra_cols:
check_col = extra_cols[0]
Expand Down

0 comments on commit fb00c82

Please sign in to comment.