Skip to content

Commit

Permalink
adjust tests
Browse files Browse the repository at this point in the history
  • Loading branch information
MaiBe-ctrl committed Jun 25, 2024
1 parent 3cbd5ab commit 2176e14
Show file tree
Hide file tree
Showing 2 changed files with 68 additions and 68 deletions.
20 changes: 10 additions & 10 deletions neuralprophet/event_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def get_holiday_names(country: Union[str, Iterable[str]], df=None):

all_holidays = get_all_holidays(years=years, country=country)

return list(all_holidays.keys())
return set(all_holidays.keys())


def get_all_holidays(years, country):
Expand Down Expand Up @@ -70,13 +70,13 @@ def get_all_holidays(years, country):
for date, name in single_country_specific_holidays.items():
all_holidays[name].append(pd.to_datetime(date))

# merge holidays in different languages (having the exact same datatimes) into one holiday item
unique_holidays = {}
for key1, val1 in all_holidays.items():
holiday_names = [key1]
for key2, val2 in all_holidays.items():
if set(val1) == set(val2) and key1 != key2:
holiday_names.append(key2)
unique_holidays["_".join(list(holiday_names))] = val1
# # merge holidays in different languages (having the exact same datatimes) into one holiday item
# unique_holidays = {}
# for key1, val1 in all_holidays.items():
# holiday_names = [key1]
# for key2, val2 in all_holidays.items():
# if set(val1) == set(val2) and key1 != key2:
# holiday_names.append(key2)
# unique_holidays["_".join(list(holiday_names))] = val1

return unique_holidays
return all_holidays
116 changes: 58 additions & 58 deletions tests/test_event_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,64 +29,64 @@
PLOT = False


# def test_get_country_holidays():
# # deprecated
# # assert issubclass(event_utils.get_country_holidays("TU").__class__, holidays.countries.turkey.TR) is True
# # new format
# assert issubclass(event_utils.get_all_holidays(country=["TU", "US"], years=2025).__class__, dict) is True

# for country in ("UnitedStates", "US", "USA"):
# us_holidays = event_utils.get_all_holidays(country=country, years=[2019, 2020])
# assert len(us_holidays) == 10

# with pytest.raises(NotImplementedError):
# event_utils.get_holiday_names("NotSupportedCountry")


# def test_get_country_holidays_with_subdivisions():
# # Test US holidays with a subdivision
# us_ca_holidays = country_holidays("US", years=2019, subdiv="CA")
# assert issubclass(us_ca_holidays.__class__, holidays.countries.united_states.UnitedStates) is True
# assert len(us_ca_holidays) > 0 # Assuming there are holidays specific to CA

# # Test Canada holidays with a subdivision
# ca_on_holidays = country_holidays("CA", years=2019, subdiv="ON")
# assert issubclass(ca_on_holidays.__class__, holidays.countries.canada.CA) is True
# assert len(ca_on_holidays) > 0 # Assuming there are holidays specific to ON


# def test_add_country_holiday_multiple_calls_warning(caplog):
# m = NeuralProphet(
# epochs=EPOCHS,
# batch_size=BATCH_SIZE,
# learning_rate=LR,
# )
# m.add_country_holidays(["US", "Germany"])
# error_message = "Country holidays can only be added once."
# assert error_message not in caplog.text

# with pytest.raises(AssertionError):
# m.add_country_holidays("Germany")
# # assert error_message in caplog.text


# def test_multiple_countries():
# # test if multiple countries are added
# df = pd.read_csv(PEYTON_FILE, nrows=NROWS)
# m = NeuralProphet(
# epochs=EPOCHS,
# batch_size=BATCH_SIZE,
# learning_rate=LR,
# )
# m.add_country_holidays(country_name=["US", "Germany"])
# m.fit(df, freq="D")
# m.predict(df)
# # get the name of holidays and compare that no holiday is repeated
# holiday_names = m.model.config_holidays.holiday_names
# assert "Independence Day" in holiday_names
# assert "Christmas Day" in holiday_names
# assert "Erster Weihnachtstag" not in holiday_names
# assert "Neujahr" not in holiday_names
def test_get_country_holidays():
# deprecated
# assert issubclass(event_utils.get_country_holidays("TU").__class__, holidays.countries.turkey.TR) is True
# new format
assert issubclass(event_utils.get_all_holidays(country=["TU", "US"], years=2025).__class__, dict) is True

for country in ("UnitedStates", "US", "USA"):
us_holidays = event_utils.get_all_holidays(country=country, years=[2019, 2020])
assert len(us_holidays) == 10

with pytest.raises(NotImplementedError):
event_utils.get_holiday_names("NotSupportedCountry")


def test_get_country_holidays_with_subdivisions():
# Test US holidays with a subdivision
us_ca_holidays = country_holidays("US", years=2019, subdiv="CA")
assert issubclass(us_ca_holidays.__class__, holidays.countries.united_states.UnitedStates) is True
assert len(us_ca_holidays) > 0 # Assuming there are holidays specific to CA

# Test Canada holidays with a subdivision
ca_on_holidays = country_holidays("CA", years=2019, subdiv="ON")
assert issubclass(ca_on_holidays.__class__, holidays.countries.canada.CA) is True
assert len(ca_on_holidays) > 0 # Assuming there are holidays specific to ON


def test_add_country_holiday_multiple_calls_warning(caplog):
m = NeuralProphet(
epochs=EPOCHS,
batch_size=BATCH_SIZE,
learning_rate=LR,
)
m.add_country_holidays(["US", "Germany"])
error_message = "Country holidays can only be added once."
assert error_message not in caplog.text

with pytest.raises(AssertionError):
m.add_country_holidays("Germany")
# assert error_message in caplog.text


def test_multiple_countries():
# test if multiple countries are added
df = pd.read_csv(PEYTON_FILE, nrows=NROWS)
m = NeuralProphet(
epochs=EPOCHS,
batch_size=BATCH_SIZE,
learning_rate=LR,
)
m.add_country_holidays(country_name=["US", "Germany"])
m.fit(df, freq="D")
m.predict(df)
# get the name of holidays and compare that no holiday is repeated
holiday_names = m.model.config_holidays.holiday_names
assert "Independence Day" in holiday_names
assert "Christmas Day" in holiday_names
assert "Erster Weihnachtstag" not in holiday_names
assert "Neujahr" not in holiday_names


def test_events():
Expand Down

0 comments on commit 2176e14

Please sign in to comment.