From 2176e14aa34b8b51f1dd43063fa930d92b2f2719 Mon Sep 17 00:00:00 2001 From: MaiBe-ctrl Date: Tue, 25 Jun 2024 12:00:30 -0700 Subject: [PATCH] adjust tests --- neuralprophet/event_utils.py | 20 +++--- tests/test_event_utils.py | 116 +++++++++++++++++------------------ 2 files changed, 68 insertions(+), 68 deletions(-) diff --git a/neuralprophet/event_utils.py b/neuralprophet/event_utils.py index 6c1b16780..d4ab2eb0d 100644 --- a/neuralprophet/event_utils.py +++ b/neuralprophet/event_utils.py @@ -34,7 +34,7 @@ def get_holiday_names(country: Union[str, Iterable[str]], df=None): all_holidays = get_all_holidays(years=years, country=country) - return list(all_holidays.keys()) + return set(all_holidays.keys()) def get_all_holidays(years, country): @@ -70,13 +70,13 @@ def get_all_holidays(years, country): for date, name in single_country_specific_holidays.items(): all_holidays[name].append(pd.to_datetime(date)) - # merge holidays in different languages (having the exact same datatimes) into one holiday item - unique_holidays = {} - for key1, val1 in all_holidays.items(): - holiday_names = [key1] - for key2, val2 in all_holidays.items(): - if set(val1) == set(val2) and key1 != key2: - holiday_names.append(key2) - unique_holidays["_".join(list(holiday_names))] = val1 + # # merge holidays in different languages (having the exact same datatimes) into one holiday item + # unique_holidays = {} + # for key1, val1 in all_holidays.items(): + # holiday_names = [key1] + # for key2, val2 in all_holidays.items(): + # if set(val1) == set(val2) and key1 != key2: + # holiday_names.append(key2) + # unique_holidays["_".join(list(holiday_names))] = val1 - return unique_holidays + return all_holidays diff --git a/tests/test_event_utils.py b/tests/test_event_utils.py index 22f0f39a7..6d9ed6f50 100644 --- a/tests/test_event_utils.py +++ b/tests/test_event_utils.py @@ -29,64 +29,64 @@ PLOT = False -# def test_get_country_holidays(): -# # deprecated -# # assert issubclass(event_utils.get_country_holidays("TU").__class__, holidays.countries.turkey.TR) is True -# # new format -# assert issubclass(event_utils.get_all_holidays(country=["TU", "US"], years=2025).__class__, dict) is True - -# for country in ("UnitedStates", "US", "USA"): -# us_holidays = event_utils.get_all_holidays(country=country, years=[2019, 2020]) -# assert len(us_holidays) == 10 - -# with pytest.raises(NotImplementedError): -# event_utils.get_holiday_names("NotSupportedCountry") - - -# def test_get_country_holidays_with_subdivisions(): -# # Test US holidays with a subdivision -# us_ca_holidays = country_holidays("US", years=2019, subdiv="CA") -# assert issubclass(us_ca_holidays.__class__, holidays.countries.united_states.UnitedStates) is True -# assert len(us_ca_holidays) > 0 # Assuming there are holidays specific to CA - -# # Test Canada holidays with a subdivision -# ca_on_holidays = country_holidays("CA", years=2019, subdiv="ON") -# assert issubclass(ca_on_holidays.__class__, holidays.countries.canada.CA) is True -# assert len(ca_on_holidays) > 0 # Assuming there are holidays specific to ON - - -# def test_add_country_holiday_multiple_calls_warning(caplog): -# m = NeuralProphet( -# epochs=EPOCHS, -# batch_size=BATCH_SIZE, -# learning_rate=LR, -# ) -# m.add_country_holidays(["US", "Germany"]) -# error_message = "Country holidays can only be added once." -# assert error_message not in caplog.text - -# with pytest.raises(AssertionError): -# m.add_country_holidays("Germany") -# # assert error_message in caplog.text - - -# def test_multiple_countries(): -# # test if multiple countries are added -# df = pd.read_csv(PEYTON_FILE, nrows=NROWS) -# m = NeuralProphet( -# epochs=EPOCHS, -# batch_size=BATCH_SIZE, -# learning_rate=LR, -# ) -# m.add_country_holidays(country_name=["US", "Germany"]) -# m.fit(df, freq="D") -# m.predict(df) -# # get the name of holidays and compare that no holiday is repeated -# holiday_names = m.model.config_holidays.holiday_names -# assert "Independence Day" in holiday_names -# assert "Christmas Day" in holiday_names -# assert "Erster Weihnachtstag" not in holiday_names -# assert "Neujahr" not in holiday_names +def test_get_country_holidays(): + # deprecated + # assert issubclass(event_utils.get_country_holidays("TU").__class__, holidays.countries.turkey.TR) is True + # new format + assert issubclass(event_utils.get_all_holidays(country=["TU", "US"], years=2025).__class__, dict) is True + + for country in ("UnitedStates", "US", "USA"): + us_holidays = event_utils.get_all_holidays(country=country, years=[2019, 2020]) + assert len(us_holidays) == 10 + + with pytest.raises(NotImplementedError): + event_utils.get_holiday_names("NotSupportedCountry") + + +def test_get_country_holidays_with_subdivisions(): + # Test US holidays with a subdivision + us_ca_holidays = country_holidays("US", years=2019, subdiv="CA") + assert issubclass(us_ca_holidays.__class__, holidays.countries.united_states.UnitedStates) is True + assert len(us_ca_holidays) > 0 # Assuming there are holidays specific to CA + + # Test Canada holidays with a subdivision + ca_on_holidays = country_holidays("CA", years=2019, subdiv="ON") + assert issubclass(ca_on_holidays.__class__, holidays.countries.canada.CA) is True + assert len(ca_on_holidays) > 0 # Assuming there are holidays specific to ON + + +def test_add_country_holiday_multiple_calls_warning(caplog): + m = NeuralProphet( + epochs=EPOCHS, + batch_size=BATCH_SIZE, + learning_rate=LR, + ) + m.add_country_holidays(["US", "Germany"]) + error_message = "Country holidays can only be added once." + assert error_message not in caplog.text + + with pytest.raises(AssertionError): + m.add_country_holidays("Germany") + # assert error_message in caplog.text + + +def test_multiple_countries(): + # test if multiple countries are added + df = pd.read_csv(PEYTON_FILE, nrows=NROWS) + m = NeuralProphet( + epochs=EPOCHS, + batch_size=BATCH_SIZE, + learning_rate=LR, + ) + m.add_country_holidays(country_name=["US", "Germany"]) + m.fit(df, freq="D") + m.predict(df) + # get the name of holidays and compare that no holiday is repeated + holiday_names = m.model.config_holidays.holiday_names + assert "Independence Day" in holiday_names + assert "Christmas Day" in holiday_names + assert "Erster Weihnachtstag" not in holiday_names + assert "Neujahr" not in holiday_names def test_events():