Skip to content

Commit

Permalink
Enable more arguments for with dace.tasklet
Browse files Browse the repository at this point in the history
  • Loading branch information
tbennun committed Sep 5, 2023
1 parent f95f816 commit 171ddca
Show file tree
Hide file tree
Showing 4 changed files with 34 additions and 2 deletions.
14 changes: 14 additions & 0 deletions dace/frontend/python/astutils.py
Original file line number Diff line number Diff line change
Expand Up @@ -705,3 +705,17 @@ def escape_string(value: Union[bytes, str]):
return value.encode("unicode_escape").decode("utf-8")
# Python 2.x
return value.encode('string_escape')


def parse_function_arguments(node: ast.Call, argnames: List[str]) -> Dict[str, ast.AST]:
"""
Parses function arguments (both positional and keyword) from a Call node,
based on the function's argument names. If an argument was not given, it will
not be in the result.
"""
result = {}
for arg, aname in zip(node.args, argnames):
result[aname] = arg
for kw in node.keywords:
result[kw.arg] = kw.value
return result
3 changes: 2 additions & 1 deletion dace/frontend/python/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,10 +293,11 @@ class tasklet(metaclass=TaskletMetaclass):
The DaCe framework cannot analyze these tasklets for optimization.
"""

def __init__(self, language: Union[str, dtypes.Language] = dtypes.Language.Python):
def __init__(self, language: Union[str, dtypes.Language] = dtypes.Language.Python, side_effects: bool = False):
if isinstance(language, str):
language = dtypes.Language[language]
self.language = language
self.side_effects = side_effects

def __enter__(self):
if self.language != dtypes.Language.Python:
Expand Down
17 changes: 17 additions & 0 deletions dace/frontend/python/newast.py
Original file line number Diff line number Diff line change
Expand Up @@ -2510,6 +2510,7 @@ def _parse_tasklet(self, state: SDFGState, node: TaskletType, name=None):

# Looking for the first argument in a tasklet annotation: @dace.tasklet(STRING HERE)
langInf = None
side_effects = None
if isinstance(node, ast.FunctionDef) and \
hasattr(node, 'decorator_list') and \
isinstance(node.decorator_list, list) and \
Expand All @@ -2522,6 +2523,19 @@ def _parse_tasklet(self, state: SDFGState, node: TaskletType, name=None):
langArg = node.decorator_list[0].args[0].value
langInf = dtypes.Language[langArg]

# Extract arguments from with statement
if isinstance(node, ast.With):
expr = node.items[0].context_expr
if isinstance(expr, ast.Call):
args = astutils.parse_function_arguments(expr, ['language', 'side_effects'])
langArg = args.get('language', None)
side_effects = args.get('side_effects', None)
langInf = astutils.evalnode(langArg, {**self.globals, **self.defined})
if isinstance(langInf, str):
langInf = dtypes.Language[langInf]

side_effects = astutils.evalnode(side_effects, {**self.globals, **self.defined})

ttrans = TaskletTransformer(self,
self.defined,
self.sdfg,
Expand All @@ -2536,6 +2550,9 @@ def _parse_tasklet(self, state: SDFGState, node: TaskletType, name=None):
symbols=self.symbols)
node, inputs, outputs, self.accesses = ttrans.parse_tasklet(node, name)

if side_effects is not None:
node.side_effects = side_effects

# Convert memlets to their actual data nodes
for i in inputs.values():
if not isinstance(i, tuple) and i.data in self.scope_vars.keys():
Expand Down
2 changes: 1 addition & 1 deletion dace/frontend/python/preprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -1268,7 +1268,7 @@ def _convert_to_ast(contents: Any):
node)
else:
# Augment closure with new value
newnode = self.resolver.global_value_to_node(e, node, f'inlined_{id(contents)}', True, keep_object=True)
newnode = self.resolver.global_value_to_node(contents, node, f'inlined_{id(contents)}', True, keep_object=True)
return newnode

return _convert_to_ast(contents)
Expand Down

0 comments on commit 171ddca

Please sign in to comment.