Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Major] Dataloader: Just-In-Time tabularization #1529

Merged
merged 141 commits into from
Jun 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
141 commits
Select commit Hold shift + click to select a range
34909cb
minimal pytest
SimonWittner Dec 13, 2023
687c085
move_func_getitem
SimonWittner Dec 13, 2023
5215340
slicing
SimonWittner Dec 15, 2023
c70fae2
predict_mode
SimonWittner Dec 15, 2023
b78d5e0
typos
SimonWittner Dec 18, 2023
beae5bb
lr-finder
SimonWittner Dec 19, 2023
8427ffc
drop_missing
SimonWittner Dec 19, 2023
ff05b2a
predict_v2
SimonWittner Dec 19, 2023
c408e95
predict_v3
SimonWittner Dec 19, 2023
df29f33
samples
SimonWittner Dec 20, 2023
29fe999
lagged regressor n_lags
SimonWittner Dec 21, 2023
2f584c2
preliminary: events, holidays
SimonWittner Dec 21, 2023
fca7adf
adjustes pytests
SimonWittner Dec 21, 2023
139a97f
selective forecasting
SimonWittner Dec 22, 2023
30aa303
black
SimonWittner Dec 22, 2023
381c912
ruff
SimonWittner Dec 22, 2023
660934c
lagged_regressors
SimonWittner Jan 4, 2024
51fa0a6
Note down df path to TimeDataset
ourownstory Jan 12, 2024
da74f87
complete notes on TimeDataset, move meta
ourownstory Jan 12, 2024
97fbe07
Big rewrite with real and pseudocode
ourownstory Jan 16, 2024
3cfa2be
Merge branch 'main' into dataloader-fly-TimeDataset
ourownstory Jan 17, 2024
b1f084b
Merge branch 'main' into dataloader-fly-TimeDataset
ourownstory Jan 17, 2024
bdf529c
create_target_start_end_mask
ourownstory Jan 17, 2024
c814115
boolean mask
ourownstory Jan 18, 2024
7119419
combine masks into map
ourownstory Jan 18, 2024
66bb911
notes for nan check
ourownstory Jan 18, 2024
fe382c1
bypass NAN filter
ourownstory Jan 18, 2024
8ec4f9f
rework index to point at prediction origin, not first forecast.
ourownstory Jan 19, 2024
23d6100
tabularize: converted time and lags to single sample extraction
ourownstory Jan 19, 2024
49af45b
convert lagged regressors
ourownstory Jan 23, 2024
a35a1b8
consolidate seasonality computation in one script
ourownstory Jan 23, 2024
c1c9b1b
finish Seasonlity conversion
ourownstory Jan 24, 2024
8271a5e
update todos
ourownstory Jan 24, 2024
1aad054
complete targets and future regressors
ourownstory Jan 24, 2024
a41138e
convert events
ourownstory Jan 24, 2024
dfc6006
finish events and holidays conversion
ourownstory Jan 24, 2024
62c4818
debug timedataset
ourownstory Jan 25, 2024
e7b8f0c
debugging
ourownstory Jan 26, 2024
235eea8
make_country_specific_holidays_df
ourownstory Jan 26, 2024
02ff9bb
remove uses of df.loc[...].values
ourownstory Jan 26, 2024
7fda18d
debug time
ourownstory Jan 26, 2024
621e701
debugging types
ourownstory Jan 26, 2024
c62f332
debug timedata
ourownstory Jan 26, 2024
54edbf4
debugging time_dataset variable shapes
ourownstory Jan 26, 2024
4629bf4
address indexing and slicing issues, .loc
ourownstory Jan 26, 2024
b2f89ed
fix dimensions except nonstationary components
ourownstory Jan 27, 2024
c65a107
integrate torch formatting into tabularize
ourownstory Jan 27, 2024
af5524a
check shapes
ourownstory Jan 27, 2024
404e307
AirPassengers test working!
ourownstory Jan 27, 2024
6075074
fix dataset generator
ourownstory Jan 27, 2024
d6242a2
fixed all performance tests but Energy due to nonstationary components
ourownstory Jan 30, 2024
a5ebff9
fixed nonstationary issue. all performance tests running
ourownstory Jan 30, 2024
a4152e6
refactor tabularize function
ourownstory Jan 30, 2024
fba0d0d
fix bug
ourownstory Jan 30, 2024
3493d8a
initial build of GlobalTimeDataset
ourownstory Jan 30, 2024
dbec862
refactor TimeDataset not to use kwargs passthrough
ourownstory Jan 30, 2024
254cb23
debugged seasonal components call of TimeDataset
ourownstory Jan 30, 2024
1b6940a
fix numpy object type error
ourownstory Jan 31, 2024
edec344
fix seasonality condition bugs
ourownstory Jan 31, 2024
5eef5f9
fix events and future regressor cases
ourownstory Feb 1, 2024
f88e550
fixing prediction frequency filter
ourownstory Feb 1, 2024
61aad2a
performance_test_energy
SimonWittner Feb 2, 2024
661b5b7
debug events
ourownstory Feb 2, 2024
3e5dd34
convert new energytest to daily data
ourownstory Feb 2, 2024
b78477b
fix events util reference
ourownstory Feb 2, 2024
190e3b7
fix test_get_country_holidays
ourownstory Feb 2, 2024
767ca02
fix test_timedataset_minima
ourownstory Feb 2, 2024
7e9b29d
fix selective forecasting
ourownstory Feb 2, 2024
32d2cc6
cleanup timedataset
ourownstory Feb 2, 2024
b709f2d
refactor tabularize_univariate
ourownstory Feb 2, 2024
9fe44c4
daily_data
SimonWittner Feb 2, 2024
7d84b37
start nan check for smaple mask
ourownstory Feb 7, 2024
79ad0e7
working on time nan2
ourownstory Feb 7, 2024
469b11c
fix tests
ourownstory Feb 7, 2024
38f70fa
finish nan-check
ourownstory Feb 8, 2024
cfb2562
fix dims
ourownstory Feb 8, 2024
e320b22
pass self.df to indexing
ourownstory Feb 8, 2024
7f7be5f
fix zero dim lagged regressors
ourownstory Feb 8, 2024
d00d5f9
close figures in tests
ourownstory Feb 8, 2024
df5051d
fix typings
ourownstory Feb 8, 2024
d3bce01
black
ourownstory Feb 8, 2024
dce2f73
ruff
ourownstory Feb 8, 2024
bedce94
linting
ourownstory Feb 8, 2024
051e1ad
linting
ourownstory Feb 8, 2024
0c9cd87
modify logs
ourownstory Feb 9, 2024
f44231a
add benchmarking script for computational time
ourownstory Feb 9, 2024
2039212
speed up uncertainty tests
ourownstory Feb 9, 2024
d34700f
fix unit test multiple country
ourownstory Feb 9, 2024
485f5a8
reduce tests log level to ERROR
ourownstory Feb 9, 2024
8b863da
reduce log level to ERROR and fix adding multiple countries
ourownstory Feb 9, 2024
3226884
bypass intentional glocal test error log
ourownstory Feb 9, 2024
a6eceb2
fix prev
ourownstory Feb 9, 2024
6cbf17b
benchmark dataloader time
ourownstory Feb 9, 2024
fbcccc3
Merge branch 'main' into dataloader-jit
ourownstory Feb 14, 2024
0c16eb1
remove hourly energy test
ourownstory Feb 15, 2024
b5845fd
add debug notebook for energy hourly
ourownstory Feb 15, 2024
712dcf0
set to log model performance INFO
ourownstory Feb 15, 2024
eb2ccc4
Merge branch 'main' into dataloader-jit
ourownstory Feb 15, 2024
c0b3cdd
address config_regressors.regressors
ourownstory Feb 15, 2024
88264fc
clean up create_nan_mask
ourownstory Feb 15, 2024
a0b0247
clean up create_nan_mask params
ourownstory Feb 15, 2024
93f0067
clean TimeDataframe
ourownstory Feb 15, 2024
d769a8d
update prediction frequency documentation
ourownstory Feb 15, 2024
576ed14
improve prediction frequency documentation
ourownstory Feb 15, 2024
865645c
further improve prediction frequency documentation
ourownstory Feb 15, 2024
4c4d640
fix test errors
ourownstory Feb 15, 2024
d63ea98
fix df_names call
ourownstory Feb 15, 2024
6dfaffa
fix selective prediction assertion
ourownstory Feb 15, 2024
2f38531
merge main
ourownstory Feb 19, 2024
eca2dbb
Merge branch 'main' into dataloader-jit
MaiBe-ctrl Jun 21, 2024
0845d62
normalize holiday naes
MaiBe-ctrl Jun 21, 2024
0982084
fix linting
MaiBe-ctrl Jun 21, 2024
e89057b
fix tests
MaiBe-ctrl Jun 21, 2024
7d938bd
update to use new holiday functions in event_utils.py
ourownstory Jun 21, 2024
f3ca8f3
fix seasonality_local_reg test
ourownstory Jun 21, 2024
08038bd
limit holidays to less than 1.0
ourownstory Jun 21, 2024
5054ab9
Merge branch 'main' into dataloader-jit
MaiBe-ctrl Jun 21, 2024
1da552a
changed holidays
MaiBe-ctrl Jun 21, 2024
d6c8210
Merged
MaiBe-ctrl Jun 21, 2024
adcd8de
update lock
ourownstory Jun 21, 2024
f08e8c4
fix tests
MaiBe-ctrl Jun 21, 2024
241a407
changed tests
MaiBe-ctrl Jun 21, 2024
c1abbea
adjsuted tests
MaiBe-ctrl Jun 21, 2024
40ad298
fix reserved names
MaiBe-ctrl Jun 22, 2024
f7b5eb7
fixed ruff lintint
MaiBe-ctrl Jun 22, 2024
a04ab12
Merge branch 'main' into dataloader-jit
ourownstory Jun 22, 2024
a436231
changed test
MaiBe-ctrl Jun 22, 2024
60260bd
translate holidays to english is possible
MaiBe-ctrl Jun 22, 2024
c54d4b7
exclude py3.13
ourownstory Jun 22, 2024
0508454
update lock
ourownstory Jun 22, 2024
cde3f45
Merge all holidays related tests in one file
MaiBe-ctrl Jun 25, 2024
9ae4f3c
add deterministic flag
MaiBe-ctrl Jun 26, 2024
aac70de
fixed ruff linting issues
MaiBe-ctrl Jun 26, 2024
ec76aae
fixed glocal test
MaiBe-ctrl Jun 26, 2024
939401f
Merge branch 'bug/make_tests_deterministic' into dataloader-jit
MaiBe-ctrl Jun 26, 2024
19d8e7a
fix lock file
MaiBe-ctrl Jun 26, 2024
c533f01
update poetry
MaiBe-ctrl Jun 26, 2024
ad449c2
moved the deterministic flag to the train method
MaiBe-ctrl Jun 26, 2024
7c98dc9
Merge branch 'bug/make_tests_deterministic' into dataloader-jit
MaiBe-ctrl Jun 26, 2024
444195a
Merge branch 'main' into dataloader-jit
MaiBe-ctrl Jun 26, 2024
39b6913
update lock file
MaiBe-ctrl Jun 26, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 17 additions & 2 deletions docs/source/code/forecaster.rst
Original file line number Diff line number Diff line change
@@ -1,5 +1,20 @@
NeuralProphet Class
-----------------------
Core Module Documentation
==========================

.. toctree::
:hidden:
:maxdepth: 1

configure.py <configure>
df_utils.py <df_utils>
event_utils.py <event_utils>
plot_forecast_plotly.py <plot_forecast_plotly>
plot_forecast_matplotlib.py <plot_forecast_matplotlib>
plot_model_parameters_plotly.py <plot_model_parameters_plotly>
plot_model_parameters_matplotlib.py <plot_model_parameters_matplotlib>
time_dataset.py <time_dataset>
time_net.py <time_net>
utils.py <utils>

.. automodule:: neuralprophet.forecaster
:members:
5 changes: 0 additions & 5 deletions docs/source/code/hdays_utils.rst

This file was deleted.

3 changes: 0 additions & 3 deletions docs/source/how-to-guides/feature-guides/mlflow.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,6 @@
"# Start a new MLflow run\n",
"if local:\n",
" with mlflow.start_run():\n",
"\n",
" # Create a new MLflow experiment\n",
" mlflow.set_experiment(\"NP-MLflow Quickstart_v1\")\n",
"\n",
Expand Down Expand Up @@ -259,7 +258,6 @@
"from mlflow.data.pandas_dataset import PandasDataset\n",
"\n",
"if local:\n",
"\n",
" mlflow.pytorch.autolog(\n",
" log_every_n_epoch=1,\n",
" log_every_n_step=None,\n",
Expand All @@ -279,7 +277,6 @@
" model_name = \"NeuralProphet\"\n",
"\n",
" with mlflow.start_run() as run:\n",
"\n",
" dataset: PandasDataset = mlflow.data.from_pandas(df, source=\"AirPassengersDataset\")\n",
"\n",
" # Log the dataset to the MLflow Run. Specify the \"training\" context to indicate that the\n",
Expand Down
6 changes: 4 additions & 2 deletions neuralprophet/components/future_regressors/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,10 @@ def scalar_features_effects(self, features, params, indices=None):
if indices is not None:
features = features[:, :, indices]
params = params[:, indices]

return torch.sum(features.unsqueeze(dim=2) * params.unsqueeze(dim=0).unsqueeze(dim=0), dim=-1)
# features dims: (batch, n_forecasts, n_features) -> (batch, n_forecasts, 1, n_features)
# params dims: (n_quantiles, n_features) -> (batch, 1, n_quantiles, n_features)
out = torch.sum(features.unsqueeze(dim=2) * params.unsqueeze(dim=0).unsqueeze(dim=0), dim=-1)
return out # dims (batch, n_forecasts, n_quantiles)

def get_reg_weights(self, name):
"""
Expand Down
11 changes: 5 additions & 6 deletions neuralprophet/configure.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

from neuralprophet import df_utils, np_types, utils_torch
from neuralprophet.custom_loss_metrics import PinballLoss
from neuralprophet.hdays_utils import get_holidays_from_country
from neuralprophet.event_utils import get_holiday_names

log = logging.getLogger("NP.config")

Expand All @@ -42,10 +42,9 @@
config_events: Optional[ConfigEvents] = None,
config_seasonality: Optional[ConfigSeasonality] = None,
):
if len(df["ID"].unique()) == 1:
if not self.global_normalization:
log.info("Setting normalization to global as only one dataframe provided for training.")
self.global_normalization = True
if len(df["ID"].unique()) == 1 and not self.global_normalization:
log.info("Setting normalization to global as only one dataframe provided for training.")
self.global_normalization = True
self.local_data_params, self.global_data_params = df_utils.init_data_params(
df=df,
normalize=self.normalize,
Expand Down Expand Up @@ -305,7 +304,7 @@
log.error("Invalid growth for global_local mode '{}'. Set to 'global'".format(self.trend_global_local))
self.trend_global_local = "global"

if self.trend_local_reg < 0:

Check failure on line 307 in neuralprophet/configure.py

View workflow job for this annotation

GitHub Actions / pyright

Operator "<" not supported for "None" (reportOptionalOperand)
log.error("Invalid negative trend_local_reg '{}'. Set to False".format(self.trend_local_reg))
self.trend_local_reg = False

Expand Down Expand Up @@ -354,13 +353,13 @@
log.error("Invalid global_local mode '{}'. Set to 'global'".format(self.global_local))
self.global_local = "global"

self.periods = OrderedDict(

Check failure on line 356 in neuralprophet/configure.py

View workflow job for this annotation

GitHub Actions / pyright

No overloads for "__init__" match the provided arguments (reportCallIssue)
{

Check failure on line 357 in neuralprophet/configure.py

View workflow job for this annotation

GitHub Actions / pyright

Argument of type "dict[str, Season]" cannot be assigned to parameter "iterable" of type "Iterable[list[bytes]]" in function "__init__" (reportArgumentType)
"yearly": Season(
resolution=6,
period=365.25,
arg=self.yearly_arg,
global_local=(

Check failure on line 362 in neuralprophet/configure.py

View workflow job for this annotation

GitHub Actions / pyright

Argument of type "SeasonGlobalLocalMode | Literal['auto']" cannot be assigned to parameter "global_local" of type "SeasonGlobalLocalMode" in function "__init__" (reportArgumentType)
self.yearly_global_local
if self.yearly_global_local in ["global", "local"]
else self.global_local
Expand All @@ -371,7 +370,7 @@
resolution=3,
period=7,
arg=self.weekly_arg,
global_local=(

Check failure on line 373 in neuralprophet/configure.py

View workflow job for this annotation

GitHub Actions / pyright

Argument of type "SeasonGlobalLocalMode | Literal['auto']" cannot be assigned to parameter "global_local" of type "SeasonGlobalLocalMode" in function "__init__" (reportArgumentType)
self.weekly_global_local
if self.weekly_global_local in ["global", "local"]
else self.global_local
Expand All @@ -382,7 +381,7 @@
resolution=6,
period=1,
arg=self.daily_arg,
global_local=(

Check failure on line 384 in neuralprophet/configure.py

View workflow job for this annotation

GitHub Actions / pyright

Argument of type "SeasonGlobalLocalMode | Literal['auto']" cannot be assigned to parameter "global_local" of type "SeasonGlobalLocalMode" in function "__init__" (reportArgumentType)
self.daily_global_local if self.daily_global_local in ["global", "local"] else self.global_local
),
condition_name=None,
Expand All @@ -390,7 +389,7 @@
}
)

assert self.seasonality_local_reg >= 0, "Invalid seasonality_local_reg '{}'.".format(self.seasonality_local_reg)

Check failure on line 392 in neuralprophet/configure.py

View workflow job for this annotation

GitHub Actions / pyright

Operator ">=" not supported for "None" (reportOptionalOperand)

if self.seasonality_local_reg is True:
log.warning("seasonality_local_reg = True. Default seasonality_local_reg value set to 1")
Expand All @@ -408,7 +407,7 @@
resolution=resolution,
period=period,
arg=arg,
global_local=global_local if global_local in ["global", "local"] else self.global_local,

Check failure on line 410 in neuralprophet/configure.py

View workflow job for this annotation

GitHub Actions / pyright

Argument of type "str" cannot be assigned to parameter "global_local" of type "SeasonGlobalLocalMode" in function "__init__"   Type "str" is incompatible with type "SeasonGlobalLocalMode"     "str" is incompatible with type "Literal['global']"     "str" is incompatible with type "Literal['local']"     "str" is incompatible with type "Literal['glocal']" (reportArgumentType)
condition_name=condition_name,
)

Expand Down Expand Up @@ -484,7 +483,7 @@
regressors: OrderedDict = field(init=False) # contains RegressorConfig objects

def __post_init__(self):
self.regressors = None

Check failure on line 486 in neuralprophet/configure.py

View workflow job for this annotation

GitHub Actions / pyright

Cannot assign to attribute "regressors" for class "ConfigFutureRegressors*"   "None" is incompatible with "OrderedDict[Unknown, Unknown]" (reportAttributeAccessIssue)


@dataclass
Expand All @@ -508,7 +507,7 @@
holiday_names: set = field(init=False)

def init_holidays(self, df=None):
self.holiday_names = get_holidays_from_country(self.country, df)
self.holiday_names = get_holiday_names(self.country, df)


ConfigCountryHolidays = Holidays
25 changes: 13 additions & 12 deletions neuralprophet/data/process.py
Original file line number Diff line number Diff line change
Expand Up @@ -333,18 +333,18 @@ def _validate_column_name(
"""
reserved_names = [
"trend",
"additive_terms",
"daily",
"weekly",
"yearly",
"events",
"holidays",
"zeros",
"extra_regressors_additive",
"yhat",
"extra_regressors_multiplicative",
"multiplicative_terms",
"ID",
"y_scaled",
"ds",
"t",
"y",
"index",
]
rn_l = [n + "_lower" for n in reserved_names]
rn_u = [n + "_upper" for n in reserved_names]
Expand Down Expand Up @@ -434,14 +434,14 @@ def _check_dataframe(

def _handle_missing_data(
df: pd.DataFrame,
freq: Optional[str],
freq: str,
n_lags: int,
n_forecasts: int,
config_missing,
config_regressors: Optional[ConfigFutureRegressors],
config_lagged_regressors: Optional[ConfigLaggedRegressors],
config_events: Optional[ConfigEvents],
config_seasonality: Optional[ConfigSeasonality],
config_regressors: Optional[ConfigFutureRegressors] = None,
config_lagged_regressors: Optional[ConfigLaggedRegressors] = None,
config_events: Optional[ConfigEvents] = None,
config_seasonality: Optional[ConfigSeasonality] = None,
predicting: bool = False,
) -> pd.DataFrame:
"""
Expand Down Expand Up @@ -618,12 +618,13 @@ def _create_dataset(model, df, predict_mode, prediction_frequency=None):
predict_mode=predict_mode,
n_lags=model.n_lags,
n_forecasts=model.n_forecasts,
prediction_frequency=prediction_frequency,
predict_steps=model.predict_steps,
config_seasonality=model.config_seasonality,
config_events=model.config_events,
config_country_holidays=model.config_country_holidays,
config_lagged_regressors=model.config_lagged_regressors,
config_regressors=model.config_regressors,
config_lagged_regressors=model.config_lagged_regressors,
config_missing=model.config_missing,
prediction_frequency=prediction_frequency,
# config_train=model.config_train, # no longer needed since JIT tabularization.
)
23 changes: 10 additions & 13 deletions neuralprophet/df_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,28 +88,27 @@ def return_df_in_original_format(df, received_ID_col=False, received_single_time
return new_df


def get_max_num_lags(config_lagged_regressors: Optional[ConfigLaggedRegressors], n_lags: int) -> int:
def get_max_num_lags(n_lags: int, config_lagged_regressors: Optional[ConfigLaggedRegressors]) -> int:
"""Get the greatest number of lags between the autoregression lags and the covariates lags.

Parameters
----------
config_lagged_regressors : configure.ConfigLaggedRegressors
Configurations for lagged regressors
n_lags : int
number of lagged values of series to include as model inputs
config_lagged_regressors : configure.ConfigLaggedRegressors
Configurations for lagged regressors

Returns
-------
int
Maximum number of lags between the autoregression lags and the covariates lags.
"""
if config_lagged_regressors is not None:
log.debug("config_lagged_regressors exists")
max_n_lags = max([n_lags] + [val.n_lags for key, val in config_lagged_regressors.items()])
# log.debug("config_lagged_regressors exists")
return max([n_lags] + [val.n_lags for key, val in config_lagged_regressors.items()])
else:
log.debug("config_lagged_regressors does not exist")
max_n_lags = n_lags
return max_n_lags
# log.debug("config_lagged_regressors does not exist")
return n_lags


def merge_dataframes(df: pd.DataFrame) -> pd.DataFrame:
Expand Down Expand Up @@ -508,14 +507,12 @@ def check_dataframe(
for name in columns:
if name not in df:
raise ValueError(f"Column {name!r} missing from dataframe")
if df.loc[df.loc[:, name].notnull()].shape[0] < 1:
if sum(df.loc[:, name].notnull().values) < 1:
raise ValueError(f"Dataframe column {name!r} only has NaN rows.")
if not np.issubdtype(df[name].dtype, np.number):
df[name] = pd.to_numeric(df[name])
if np.isinf(df.loc[:, name].values).any():
df.loc[:, name] = df[name].replace([np.inf, -np.inf], np.nan)
if df.loc[df.loc[:, name].notnull()].shape[0] < 1:
raise ValueError(f"Dataframe column {name!r} only has NaN rows.")

if future:
return df, regressors_to_remove, lag_regressors_to_remove
Expand Down Expand Up @@ -1541,10 +1538,10 @@ def drop_missing_from_df(df, drop_missing, predict_steps, n_lags):
if all_nan_idx[i + 1] - all_nan_idx[i] > 1:
break
# drop NaN window
df = df.drop(df.index[window[0] : window[-1] + 1]).reset_index().drop("index", axis=1)
df = df.drop(df.index[window[0] : window[-1] + 1]).reset_index(drop=True)
# drop lagged values if window does not occur at the beginning of df
if window[0] - (n_lags - 1) >= 0:
df = df.drop(df.index[(window[0] - (n_lags - 1)) : window[0]]).reset_index().drop("index", axis=1)
df = df.drop(df.index[(window[0] - (n_lags - 1)) : window[0]]).reset_index(drop=True)
return df


Expand Down
71 changes: 71 additions & 0 deletions neuralprophet/event_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
from collections import defaultdict
from typing import Iterable, Union

import numpy as np
import pandas as pd
from holidays import country_holidays


def get_holiday_names(country: Union[str, Iterable[str]], df=None):
"""
Return all possible holiday names for a list of countries over time period in df

Parameters
----------
country : str, list
List of country names to retrieve country specific holidays
df : pd.Dataframe
Dataframe from which datestamps will be retrieved from

Returns
-------
set
All possible holiday names of given country
"""
if df is None:
years = np.arange(1995, 2045)
else:
dates = df["ds"].copy(deep=True)
years = pd.unique(dates.apply(lambda x: x.year))
# years = list({x.year for x in dates})
# support multiple countries, convert to list if not already
if isinstance(country, str):
country = [country]

all_holidays = get_all_holidays(years=years, country=country)
return set(all_holidays.keys())


def get_all_holidays(years, country):
"""
Make dataframe of country specific holidays for given years and countries
Parameters
----------
year_list : list
List of years
country : str, list, dict
List of country names and optional subdivisions
Returns
-------
pd.DataFrame
Containing country specific holidays df with columns 'ds' and 'holiday'
"""
# convert to list if not already
if isinstance(country, str):
country = {country: None}
elif isinstance(country, list):
country = dict(zip(country, [None] * len(country)))

all_holidays = defaultdict(list)
# iterate over countries and get holidays for each country
for single_country, subdivision in country.items():
# For compatibility with Turkey as "TU" cases.
single_country = "TUR" if single_country == "TU" else single_country
# get dict of dates and their holiday name
single_country_specific_holidays = country_holidays(
country=single_country, subdiv=subdivision, years=years, expand=True, observed=False, language="en"
)
# invert order - for given holiday, store list of dates
for date, name in single_country_specific_holidays.items():
all_holidays[name].append(pd.to_datetime(date))
return all_holidays
Loading
Loading