Skip to content

Commit

Permalink
switch to dict (uxlfoundation#2111)
Browse files Browse the repository at this point in the history
  • Loading branch information
icfaust authored Oct 15, 2024
1 parent 89a37a8 commit 2dd89cd
Showing 1 changed file with 14 additions and 13 deletions.
27 changes: 14 additions & 13 deletions sklearnex/tests/test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,12 +306,12 @@ def estimator_trace(estimator, method, cache, capsys, monkeypatch):
cache.set("key", key)
cache.set(
"text",
[
re.findall(regex_func, text),
text,
[i.replace(os.sep, ".") for i in re.findall(regex_mod, text)],
[""] + re.findall(regex_callingline, text),
],
{
"funcs": re.findall(regex_func, text),
"trace": text,
"modules": [i.replace(os.sep, ".") for i in re.findall(regex_mod, text)],
"callingline": [""] + re.findall(regex_callingline, text),
},
)

return cache.get("text", None)
Expand All @@ -322,8 +322,8 @@ def call_validate_data(text, estimator, method):
called once before offloading to oneDAL in sklearnex"""
try:
# get last to_table call showing end of oneDAL input portion of code
idx = len(text[0]) - 1 - text[0][::-1].index("to_table")
validfuncs = text[0][:idx]
idx = len(text["funcs"]) - 1 - text["funcs"][::-1].index("to_table")
validfuncs = text["funcs"][:idx]
except ValueError:
pytest.skip("onedal backend not used in this function")

Expand All @@ -341,16 +341,17 @@ def n_jobs_check(text, estimator, method):
"""verify the n_jobs is being set if '_get_backend' or 'to_table' is called"""
# remove the _get_backend function from sklearnex from considered _get_backend
count = max(
text[0].count("to_table"),
text["funcs"].count("to_table"),
len(
[
i
for i in range(len(text[0]))
if text[0][i] == "_get_backend" and "sklearnex" not in text[2][i]
for i in range(len(text["funcs"]))
if text["funcs"][i] == "_get_backend"
and "sklearnex" not in text["modules"][i]
]
),
)
n_jobs_count = text[0].count("n_jobs_wrapper")
n_jobs_count = text["funcs"].count("n_jobs_wrapper")

assert bool(count) == bool(
n_jobs_count
Expand All @@ -360,7 +361,7 @@ def n_jobs_check(text, estimator, method):
def runtime_property_check(text, estimator, method):
"""use of Python's 'property' should not be used at runtime, only at class instantiation"""
assert (
len(re.findall(r"property\(", text[1])) == 0
len(re.findall(r"property\(", text["trace"])) == 0
), f"{estimator}.{method} should only use 'property' at instantiation"


Expand Down

0 comments on commit 2dd89cd

Please sign in to comment.