Skip to content

Commit

Permalink
feat: adding a new fn_name argument to `_test_source_to_source_func…
Browse files Browse the repository at this point in the history
…tion` to correctly retrieve the object from the lazily transpiled kornia module.
  • Loading branch information
YushaArif99 committed Sep 27, 2024
1 parent 7d7a584 commit d56fd86
Showing 1 changed file with 26 additions and 2 deletions.
28 changes: 26 additions & 2 deletions helpers.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import copy
import inspect
import gast
import ivy
import jax
import jax.numpy as jnp
Expand Down Expand Up @@ -168,6 +170,23 @@ def _target_to_simplified(target: str):
return "pt"
return target

def _get_fn_name_from_stack():
# Get the previous calling stack frame
stack = inspect.stack()
caller_frame = stack[2] # The function two levels above (eg: _test_function -> test_rgb_to_grayscale)

source_code = inspect.getsource(caller_frame.frame)
parsed_ast = gast.parse(source_code)

# Traverse the AST to find the call to _test_function and extract the first argument
for node in gast.walk(parsed_ast):
if isinstance(node, gast.Call) and hasattr(node.func, 'id') and node.func.id == '_test_function':
# Found the call to _test_function, extract the first argument (the fn)
first_arg = node.args[0] # This is the first argument passed to _test_function
if isinstance(first_arg, gast.Attribute):
# Get the full name of the function (e.g., "kornia.color.rgb_to_grayscale")
return gast.unparse(first_arg).strip()
return None

def _test_trace_function(
fn,
Expand Down Expand Up @@ -232,6 +251,7 @@ def _test_transpile_function(

def _test_source_to_source_function(
fn,
fn_name,
trace_args,
trace_kwargs,
test_args,
Expand All @@ -245,7 +265,10 @@ def _test_source_to_source_function(
pytest.skip()

transpiled_kornia = ivy.transpile(kornia, source="torch", target=target)
translated_fn = eval("transpiled_" + f"{fn.__module__}.{fn.__name__}")
if fn_name:
translated_fn = eval("transpiled_" + f"{fn_name}")
else:
translated_fn = eval("transpiled_" + f"{fn.__module__}.{fn.__name__}")

if backend_compile:
try:
Expand Down Expand Up @@ -297,14 +320,15 @@ def _test_function(
):
# print out the full function module/name, so it will appear in the test_report.json
print(f"{fn.__module__}.{fn.__name__}")

fn_name = _get_fn_name_from_stack()
if skip and mode != "s2s":
# any skipped due to DCF issues should still work with ivy.source_to_source
pytest.skip()

if mode == "s2s":
_test_source_to_source_function(
fn,
fn_name,
trace_args,
trace_kwargs,
test_args,
Expand Down

0 comments on commit d56fd86

Please sign in to comment.