From 4cf74441aab123cca5542df609e88ea5ea94d25c Mon Sep 17 00:00:00 2001 From: Maisa Ben Salah <76703998+MaiBe-ctrl@users.noreply.github.com> Date: Mon, 29 Jul 2024 16:49:52 -0700 Subject: [PATCH] [Minor] Vectorize timedataset (#1617) * Convert to tensors * clarify ID drop * fixed tests * added vectorization * fixed linters --------- Co-authored-by: ourownstory --- neuralprophet/time_dataset.py | 148 ++++++++++++++++------------------ 1 file changed, 70 insertions(+), 78 deletions(-) diff --git a/neuralprophet/time_dataset.py b/neuralprophet/time_dataset.py index a5a70e0b3..c8fb0769e 100644 --- a/neuralprophet/time_dataset.py +++ b/neuralprophet/time_dataset.py @@ -333,27 +333,22 @@ def get_sample_lagged_regressors(df_tensors, origin_index, config_lagged_regress def get_sample_seasonalities(df_tensors, origin_index, n_forecasts, max_lags, n_lags, config_seasonality): + seasonalities = OrderedDict({}) if max_lags == 0: dates = df_tensors["ds"][origin_index].unsqueeze(0) else: dates = df_tensors["ds"][origin_index - n_lags + 1 : origin_index + n_forecasts + 1] + t = (dates - torch.tensor(datetime(1900, 1, 1).timestamp())).float() / (3600 * 24.0) + for name, period in config_seasonality.periods.items(): if period.resolution > 0: if config_seasonality.computation == "fourier": - t = (dates - datetime(1900, 1, 1).timestamp()).float() / (3600 * 24.0) - features = torch.cat( - [ - torch.sin(2.0 * (i + 1) * np.pi * t / period.period).unsqueeze(1) - for i in range(period.resolution) - ] - + [ - torch.cos(2.0 * (i + 1) * np.pi * t / period.period).unsqueeze(1) - for i in range(period.resolution) - ], - dim=1, - ) + factor = 2.0 * np.pi * t[:, None] / period.period + sin_terms = torch.sin(factor * torch.arange(1, period.resolution + 1)) + cos_terms = torch.cos(factor * torch.arange(1, period.resolution + 1)) + features = torch.cat((sin_terms, cos_terms), dim=1) else: raise NotImplementedError @@ -645,14 +640,15 @@ def get_event_offset_features(event, config, feature): tuple Tuple of additive_events and multiplicative_events """ - events = pd.DataFrame({}) - lw = config.lower_window - uw = config.upper_window - for offset in range(lw, uw + 1): - key = utils.create_event_names_for_offsets(event, offset) - offset_feature = feature.shift(periods=offset, fill_value=0.0) - events[key] = offset_feature - return events + offsets = range(config.lower_window, config.upper_window + 1) + offset_features = pd.concat( + { + utils.create_event_names_for_offsets(event, offset): feature.shift(periods=offset, fill_value=0.0) + for offset in offsets + }, + axis=1, + ) + return offset_features def add_event_features_to_df( @@ -680,52 +676,58 @@ def add_event_features_to_df( def normalize_holiday_name(name): # Handle cases like "Independence Day (observed)" -> "Independence Day" - if "(observed)" in name: - return name.replace(" (observed)", "") - return name + return name.replace(" (observed)", "") if "(observed)" in name else name + + def add_offset_features(feature, event_name, config): + additive_names = [] + multiplicative_names = [] + for offset in range(config.lower_window, config.upper_window + 1): + event_offset_name = utils.create_event_names_for_offsets(event_name, offset) + df[event_offset_name] = feature.shift(periods=offset, fill_value=0.0) + if config.mode == "additive": + additive_names.append(event_offset_name) + else: + multiplicative_names.append(event_offset_name) + return additive_names, multiplicative_names - # create all additional user specified offest events + # Create all additional user-specified offset events additive_events_names = [] multiplicative_events_names = [] + if config_events is not None: - for event in sorted(list(config_events.keys())): + for event in sorted(config_events.keys()): feature = df[event] config = config_events[event] - mode = config.mode - for offset in range(config.lower_window, config.upper_window + 1): - event_offset_name = utils.create_event_names_for_offsets(event, offset) - df[event_offset_name] = feature.shift(periods=offset, fill_value=0.0) - if mode == "additive": - additive_events_names.append(event_offset_name) - else: - multiplicative_events_names.append(event_offset_name) + additive_names, multiplicative_names = add_offset_features(feature, event, config) + additive_events_names.extend(additive_names) + multiplicative_events_names.extend(multiplicative_names) - # create all country specific holidays and their offsets. + # Create all country-specific holidays and their offsets additive_holiday_names = [] multiplicative_holiday_names = [] + if config_country_holidays is not None: - year_list = list({x.year for x in df.ds}) + year_list = df["ds"].dt.year.unique() country_holidays_dict = get_all_holidays(year_list, config_country_holidays.country) config = config_country_holidays - mode = config.mode + for holiday in config_country_holidays.holiday_names: - feature = pd.Series(np.zeros(df.shape[0], dtype=np.float32)) - holiday = normalize_holiday_name(holiday) - if holiday in country_holidays_dict.keys(): - dates = country_holidays_dict[holiday] - feature[df.ds.isin(dates)] = 1.0 + feature = pd.Series(np.zeros(len(df)), index=df.index, dtype=np.float32) + normalized_holiday = normalize_holiday_name(holiday) + + if normalized_holiday in country_holidays_dict: + dates = country_holidays_dict[normalized_holiday] + feature.loc[df["ds"].isin(dates)] = 1.0 else: raise ValueError(f"Holiday {holiday} not found in {config_country_holidays.country} holidays") - for offset in range(config.lower_window, config.upper_window + 1): - holiday_offset_name = utils.create_event_names_for_offsets(holiday, offset) - df[holiday_offset_name] = feature.shift(periods=offset, fill_value=0.0) - if mode == "additive": - additive_holiday_names.append(holiday_offset_name) - else: - multiplicative_holiday_names.append(holiday_offset_name) - # Future TODO: possibly undo merge of events and holidays. + + additive_names, multiplicative_names = add_offset_features(feature, normalized_holiday, config) + additive_holiday_names.extend(additive_names) + multiplicative_holiday_names.extend(multiplicative_names) + additive_event_and_holiday_names = sorted(additive_events_names + additive_holiday_names) multiplicative_event_and_holiday_names = sorted(multiplicative_events_names + multiplicative_holiday_names) + return df, additive_event_and_holiday_names, multiplicative_event_and_holiday_names @@ -763,35 +765,26 @@ def create_prediction_frequency_filter_mask(timestamps, prediction_frequency=Non Returns boolean mask where prediction origin indexes to be included are True, and the rest False. """ - mask = torch.ones(len(timestamps), dtype=torch.bool) - - # Basic case: no filter if prediction_frequency is None: - return mask - else: - assert isinstance(prediction_frequency, dict) + return torch.ones(len(timestamps), dtype=torch.bool) timestamps = pd.to_datetime(timestamps.numpy(), unit="s") - filter_masks = [] + mask = torch.ones(len(timestamps), dtype=torch.bool) + + filters = { + "hourly-minute": timestamps.minute, + "daily-hour": timestamps.hour, + "weekly-day": timestamps.dayofweek, + "monthly-day": timestamps.day, + "yearly-month": timestamps.month, + } + for key, value in prediction_frequency.items(): - if key == "hourly-minute": - filter_mask = timestamps.minute == value - elif key == "daily-hour": - filter_mask = timestamps.hour == value - elif key == "weekly-day": - filter_mask = timestamps.dayofweek == value - elif key == "monthly-day": - filter_mask = timestamps.day == value - elif key == "yearly-month": - filter_mask = timestamps.month == value - else: + if key not in filters: raise ValueError(f"Invalid prediction frequency: {key}") - filter_masks.append(filter_mask) + mask &= filters[key] == value - combined_mask = filter_masks[0] - for m in filter_masks[1:]: - combined_mask = combined_mask & m - return torch.tensor(combined_mask, dtype=torch.bool) + return torch.tensor(mask, dtype=torch.bool) def create_nan_mask( @@ -831,7 +824,7 @@ def create_nan_mask( targets_nan = torch.cat([targets_nan, torch.ones(n_forecasts, dtype=torch.bool)]) targets_valid = ~targets_nan - valid_origins = valid_origins & targets_valid + valid_origins &= targets_valid # AR LAGS if n_lags > 0: @@ -841,7 +834,7 @@ def create_nan_mask( # as there are missing lags for the corresponding origin_indexes y_lags_nan = torch.cat([torch.ones(n_lags - 1, dtype=torch.bool), y_lags_nan]) y_lags_valid = ~y_lags_nan - valid_origins = valid_origins & y_lags_valid + valid_origins &= y_lags_valid # LAGGED REGRESSORS if config_lagged_regressors is not None: # and max_lags > 0: @@ -856,16 +849,15 @@ def create_nan_mask( # fill first n_reg_lags -1 positions with True, # as there are missing lags for the corresponding origin_indexes reg_lags_nan = torch.cat([torch.ones(n_reg_lags - 1, dtype=torch.bool), reg_lags_nan]) - reg_lags_valid_i = ~reg_lags_nan - reg_lags_valid = reg_lags_valid & reg_lags_valid_i - valid_origins = valid_origins & reg_lags_valid + reg_lags_valid &= ~reg_lags_nan + valid_origins &= reg_lags_valid # TIME: TREND & SEASONALITY: the time at each sample's lags and forecasts # FUTURE REGRESSORS # EVENTS names = ["t"] + future_regressor_names + event_names valid_columns = mask_origin_without_nan_for_columns(tensor_isna, names, max_lags, n_lags, n_forecasts) - valid_origins = valid_origins & valid_columns + valid_origins &= valid_columns return valid_origins