Skip to content

Commit

Permalink
[Minor] Vectorize timedataset (#1617)
Browse files Browse the repository at this point in the history
* Convert to tensors

* clarify ID drop

* fixed tests

* added vectorization

* fixed linters

---------

Co-authored-by: ourownstory <ourownstory@users.noreply.github.com>
  • Loading branch information
MaiBe-ctrl and ourownstory authored Jul 29, 2024
1 parent 18e3e4d commit 4cf7444
Showing 1 changed file with 70 additions and 78 deletions.
148 changes: 70 additions & 78 deletions neuralprophet/time_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -333,27 +333,22 @@ def get_sample_lagged_regressors(df_tensors, origin_index, config_lagged_regress


def get_sample_seasonalities(df_tensors, origin_index, n_forecasts, max_lags, n_lags, config_seasonality):

seasonalities = OrderedDict({})
if max_lags == 0:
dates = df_tensors["ds"][origin_index].unsqueeze(0)
else:
dates = df_tensors["ds"][origin_index - n_lags + 1 : origin_index + n_forecasts + 1]

t = (dates - torch.tensor(datetime(1900, 1, 1).timestamp())).float() / (3600 * 24.0)

for name, period in config_seasonality.periods.items():
if period.resolution > 0:
if config_seasonality.computation == "fourier":
t = (dates - datetime(1900, 1, 1).timestamp()).float() / (3600 * 24.0)
features = torch.cat(
[
torch.sin(2.0 * (i + 1) * np.pi * t / period.period).unsqueeze(1)
for i in range(period.resolution)
]
+ [
torch.cos(2.0 * (i + 1) * np.pi * t / period.period).unsqueeze(1)
for i in range(period.resolution)
],
dim=1,
)
factor = 2.0 * np.pi * t[:, None] / period.period
sin_terms = torch.sin(factor * torch.arange(1, period.resolution + 1))
cos_terms = torch.cos(factor * torch.arange(1, period.resolution + 1))
features = torch.cat((sin_terms, cos_terms), dim=1)
else:
raise NotImplementedError

Expand Down Expand Up @@ -645,14 +640,15 @@ def get_event_offset_features(event, config, feature):
tuple
Tuple of additive_events and multiplicative_events
"""
events = pd.DataFrame({})
lw = config.lower_window
uw = config.upper_window
for offset in range(lw, uw + 1):
key = utils.create_event_names_for_offsets(event, offset)
offset_feature = feature.shift(periods=offset, fill_value=0.0)
events[key] = offset_feature
return events
offsets = range(config.lower_window, config.upper_window + 1)
offset_features = pd.concat(
{
utils.create_event_names_for_offsets(event, offset): feature.shift(periods=offset, fill_value=0.0)
for offset in offsets
},
axis=1,
)
return offset_features


def add_event_features_to_df(
Expand Down Expand Up @@ -680,52 +676,58 @@ def add_event_features_to_df(

def normalize_holiday_name(name):
# Handle cases like "Independence Day (observed)" -> "Independence Day"
if "(observed)" in name:
return name.replace(" (observed)", "")
return name
return name.replace(" (observed)", "") if "(observed)" in name else name

def add_offset_features(feature, event_name, config):
additive_names = []
multiplicative_names = []
for offset in range(config.lower_window, config.upper_window + 1):
event_offset_name = utils.create_event_names_for_offsets(event_name, offset)
df[event_offset_name] = feature.shift(periods=offset, fill_value=0.0)
if config.mode == "additive":
additive_names.append(event_offset_name)
else:
multiplicative_names.append(event_offset_name)
return additive_names, multiplicative_names

# create all additional user specified offest events
# Create all additional user-specified offset events
additive_events_names = []
multiplicative_events_names = []

if config_events is not None:
for event in sorted(list(config_events.keys())):
for event in sorted(config_events.keys()):
feature = df[event]
config = config_events[event]
mode = config.mode
for offset in range(config.lower_window, config.upper_window + 1):
event_offset_name = utils.create_event_names_for_offsets(event, offset)
df[event_offset_name] = feature.shift(periods=offset, fill_value=0.0)
if mode == "additive":
additive_events_names.append(event_offset_name)
else:
multiplicative_events_names.append(event_offset_name)
additive_names, multiplicative_names = add_offset_features(feature, event, config)
additive_events_names.extend(additive_names)
multiplicative_events_names.extend(multiplicative_names)

# create all country specific holidays and their offsets.
# Create all country-specific holidays and their offsets
additive_holiday_names = []
multiplicative_holiday_names = []

if config_country_holidays is not None:
year_list = list({x.year for x in df.ds})
year_list = df["ds"].dt.year.unique()
country_holidays_dict = get_all_holidays(year_list, config_country_holidays.country)
config = config_country_holidays
mode = config.mode

for holiday in config_country_holidays.holiday_names:
feature = pd.Series(np.zeros(df.shape[0], dtype=np.float32))
holiday = normalize_holiday_name(holiday)
if holiday in country_holidays_dict.keys():
dates = country_holidays_dict[holiday]
feature[df.ds.isin(dates)] = 1.0
feature = pd.Series(np.zeros(len(df)), index=df.index, dtype=np.float32)
normalized_holiday = normalize_holiday_name(holiday)

if normalized_holiday in country_holidays_dict:
dates = country_holidays_dict[normalized_holiday]
feature.loc[df["ds"].isin(dates)] = 1.0
else:
raise ValueError(f"Holiday {holiday} not found in {config_country_holidays.country} holidays")
for offset in range(config.lower_window, config.upper_window + 1):
holiday_offset_name = utils.create_event_names_for_offsets(holiday, offset)
df[holiday_offset_name] = feature.shift(periods=offset, fill_value=0.0)
if mode == "additive":
additive_holiday_names.append(holiday_offset_name)
else:
multiplicative_holiday_names.append(holiday_offset_name)
# Future TODO: possibly undo merge of events and holidays.

additive_names, multiplicative_names = add_offset_features(feature, normalized_holiday, config)
additive_holiday_names.extend(additive_names)
multiplicative_holiday_names.extend(multiplicative_names)

additive_event_and_holiday_names = sorted(additive_events_names + additive_holiday_names)
multiplicative_event_and_holiday_names = sorted(multiplicative_events_names + multiplicative_holiday_names)

return df, additive_event_and_holiday_names, multiplicative_event_and_holiday_names


Expand Down Expand Up @@ -763,35 +765,26 @@ def create_prediction_frequency_filter_mask(timestamps, prediction_frequency=Non
Returns boolean mask where prediction origin indexes to be included are True, and the rest False.
"""
mask = torch.ones(len(timestamps), dtype=torch.bool)

# Basic case: no filter
if prediction_frequency is None:
return mask
else:
assert isinstance(prediction_frequency, dict)
return torch.ones(len(timestamps), dtype=torch.bool)

timestamps = pd.to_datetime(timestamps.numpy(), unit="s")
filter_masks = []
mask = torch.ones(len(timestamps), dtype=torch.bool)

filters = {
"hourly-minute": timestamps.minute,
"daily-hour": timestamps.hour,
"weekly-day": timestamps.dayofweek,
"monthly-day": timestamps.day,
"yearly-month": timestamps.month,
}

for key, value in prediction_frequency.items():
if key == "hourly-minute":
filter_mask = timestamps.minute == value
elif key == "daily-hour":
filter_mask = timestamps.hour == value
elif key == "weekly-day":
filter_mask = timestamps.dayofweek == value
elif key == "monthly-day":
filter_mask = timestamps.day == value
elif key == "yearly-month":
filter_mask = timestamps.month == value
else:
if key not in filters:
raise ValueError(f"Invalid prediction frequency: {key}")
filter_masks.append(filter_mask)
mask &= filters[key] == value

combined_mask = filter_masks[0]
for m in filter_masks[1:]:
combined_mask = combined_mask & m
return torch.tensor(combined_mask, dtype=torch.bool)
return torch.tensor(mask, dtype=torch.bool)


def create_nan_mask(
Expand Down Expand Up @@ -831,7 +824,7 @@ def create_nan_mask(
targets_nan = torch.cat([targets_nan, torch.ones(n_forecasts, dtype=torch.bool)])
targets_valid = ~targets_nan

valid_origins = valid_origins & targets_valid
valid_origins &= targets_valid

# AR LAGS
if n_lags > 0:
Expand All @@ -841,7 +834,7 @@ def create_nan_mask(
# as there are missing lags for the corresponding origin_indexes
y_lags_nan = torch.cat([torch.ones(n_lags - 1, dtype=torch.bool), y_lags_nan])
y_lags_valid = ~y_lags_nan
valid_origins = valid_origins & y_lags_valid
valid_origins &= y_lags_valid

# LAGGED REGRESSORS
if config_lagged_regressors is not None: # and max_lags > 0:
Expand All @@ -856,16 +849,15 @@ def create_nan_mask(
# fill first n_reg_lags -1 positions with True,
# as there are missing lags for the corresponding origin_indexes
reg_lags_nan = torch.cat([torch.ones(n_reg_lags - 1, dtype=torch.bool), reg_lags_nan])
reg_lags_valid_i = ~reg_lags_nan
reg_lags_valid = reg_lags_valid & reg_lags_valid_i
valid_origins = valid_origins & reg_lags_valid
reg_lags_valid &= ~reg_lags_nan
valid_origins &= reg_lags_valid

# TIME: TREND & SEASONALITY: the time at each sample's lags and forecasts
# FUTURE REGRESSORS
# EVENTS
names = ["t"] + future_regressor_names + event_names
valid_columns = mask_origin_without_nan_for_columns(tensor_isna, names, max_lags, n_lags, n_forecasts)
valid_origins = valid_origins & valid_columns
valid_origins &= valid_columns

return valid_origins

Expand Down

0 comments on commit 4cf7444

Please sign in to comment.