Skip to content

Commit

Permalink
Unused imports backport (successor to #1808) (#1816)
Browse files Browse the repository at this point in the history
Author: @romanc

@romanc is on leave for the next few days, thus I have replayed the
changes made in #1808 here for faster
turnaround.
  • Loading branch information
phschaad authored Dec 7, 2024
1 parent 00f05e6 commit 3466973
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 22 deletions.
32 changes: 20 additions & 12 deletions dace/frontend/python/newast.py
Original file line number Diff line number Diff line change
Expand Up @@ -1319,7 +1319,7 @@ def _views_to_data(state: SDFGState, nodes: List[dace.nodes.AccessNode]) -> List
self.sdfg.replace_dict(repl_dict)

propagate_states(self.sdfg)
for state, memlet, inner_indices in itertools.chain(self.inputs.values(), self.outputs.values()):
for state, memlet, _inner_indices in itertools.chain(self.inputs.values(), self.outputs.values()):
if state is not None and state.dynamic_executions:
memlet.dynamic = True

Expand Down Expand Up @@ -2366,8 +2366,11 @@ def visit_For(self, node: ast.For):
init_expr='%s = %s' % (indices[0], astutils.unparse(ast_ranges[0][0])),
update_expr=incr[indices[0]],
inverted=False)
_, first_subblock, _, _ = self._recursive_visit(node.body, f'for_{node.lineno}', node.lineno,
extra_symbols=extra_syms, parent=loop_region,
_, first_subblock, _, _ = self._recursive_visit(node.body,
f'for_{node.lineno}',
node.lineno,
extra_symbols=extra_syms,
parent=loop_region,
unconnected_last_block=False)
loop_region.start_block = loop_region.node_id(first_subblock)
self._connect_break_blocks(loop_region)
Expand Down Expand Up @@ -2449,7 +2452,10 @@ def visit_While(self, node: ast.While):
loop_region = self._add_loop_region(loop_cond, label=f'while_{node.lineno}', inverted=False)

# Parse body
self._recursive_visit(node.body, f'while_{node.lineno}', node.lineno, parent=loop_region,
self._recursive_visit(node.body,
f'while_{node.lineno}',
node.lineno,
parent=loop_region,
unconnected_last_block=False)

if test_region is not None:
Expand Down Expand Up @@ -2540,7 +2546,6 @@ def _has_loop_ancestor(self, node: ControlFlowBlock) -> bool:
node = node.parent_graph
return False


def visit_Break(self, node: ast.Break):
if not self._has_loop_ancestor(self.cfg_target):
raise DaceSyntaxError(self, node, "Break block outside loop region")
Expand Down Expand Up @@ -2572,8 +2577,7 @@ def visit_If(self, node: ast.If):

# Process 'else'/'elif' statements
if len(node.orelse) > 0:
else_body = ControlFlowRegion(f'{cond_block.label}_else_{node.orelse[0].lineno}',
sdfg=self.sdfg)
else_body = ControlFlowRegion(f'{cond_block.label}_else_{node.orelse[0].lineno}', sdfg=self.sdfg)
cond_block.add_branch(None, else_body)
# Visit recursively
self._recursive_visit(node.orelse, 'else', node.lineno, else_body, False)
Expand Down Expand Up @@ -2934,7 +2938,6 @@ def _add_aug_assignment(self,
wsqueezed = [i for i in range(len(wtarget_subset)) if i not in wsqz]
rsqueezed = [i for i in range(len(rtarget_subset)) if i not in rsqz]


if (boolarr or indirect_indices
or (sqz_wsub.size() == sqz_osub.size() and sqz_wsub.size() == sqz_rsub.size())):
map_range = {i: rng for i, rng in all_idx_tuples}
Expand Down Expand Up @@ -3358,8 +3361,11 @@ def _visit_assign(self, node, node_target, op, dtype=None, is_return=False):

new_data, rng = None, None
dtype_keys = tuple(dtypes.dtype_to_typeclass().keys())
if not (result in self.sdfg.symbols or symbolic.issymbolic(result) or isinstance(result, dtype_keys) or
(isinstance(result, str) and any(result in x for x in [self.sdfg.arrays, self.sdfg._pgrids, self.sdfg._subarrays, self.sdfg._rdistrarrays]))):
if not (
result in self.sdfg.symbols or symbolic.issymbolic(result) or isinstance(result, dtype_keys) or
(isinstance(result, str) and any(
result in x
for x in [self.sdfg.arrays, self.sdfg._pgrids, self.sdfg._subarrays, self.sdfg._rdistrarrays]))):
raise DaceSyntaxError(
self, node, "In assignments, the rhs may only be "
"data, numerical/boolean constants "
Expand Down Expand Up @@ -3467,7 +3473,9 @@ def _visit_assign(self, node, node_target, op, dtype=None, is_return=False):
cname = self.sdfg.find_new_constant(f'__ind{i}_{true_name}')
self.sdfg.add_constant(cname, carr)
# Add constant to descriptor repository
self.sdfg.add_array(cname, carr.shape, dtypes.dtype_to_typeclass(carr.dtype.type),
self.sdfg.add_array(cname,
carr.shape,
dtypes.dtype_to_typeclass(carr.dtype.type),
transient=True)
if numpy.array(arr).dtype == numpy.bool_:
boolarr = cname
Expand Down Expand Up @@ -4769,7 +4777,7 @@ def visit_With(self, node: ast.With, is_async=False):
evald = astutils.evalnode(node.items[0].context_expr, self.globals)
if hasattr(evald, "name"):
named_region_name: str = evald.name
else:
else:
named_region_name = f"Named Region {node.lineno}"
named_region = NamedRegion(named_region_name, debuginfo=self.current_lineinfo)
self.cfg_target.add_node(named_region)
Expand Down
22 changes: 12 additions & 10 deletions dace/frontend/python/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,9 +59,10 @@ def _get_locals_and_globals(f):
result.update(f.__globals__)
# grab the free variables (i.e. locals)
if f.__closure__ is not None:
result.update(
{k: v
for k, v in zip(f.__code__.co_freevars, [_get_cell_contents_or_none(x) for x in f.__closure__])})
result.update({
k: v
for k, v in zip(f.__code__.co_freevars, [_get_cell_contents_or_none(x) for x in f.__closure__])
})

return result

Expand Down Expand Up @@ -142,6 +143,7 @@ def infer_symbols_from_datadescriptor(sdfg: SDFG,
class DaceProgram(pycommon.SDFGConvertible):
""" A data-centric program object, obtained by decorating a function with
``@dace.program``. """

def __init__(self,
f,
args,
Expand Down Expand Up @@ -405,9 +407,10 @@ def _create_sdfg_args(self, sdfg: SDFG, args: Tuple[Any], kwargs: Dict[str, Any]

# Update arguments with symbols in data shapes
result.update(
infer_symbols_from_datadescriptor(
sdfg, {k: create_datadescriptor(v)
for k, v in result.items() if k not in self.constant_args}))
infer_symbols_from_datadescriptor(sdfg, {
k: create_datadescriptor(v)
for k, v in result.items() if k not in self.constant_args
}))
return result

def __call__(self, *args, **kwargs):
Expand Down Expand Up @@ -487,9 +490,6 @@ def _parse(self, args, kwargs, simplify=None, save=False, validate=False) -> SDF
:param validate: If True, validates the resulting SDFG after creation.
:return: The generated SDFG object.
"""
# Avoid import loop
from dace.transformation.passes import scalar_to_symbol as scal2sym
from dace.transformation import helpers as xfh

# Obtain DaCe program as SDFG
sdfg, cached = self._generate_pdp(args, kwargs, simplify=simplify)
Expand Down Expand Up @@ -812,7 +812,9 @@ def get_program_hash(self, *args, **kwargs) -> cached_program.ProgramCacheKey:
_, key = self._load_sdfg(None, *args, **kwargs)
return key

def _generate_pdp(self, args: Tuple[Any], kwargs: Dict[str, Any],
def _generate_pdp(self,
args: Tuple[Any],
kwargs: Dict[str, Any],
simplify: Optional[bool] = None) -> Tuple[SDFG, bool]:
""" Generates the parsed AST representation of a DaCe program.
Expand Down

0 comments on commit 3466973

Please sign in to comment.