Skip to content

Commit

Permalink
Merge pull request #1120 from spcl/frontend-fix
Browse files Browse the repository at this point in the history
Fix string literals in elementwise and disallowed keywords in decorators
  • Loading branch information
tbennun authored Oct 11, 2022
2 parents d12b902 + 8a8a3af commit f1afce3
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 8 deletions.
20 changes: 14 additions & 6 deletions dace/frontend/python/preprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -1413,14 +1413,22 @@ def visit_AugAssign(self, node: ast.AugAssign) -> ast.Assign:

def find_disallowed_statements(node: ast.AST):
from dace.frontend.python.newast import DISALLOWED_STMTS # Avoid import loop
for subnode in ast.walk(node):
# Found disallowed statement
if type(subnode).__name__ in DISALLOWED_STMTS:
return type(subnode).__name__
# Skip everything until the function contents (in case there are disallowed statements in a decorator)
if isinstance(node, ast.Module) and isinstance(node.body[0], ast.FunctionDef):
nodes = node.body[0].body
else:
nodes = [node]

if isinstance(subnode, ast.Call):
if any(k.arg is None for k in subnode.keywords):
for topnode in nodes:
for subnode in ast.walk(topnode):
# Found disallowed statement
if type(subnode).__name__ in DISALLOWED_STMTS:
return type(subnode).__name__

# Calls with double-starred arguments (**args)
if isinstance(subnode, ast.Call):
if any(k.arg is None for k in subnode.keywords):
return type(subnode).__name__
return None


Expand Down
4 changes: 2 additions & 2 deletions dace/frontend/python/replacements.py
Original file line number Diff line number Diff line change
Expand Up @@ -556,7 +556,7 @@ def _arange(pv: 'ProgramVisitor', sdfg: SDFG, state: SDFGState, *args, **kwargs)

@oprepo.replaces('elementwise')
@oprepo.replaces('dace.elementwise')
def _elementwise(pv: 'ProgramVisitor', sdfg: SDFG, state: SDFGState, func: str, in_array: str, out_array=None):
def _elementwise(pv: 'ProgramVisitor', sdfg: SDFG, state: SDFGState, func: Union[StringLiteral, str], in_array: str, out_array=None):
"""Apply a lambda function to each element in the input"""

inparr = sdfg.arrays[in_array]
Expand All @@ -567,7 +567,7 @@ def _elementwise(pv: 'ProgramVisitor', sdfg: SDFG, state: SDFGState, func: str,
else:
outarr = sdfg.arrays[out_array]

func_ast = ast.parse(func)
func_ast = ast.parse(func.value if isinstance(func, StringLiteral) else func)
try:
lambda_ast = func_ast.body[0].value
if len(lambda_ast.args.args) != 1:
Expand Down
17 changes: 17 additions & 0 deletions tests/python_frontend/preparse_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,9 @@


def test_nested_objects_same_name():

class ObjA:

def __init__(self, q) -> None:
self.q = np.full([20], q)

Expand All @@ -16,6 +18,7 @@ def __call__(self, A):
return A + self.q

class ObjB:

def __init__(self, q) -> None:
self.q = np.full([20], q)
self.obja = ObjA(q * 2)
Expand Down Expand Up @@ -62,7 +65,9 @@ def outer(self, A):


def test_calltree():

class ObjA:

def __init__(self, q) -> None:
self.q = np.full([20], q)

Expand All @@ -71,6 +76,7 @@ def __call__(self, A):
return A + self.q

class ObjB:

def __init__(self, q) -> None:
self.q = np.full([20], q)
self.obja = ObjA(q * 2)
Expand Down Expand Up @@ -112,7 +118,18 @@ def mainprog(A: dace.float64[20]):
assert len(mainprog.resolver.closure_arrays) == 2


def test_program_kwargs():
kwargs = dict(auto_optimize=False)

@dace.program(**kwargs)
def tester():
pass

tester.to_sdfg()


if __name__ == '__main__':
test_nested_objects_same_name()
test_calltree()
test_same_function_different_closure()
test_program_kwargs()

0 comments on commit f1afce3

Please sign in to comment.