Skip to content

Commit

Permalink
Bugfixes
Browse files Browse the repository at this point in the history
  • Loading branch information
phschaad committed Sep 20, 2023
1 parent 41a0abf commit 89fe7b1
Show file tree
Hide file tree
Showing 5 changed files with 63 additions and 57 deletions.
41 changes: 21 additions & 20 deletions dace/frontend/python/newast.py
Original file line number Diff line number Diff line change
Expand Up @@ -1210,7 +1210,7 @@ def parse_program(self, program: ast.FunctionDef, is_tasklet: bool = False):
for stmt in program.body:
self.visit_TopLevel(stmt)
if len(self.sdfg.nodes()) == 0:
self.cfg_target.add_state('EmptyState')
self.sdfg.add_state('EmptyState')

# Handle return values
# Assignments to return values become __return* arrays
Expand Down Expand Up @@ -1277,7 +1277,7 @@ def _views_to_data(state: SDFGState, nodes: List[dace.nodes.AccessNode]) -> List
return new_nodes

# Map view access nodes to their respective data
for state in self.sdfg.nodes():
for state in self.sdfg.states():
# NOTE: We need to support views of views
nodes = list(state.data_nodes())
while nodes:
Expand Down Expand Up @@ -2023,7 +2023,7 @@ def _add_dependencies(self,
else:
name = memlet.data
vname = "{c}_in_from_{s}{n}".format(c=conn,
s=self.sdfg.nodes().index(state),
s=self.sdfg.states().index(state),
n=('_%s' % state.node_id(entry_node) if entry_node else ''))
self.accesses[(name, scope_memlet.subset, 'r')] = (vname, orng)
orig_shape = orng.size()
Expand Down Expand Up @@ -2113,7 +2113,7 @@ def _add_dependencies(self,
else:
name = memlet.data
vname = "{c}_out_of_{s}{n}".format(c=conn,
s=self.sdfg.nodes().index(state),
s=self.sdfg.states().index(state),
n=('_%s' % state.node_id(exit_node) if exit_node else ''))
self.accesses[(name, scope_memlet.subset, 'w')] = (vname, orng)
orig_shape = orng.size()
Expand Down Expand Up @@ -2210,7 +2210,8 @@ def _recursive_visit(self,
# Restore previous target
self.cfg_target = previous_target
self.last_cfg_target = previous_last_cfg_target
self.last_block = previous_block
if not unconnected_last_block:
self.last_block = previous_block

return previous_block, first_innner_block, last_inner_block, has_return_statement

Expand Down Expand Up @@ -2358,7 +2359,7 @@ def visit_For(self, node: ast.For):
# The state that all "break" edges go to
state = self.cfg_target.add_state(f'postloop_{node.lineno}')
if self.last_block is not None:
self.sdfg.add_edge(self.last_block, state, dace.InterstateEdge())
self.cfg_target.add_edge(self.last_block, state, dace.InterstateEdge())
self.last_block = state
return state

Expand Down Expand Up @@ -2473,25 +2474,25 @@ def visit_If(self, node: ast.If):

# Visit recursively
laststate, first_if_state, last_if_state, return_stmt = \
self._recursive_visit(node.body, 'if', node.lineno)
self._recursive_visit(node.body, 'if', node.lineno, self.cfg_target, True)
end_if_state = self.last_block

# Connect the states
self.sdfg.add_edge(laststate, first_if_state, dace.InterstateEdge(cond))
self.sdfg.add_edge(last_if_state, end_if_state, dace.InterstateEdge(condition=f"{not return_stmt}"))
self.cfg_target.add_edge(laststate, first_if_state, dace.InterstateEdge(cond))
self.cfg_target.add_edge(last_if_state, end_if_state, dace.InterstateEdge(condition=f"{not return_stmt}"))

# Process 'else'/'elif' statements
if len(node.orelse) > 0:
# Visit recursively
_, first_else_state, last_else_state, return_stmt = \
self._recursive_visit(node.orelse, 'else', node.lineno, False)
self._recursive_visit(node.orelse, 'else', node.lineno, self.cfg_target, False)

# Connect the states
self.sdfg.add_edge(laststate, first_else_state, dace.InterstateEdge(cond_else))
self.sdfg.add_edge(last_else_state, end_if_state, dace.InterstateEdge(condition=f"{not return_stmt}"))
self.last_block = end_if_state
self.cfg_target.add_edge(laststate, first_else_state, dace.InterstateEdge(cond_else))
self.cfg_target.add_edge(last_else_state, end_if_state, dace.InterstateEdge(condition=f"{not return_stmt}"))
else:
self.sdfg.add_edge(laststate, end_if_state, dace.InterstateEdge(cond_else))
self.cfg_target.add_edge(laststate, end_if_state, dace.InterstateEdge(cond_else))
self.last_block = end_if_state

def _parse_tasklet(self, state: SDFGState, node: TaskletType, name=None):

Expand Down Expand Up @@ -3323,7 +3324,7 @@ def _visit_assign(self, node, node_target, op, dtype=None, is_return=False):
# Handle output indirection
output_indirection = None
if _subset_has_indirection(rng, self):
output_indirection = self.sdfg.add_state('wslice_%s_%d' % (new_name, node.lineno))
output_indirection = self.cfg_target.add_state('wslice_%s_%d' % (new_name, node.lineno))
wnode = output_indirection.add_write(new_name, debuginfo=self.current_lineinfo)
memlet = Memlet.simple(new_name, str(rng))
# Dependent augmented assignments need WCR in the
Expand Down Expand Up @@ -3369,7 +3370,7 @@ def _visit_assign(self, node, node_target, op, dtype=None, is_return=False):

# Connect states properly when there is output indirection
if output_indirection:
self.sdfg.add_edge(self.last_block, output_indirection, dace.sdfg.InterstateEdge())
self.cfg_target.add_edge(self.last_block, output_indirection, dace.sdfg.InterstateEdge())
self.last_block = output_indirection

def visit_AugAssign(self, node: ast.AugAssign):
Expand Down Expand Up @@ -3814,8 +3815,8 @@ def _parse_sdfg_call(self, funcname: str, func: Union[SDFG, SDFGConvertible], no
for sym, local in mapping.items():
if isinstance(local, str) and local in self.sdfg.arrays:
# Add assignment state and inter-state edge
symassign_state = self.sdfg.add_state_before(state)
isedge = self.sdfg.edges_between(symassign_state, state)[0]
symassign_state = self.cfg_target.add_state_before(state)
isedge = self.cfg_target.edges_between(symassign_state, state)[0]
newsym = self.sdfg.find_new_symbol(f'sym_{local}')
desc = self.sdfg.arrays[local]
self.sdfg.add_symbol(newsym, desc.dtype)
Expand Down Expand Up @@ -3879,7 +3880,7 @@ def _parse_sdfg_call(self, funcname: str, func: Union[SDFG, SDFGConvertible], no
# Delete the old read descriptor
if not isinput:
conn_used = False
for s in self.sdfg.nodes():
for s in self.sdfg.states():
for n in s.data_nodes():
if n.data == aname:
conn_used = True
Expand Down Expand Up @@ -4866,7 +4867,7 @@ def _promote(node: ast.AST) -> Union[Any, str, symbolic.symbol]:
# `not sym` returns True. This exception is benign.
pass
state = self._add_state(f'promote_{scalar}_to_{str(sym)}')
edge = self.sdfg.in_edges(state)[0]
edge = state.parent.in_edges(state)[0]
edge.data.assignments = {str(sym): scalar}
return sym
return scalar
Expand Down
11 changes: 6 additions & 5 deletions dace/sdfg/sdfg.py
Original file line number Diff line number Diff line change
Expand Up @@ -1261,7 +1261,8 @@ def used_symbols(self, all_symbols: bool) -> Tuple[Set[str], Set[str], Set[str]]
# Add free state symbols
used_before_assignment = set()

b_free_syms, b_defined_syms, b_used_before_syms = super().used_symbols(all_symbols)
b_free_syms, b_defined_syms, b_used_before_syms = super().used_symbols(all_symbols, defined_syms, free_syms,
used_before_assignment)
free_syms |= b_free_syms
defined_syms |= b_defined_syms
used_before_assignment |= b_used_before_syms
Expand Down Expand Up @@ -2030,17 +2031,17 @@ def compile(self, output_file=None, validate=True) -> 'CompiledSDFG':
############################
# DaCe Compilation Process #

# Convert any scope blocks to old-school state machines for now.
# TODO: Adapt codegen to deal wiht scope blocks instead.
sdutils.inline_loop_blocks(self)

if self._regenerate_code or not os.path.isdir(build_folder):
# Clone SDFG as the other modules may modify its contents
sdfg = copy.deepcopy(self)
# Fix the build folder name on the copied SDFG to avoid it changing
# if the codegen modifies the SDFG (thereby changing its hash)
sdfg.build_folder = build_folder

# Convert any scope blocks to old-school state machines for now.
# TODO: Adapt codegen to deal wiht scope blocks instead.
sdutils.inline_loop_blocks(sdfg)

# Rename SDFG to avoid runtime issues with clashing names
index = 0
while sdfg.is_loaded():
Expand Down
12 changes: 8 additions & 4 deletions dace/sdfg/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -2368,10 +2368,14 @@ def add_loop(
return before_state, loop_scope, after_state

@abc.abstractmethod
def used_symbols(self, all_symbols: bool) -> Tuple[Set[str], Set[str], Set[str]]:
defined_syms = set()
free_syms = set()
used_before_assignment = set()
def used_symbols(self,
all_symbols: bool,
defined_syms: Optional[Set]=None,
free_syms: Optional[Set]=None,
used_before_assignment: Optional[Set]=None) -> Tuple[Set[str], Set[str], Set[str]]:
defined_syms = set() if defined_syms is None else defined_syms
free_syms = set() if free_syms is None else free_syms
used_before_assignment = set() if used_before_assignment is None else used_before_assignment

try:
ordered_blocks = self.topological_sort(self.start_block)
Expand Down
54 changes: 27 additions & 27 deletions dace/transformation/interstate/multistate_inline.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,31 +237,30 @@ def apply(self, outer_state: SDFGState, sdfg: SDFG):

# All transients become transients of the parent (if data already
# exists, find new name)
for cf in nsdfg.all_state_scopes_recursive():
for nblock in cf.nodes():
for node in nblock.nodes():
if isinstance(node, nodes.AccessNode):
datadesc = nsdfg.arrays[node.data]
if node.data not in transients and datadesc.transient:
new_name = node.data
for state in nsdfg.states():
for node in state.nodes():
if isinstance(node, nodes.AccessNode):
datadesc = nsdfg.arrays[node.data]
if node.data not in transients and datadesc.transient:
new_name = node.data
if (new_name in sdfg.arrays or new_name in outer_symbols or new_name in sdfg.constants):
new_name = f'{nsdfg.label}_{node.data}'

name = sdfg.add_datadesc(new_name, datadesc, find_new_name=True)
transients[node.data] = name

# All transients of edges between code nodes are also added to parent
for edge in state.edges():
if (isinstance(edge.src, nodes.CodeNode) and isinstance(edge.dst, nodes.CodeNode)):
if edge.data.data is not None:
datadesc = nsdfg.arrays[edge.data.data]
if edge.data.data not in transients and datadesc.transient:
new_name = edge.data.data
if (new_name in sdfg.arrays or new_name in outer_symbols or new_name in sdfg.constants):
new_name = f'{nsdfg.label}_{node.data}'
new_name = f'{nsdfg.label}_{edge.data.data}'

name = sdfg.add_datadesc(new_name, datadesc, find_new_name=True)
transients[node.data] = name

# All transients of edges between code nodes are also added to parent
for edge in nblock.edges():
if (isinstance(edge.src, nodes.CodeNode) and isinstance(edge.dst, nodes.CodeNode)):
if edge.data.data is not None:
datadesc = nsdfg.arrays[edge.data.data]
if edge.data.data not in transients and datadesc.transient:
new_name = edge.data.data
if (new_name in sdfg.arrays or new_name in outer_symbols or new_name in sdfg.constants):
new_name = f'{nsdfg.label}_{edge.data.data}'

name = sdfg.add_datadesc(new_name, datadesc, find_new_name=True)
transients[edge.data.data] = name
transients[edge.data.data] = name


# All constants (and associated transients) become constants of the parent
Expand Down Expand Up @@ -413,11 +412,12 @@ def apply(self, outer_state: SDFGState, sdfg: SDFG):
# e.data, outer_edge.data)

# Replace nested SDFG parents and SDFG pointers.
for cf in nsdfg.all_state_scopes_recursive():
for nblock in cf.nodes():
nblock.parent = cf
nblock.sdfg = sdfg
for node in nblock.nodes():
for n in nsdfg.nodes():
n.parent = outer_state.parent
for block in nsdfg.all_control_flow_blocks_recursive(recurse_into_sdfgs=False):
block.sdfg = outer_state.sdfg
if isinstance(block, SDFGState):
for node in block.nodes():
if isinstance(node, nodes.NestedSDFG):
node.sdfg.parent_sdfg = sdfg
node.sdfg.parent_nsdfg_node = node
Expand Down
2 changes: 1 addition & 1 deletion tests/memlet_propagation_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def sparse(A: dace.float32[M, N], ind: dace.int32[M, N]):
propagate_memlets_sdfg(sdfg)

# Verify all memlet subsets and volumes in the main state of the program, i.e. around the NSDFG.
map_state = sdfg.states()[1]
map_state = list(sdfg.states())[1]
i = dace.symbol('i')
j = dace.symbol('j')

Expand Down

0 comments on commit 89fe7b1

Please sign in to comment.