From 171ddcae8becfd81cc5cace9846bdd40e0be1b6a Mon Sep 17 00:00:00 2001 From: Tal Ben-Nun Date: Tue, 5 Sep 2023 15:20:37 -0700 Subject: [PATCH] Enable more arguments for `with dace.tasklet` --- dace/frontend/python/astutils.py | 14 ++++++++++++++ dace/frontend/python/interface.py | 3 ++- dace/frontend/python/newast.py | 17 +++++++++++++++++ dace/frontend/python/preprocessing.py | 2 +- 4 files changed, 34 insertions(+), 2 deletions(-) diff --git a/dace/frontend/python/astutils.py b/dace/frontend/python/astutils.py index 4a0ec88531..faf214fdeb 100644 --- a/dace/frontend/python/astutils.py +++ b/dace/frontend/python/astutils.py @@ -705,3 +705,17 @@ def escape_string(value: Union[bytes, str]): return value.encode("unicode_escape").decode("utf-8") # Python 2.x return value.encode('string_escape') + + +def parse_function_arguments(node: ast.Call, argnames: List[str]) -> Dict[str, ast.AST]: + """ + Parses function arguments (both positional and keyword) from a Call node, + based on the function's argument names. If an argument was not given, it will + not be in the result. + """ + result = {} + for arg, aname in zip(node.args, argnames): + result[aname] = arg + for kw in node.keywords: + result[kw.arg] = kw.value + return result diff --git a/dace/frontend/python/interface.py b/dace/frontend/python/interface.py index ea1970dafd..69e650beaa 100644 --- a/dace/frontend/python/interface.py +++ b/dace/frontend/python/interface.py @@ -293,10 +293,11 @@ class tasklet(metaclass=TaskletMetaclass): The DaCe framework cannot analyze these tasklets for optimization. """ - def __init__(self, language: Union[str, dtypes.Language] = dtypes.Language.Python): + def __init__(self, language: Union[str, dtypes.Language] = dtypes.Language.Python, side_effects: bool = False): if isinstance(language, str): language = dtypes.Language[language] self.language = language + self.side_effects = side_effects def __enter__(self): if self.language != dtypes.Language.Python: diff --git a/dace/frontend/python/newast.py b/dace/frontend/python/newast.py index c9d92b7860..b5d27e14f4 100644 --- a/dace/frontend/python/newast.py +++ b/dace/frontend/python/newast.py @@ -2510,6 +2510,7 @@ def _parse_tasklet(self, state: SDFGState, node: TaskletType, name=None): # Looking for the first argument in a tasklet annotation: @dace.tasklet(STRING HERE) langInf = None + side_effects = None if isinstance(node, ast.FunctionDef) and \ hasattr(node, 'decorator_list') and \ isinstance(node.decorator_list, list) and \ @@ -2522,6 +2523,19 @@ def _parse_tasklet(self, state: SDFGState, node: TaskletType, name=None): langArg = node.decorator_list[0].args[0].value langInf = dtypes.Language[langArg] + # Extract arguments from with statement + if isinstance(node, ast.With): + expr = node.items[0].context_expr + if isinstance(expr, ast.Call): + args = astutils.parse_function_arguments(expr, ['language', 'side_effects']) + langArg = args.get('language', None) + side_effects = args.get('side_effects', None) + langInf = astutils.evalnode(langArg, {**self.globals, **self.defined}) + if isinstance(langInf, str): + langInf = dtypes.Language[langInf] + + side_effects = astutils.evalnode(side_effects, {**self.globals, **self.defined}) + ttrans = TaskletTransformer(self, self.defined, self.sdfg, @@ -2536,6 +2550,9 @@ def _parse_tasklet(self, state: SDFGState, node: TaskletType, name=None): symbols=self.symbols) node, inputs, outputs, self.accesses = ttrans.parse_tasklet(node, name) + if side_effects is not None: + node.side_effects = side_effects + # Convert memlets to their actual data nodes for i in inputs.values(): if not isinstance(i, tuple) and i.data in self.scope_vars.keys(): diff --git a/dace/frontend/python/preprocessing.py b/dace/frontend/python/preprocessing.py index 10a1ab120e..239875118f 100644 --- a/dace/frontend/python/preprocessing.py +++ b/dace/frontend/python/preprocessing.py @@ -1268,7 +1268,7 @@ def _convert_to_ast(contents: Any): node) else: # Augment closure with new value - newnode = self.resolver.global_value_to_node(e, node, f'inlined_{id(contents)}', True, keep_object=True) + newnode = self.resolver.global_value_to_node(contents, node, f'inlined_{id(contents)}', True, keep_object=True) return newnode return _convert_to_ast(contents)