Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Nice errors and tasklet arguments for the Python frontend #1365

Merged
merged 4 commits into from
Sep 25, 2023
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions dace/frontend/python/astutils.py
Original file line number Diff line number Diff line change
Expand Up @@ -705,3 +705,17 @@ def escape_string(value: Union[bytes, str]):
return value.encode("unicode_escape").decode("utf-8")
# Python 2.x
return value.encode('string_escape')


def parse_function_arguments(node: ast.Call, argnames: List[str]) -> Dict[str, ast.AST]:
"""
Parses function arguments (both positional and keyword) from a Call node,
based on the function's argument names. If an argument was not given, it will
alexnick83 marked this conversation as resolved.
Show resolved Hide resolved
not be in the result.
"""
result = {}
for arg, aname in zip(node.args, argnames):
result[aname] = arg
for kw in node.keywords:
result[kw.arg] = kw.value
return result
3 changes: 2 additions & 1 deletion dace/frontend/python/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,10 +293,11 @@ class tasklet(metaclass=TaskletMetaclass):
The DaCe framework cannot analyze these tasklets for optimization.
"""

def __init__(self, language: Union[str, dtypes.Language] = dtypes.Language.Python):
def __init__(self, language: Union[str, dtypes.Language] = dtypes.Language.Python, side_effects: bool = False):
if isinstance(language, str):
language = dtypes.Language[language]
self.language = language
self.side_effects = side_effects
alexnick83 marked this conversation as resolved.
Show resolved Hide resolved

def __enter__(self):
if self.language != dtypes.Language.Python:
Expand Down
6 changes: 5 additions & 1 deletion dace/frontend/python/memlet_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,7 +285,11 @@ def ParseMemlet(visitor,
if len(node.value.args) >= 2:
write_conflict_resolution = node.value.args[1]

subset, new_axes, arrdims = parse_memlet_subset(array, node, das, parsed_slice)
try:
alexnick83 marked this conversation as resolved.
Show resolved Hide resolved
subset, new_axes, arrdims = parse_memlet_subset(array, node, das, parsed_slice)
except IndexError:
raise DaceSyntaxError(visitor, node, 'Failed to parse memlet expression due to dimensionality. '
f'Array dimensions: {array.shape}, expression in code: {astutils.unparse(node)}')

# If undefined, default number of accesses is the slice size
if num_accesses is None:
Expand Down
17 changes: 17 additions & 0 deletions dace/frontend/python/newast.py
Original file line number Diff line number Diff line change
Expand Up @@ -2510,6 +2510,7 @@ def _parse_tasklet(self, state: SDFGState, node: TaskletType, name=None):

# Looking for the first argument in a tasklet annotation: @dace.tasklet(STRING HERE)
langInf = None
side_effects = None
if isinstance(node, ast.FunctionDef) and \
hasattr(node, 'decorator_list') and \
isinstance(node.decorator_list, list) and \
Expand All @@ -2522,6 +2523,19 @@ def _parse_tasklet(self, state: SDFGState, node: TaskletType, name=None):
langArg = node.decorator_list[0].args[0].value
langInf = dtypes.Language[langArg]

# Extract arguments from with statement
if isinstance(node, ast.With):
expr = node.items[0].context_expr
if isinstance(expr, ast.Call):
args = astutils.parse_function_arguments(expr, ['language', 'side_effects'])
langArg = args.get('language', None)
side_effects = args.get('side_effects', None)
langInf = astutils.evalnode(langArg, {**self.globals, **self.defined})
if isinstance(langInf, str):
langInf = dtypes.Language[langInf]

side_effects = astutils.evalnode(side_effects, {**self.globals, **self.defined})

ttrans = TaskletTransformer(self,
self.defined,
self.sdfg,
Expand All @@ -2536,6 +2550,9 @@ def _parse_tasklet(self, state: SDFGState, node: TaskletType, name=None):
symbols=self.symbols)
node, inputs, outputs, self.accesses = ttrans.parse_tasklet(node, name)

if side_effects is not None:
node.side_effects = side_effects

# Convert memlets to their actual data nodes
for i in inputs.values():
if not isinstance(i, tuple) and i.data in self.scope_vars.keys():
Expand Down
2 changes: 1 addition & 1 deletion dace/frontend/python/preprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -1268,7 +1268,7 @@ def _convert_to_ast(contents: Any):
node)
else:
# Augment closure with new value
newnode = self.resolver.global_value_to_node(e, node, f'inlined_{id(contents)}', True, keep_object=True)
newnode = self.resolver.global_value_to_node(contents, node, f'inlined_{id(contents)}', True, keep_object=True)
return newnode

return _convert_to_ast(contents)
Expand Down
Loading
Loading