Skip to content

Commit

Permalink
Adds legacy compatibility to loopscopeblocks
Browse files Browse the repository at this point in the history
  • Loading branch information
phschaad committed Sep 7, 2023
1 parent 75195fa commit 0b0a3c8
Show file tree
Hide file tree
Showing 11 changed files with 152 additions and 36 deletions.
6 changes: 4 additions & 2 deletions dace/codegen/codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from dace.codegen.targets import cpp, cpu

from dace.codegen.instrumentation import InstrumentationProvider
from dace.sdfg.state import SDFGState
from dace.sdfg.state import SDFGState, ScopeBlock


def generate_headers(sdfg: SDFG, frame: framecode.DaCeCodeGenerator) -> str:
Expand Down Expand Up @@ -102,6 +102,8 @@ def _get_codegen_targets(sdfg: SDFG, frame: framecode.DaCeCodeGenerator):
state: SDFGState = parent
nsdfg = state.sdfg
frame.targets.add(disp.get_node_dispatcher(nsdfg, state, node))
elif isinstance(node, ScopeBlock):
frame.targets.add(disp.get_state_scope_dispatcher(parent, node))

# Array allocation
if isinstance(node, dace.nodes.AccessNode):
Expand Down Expand Up @@ -149,7 +151,7 @@ def _get_codegen_targets(sdfg: SDFG, frame: framecode.DaCeCodeGenerator):
disp.instrumentation[sdfg.instrument] = provider_mapping[sdfg.instrument]


def generate_code(sdfg, validate=True) -> List[CodeObject]:
def generate_code(sdfg: SDFG, validate=True) -> List[CodeObject]:
"""
Generates code as a list of code objects for a given SDFG.
Expand Down
19 changes: 11 additions & 8 deletions dace/codegen/targets/framecode.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def free_symbols(self, obj: Any):
if k in self.fsyms:
return self.fsyms[k]
if hasattr(obj, 'used_symbols'):
result = obj.used_symbols(all_symbols=False)
result = obj.used_symbols(all_symbols=False)[0]
else:
result = obj.free_symbols
self.fsyms[k] = result
Expand Down Expand Up @@ -395,9 +395,14 @@ def generate_external_memory_management(self, sdfg: SDFG, callsite_stream: CodeI
# Footer
callsite_stream.write('}', sdfg)

def generate_state(self, sdfg, state, global_stream, callsite_stream, generate_state_footer=True):
def generate_state(self,
sdfg: SDFG,
state: SDFGState,
global_stream: CodeIOStream,
callsite_stream: CodeIOStream,
generate_state_footer=True) -> None:

sid = sdfg.node_id(state)
sid = state.parent.node_id(state)

# Emit internal transient array allocation
self.allocate_arrays_in_scope(sdfg, state, global_stream, callsite_stream)
Expand Down Expand Up @@ -444,7 +449,7 @@ def generate_state(self, sdfg, state, global_stream, callsite_stream, generate_s
if instr is not None:
instr.on_state_end(sdfg, state, callsite_stream, global_stream)

def generate_states(self, sdfg, global_stream, callsite_stream):
def generate_states(self, sdfg: SDFG, global_stream: CodeIOStream, callsite_stream: CodeIOStream):
states_generated = set()

opbar = progress.OptionalProgressBar(sdfg.number_of_nodes(), title=f'Generating code (SDFG {sdfg.sdfg_id})')
Expand Down Expand Up @@ -526,8 +531,7 @@ def _can_allocate(self, sdfg: SDFG, state: SDFGState, desc: data.Data, scope: Un

def determine_allocation_lifetime(self, top_sdfg: SDFG):
"""
Determines where (at which scope/state/SDFG) each data descriptor
will be allocated/deallocated.
Determines where (at which scope/state/SDFG) each data descriptor will be allocated/deallocated.
:param top_sdfg: The top-level SDFG to determine for.
"""
Expand All @@ -543,8 +547,7 @@ def determine_allocation_lifetime(self, top_sdfg: SDFG):
#############################################
# Look for all states in which a scope-allocated array is used in
instances: Dict[str, List[Tuple[SDFGState, nodes.AccessNode]]] = collections.defaultdict(list)
array_names = sdfg.arrays.keys(
) #set(k for k, v in sdfg.arrays.items() if v.lifetime == dtypes.AllocationLifetime.Scope)
array_names = sdfg.arrays.keys()
# Iterate topologically to get state-order
for state in sdfg.topological_sort():
for node in state.data_nodes():
Expand Down
16 changes: 9 additions & 7 deletions dace/frontend/python/newast.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
from dace.sdfg.propagation import propagate_memlet, propagate_subset, propagate_states
from dace.memlet import Memlet
from dace.properties import LambdaProperty, CodeBlock
from dace.sdfg import SDFG, SDFGState, ControlFlowGraph, ControlFlowBlock, LoopScopeBlock
from dace.sdfg import SDFG, SDFGState, ControlFlowGraph, ControlFlowBlock, LoopScopeBlock, ScopeBlock
from dace.sdfg.replace import replace_datadesc_names
from dace.symbolic import pystr_to_symbolic, inequal_symbols

Expand Down Expand Up @@ -1046,8 +1046,8 @@ class ProgramVisitor(ExtNodeVisitor):

sdfg: SDFG
last_block: ControlFlowBlock
cfg_target: ControlFlowGraph
last_cfg_target: ControlFlowGraph
cfg_target: ScopeBlock
last_cfg_target: ScopeBlock
current_state: SDFGState

def __init__(self,
Expand Down Expand Up @@ -1334,8 +1334,8 @@ def _add_block(self, block: ControlFlowBlock):
else:
self.current_state = block

def _add_state(self, label=None) -> SDFGState:
state = self.cfg_target.add_state(label, False)
def _add_state(self, label=None, is_start=False) -> SDFGState:
state = self.cfg_target.add_state(label, is_start_block=is_start)
self._add_block(state)
return state

Expand Down Expand Up @@ -2344,8 +2344,10 @@ def visit_For(self, node: ast.For):
initialize_expr=astutils.unparse(ast_ranges[0][0]),
update_expr=incr[indices[0]],
inverted=False)
self._recursive_visit(node.body, f'for_{node.lineno}', node.lineno, extra_symbols=extra_syms,
parent=loop_scope, unconnected_last_block=False)
_, first_subblock, _, _ = self._recursive_visit(node.body, f'for_{node.lineno}',
node.lineno, extra_symbols=extra_syms,
parent=loop_scope, unconnected_last_block=False)
loop_scope.start_block = loop_scope.node_id(first_subblock)

# Handle else clause
if node.orelse:
Expand Down
3 changes: 1 addition & 2 deletions dace/sdfg/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -365,8 +365,7 @@ def sink_nodes(self) -> List[NodeT]:
return [n for n in self.nodes() if self.out_degree(n) == 0]

def topological_sort(self, source: NodeT = None) -> Sequence[NodeT]:
"""Returns nodes in topological order iff the graph contains exactly
one node with no incoming edges."""
"""Returns nodes in topological order iff the graph contains exactly one node with no incoming edges."""
if source is not None:
sources = [source]
else:
Expand Down
4 changes: 2 additions & 2 deletions dace/sdfg/infer_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def infer_connector_types(sdfg: SDFG):
:param sdfg: The SDFG to infer.
"""
# Loop over states, and in a topological sort over each state's nodes
for state in sdfg.nodes():
for state in sdfg.all_states_recursive():
for node in dfs_topological_sort(state):
# Try to infer input connector type from node type or previous edges
for e in state.in_edges(node):
Expand Down Expand Up @@ -167,7 +167,7 @@ def set_default_schedule_and_storage_types(scope: Union[SDFG, SDFGState, nodes.E

if isinstance(scope, SDFG):
# Set device for default top-level schedules and storages
for state in scope.nodes():
for state in scope.all_states_recursive():
set_default_schedule_and_storage_types(state,
parent_schedules,
use_parent_schedule=use_parent_schedule,
Expand Down
5 changes: 1 addition & 4 deletions dace/sdfg/scope.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,10 +122,7 @@ def node_id_or_none(node):
if node is None: return -1
return state.node_id(node)

res = {}
for k, v in scope_dict.items():
res[node_id_or_none(k)] = [node_id_or_none(vi) for vi in v] if v is not None else []
return res
return {node_id_or_none(k): [node_id_or_none(vi) for vi in v] for k, v in scope_dict.items()}


def scope_contains_scope(sdict: ScopeDictType, node: NodeType, other_node: NodeType) -> bool:
Expand Down
19 changes: 10 additions & 9 deletions dace/sdfg/sdfg.py
Original file line number Diff line number Diff line change
Expand Up @@ -1404,22 +1404,21 @@ def transients(self):
return result

def shared_transients(self, check_toplevel=True) -> List[str]:
""" Returns a list of transient data that appears in more than one
state. """
""" Returns a list of transient data that appears in more than one state. """
seen = {}
shared = []

# If a transient is present in an inter-state edge, it is shared
for interstate_edge in self.edges():
for interstate_edge in self.all_interstate_edges_recursive():
for sym in interstate_edge.data.free_symbols:
if sym in self.arrays and self.arrays[sym].transient:
seen[sym] = interstate_edge
shared.append(sym)

# If transient is accessed in more than one state, it is shared
for state in self.nodes():
for node in state.nodes():
if isinstance(node, nd.AccessNode) and node.desc(self).transient:
for state in self.all_states_recursive():
for node in state.data_nodes():
if node.desc(self).transient:
if (check_toplevel and node.desc(self).toplevel) or (node.data in seen
and seen[node.data] != state):
shared.append(node.data)
Expand Down Expand Up @@ -2255,8 +2254,10 @@ def __call__(self, *args, **kwargs):
def fill_scope_connectors(self):
""" Fills missing scope connectors (i.e., "IN_#"/"OUT_#" on entry/exit
nodes) according to data on the memlets. """
for state in self.nodes():
state.fill_scope_connectors()
for cf in self.all_cfgs_recursive():
for block in cf.nodes():
if isinstance(block, SDFGState):
block.fill_scope_connectors()

def predecessor_state_transitions(self, state):
""" Yields paths (lists of edges) that the SDFG can pass through
Expand Down Expand Up @@ -2555,7 +2556,7 @@ def expand_library_nodes(self, recursive=True):
including library nodes that expand to library nodes.
"""

states = list(self.states())
states = list(self.all_states_recursive())
while len(states) > 0:
state = states.pop()
expanded_something = False
Expand Down
8 changes: 7 additions & 1 deletion dace/sdfg/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -1169,7 +1169,7 @@ def to_json(self, parent=None):
from dace.sdfg.scope import _scope_dict_to_ids
# Create scope dictionary with a failsafe
try:
scope_dict_regular = self.scope_dict()
scope_dict_regular = self.scope_children()
scope_dict_ids = _scope_dict_to_ids(self, scope_dict_regular)
scope_dict = {k: sorted(v) for k, v in sorted(scope_dict_ids.items())}
except (RuntimeError, ValueError):
Expand Down Expand Up @@ -2271,6 +2271,12 @@ def all_control_flow_blocks_recursive(self, recurse_into_sdfgs=True) -> Iterator
for block in cfg.nodes():
yield block

def all_interstate_edges_recursive(self, recurse_into_sdfgs=True) -> Iterator[Edge['dace.sdfg.InterstateEdge']]:
""" Iterate over all interstate edges in this control flow graph. """
for cfg in self.all_cfgs_recursive(recurse_into_sdfgs=recurse_into_sdfgs):
for edge in cfg.edges():
yield edge

###################################################################
# Getters & setters, overrides

Expand Down
27 changes: 26 additions & 1 deletion dace/sdfg/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from dace.sdfg.graph import MultiConnectorEdge
from dace.sdfg.sdfg import SDFG
from dace.sdfg.nodes import Node, NestedSDFG
from dace.sdfg.state import SDFGState, StateSubgraphView
from dace.sdfg.state import SDFGState, StateSubgraphView, LoopScopeBlock
from dace.sdfg.scope import ScopeSubgraphView
from dace.sdfg import nodes as nd, graph as gr, propagation
from dace import config, data as dt, dtypes, memlet as mm, subsets as sbs, symbolic
Expand Down Expand Up @@ -1246,6 +1246,31 @@ def fuse_states(sdfg: SDFG, permissive: bool = False, progress: bool = None) ->
return counter


def inline_loop_blocks(sdfg: SDFG, permissive: bool = False, progress: bool = None) -> int:
# Avoid import loops
from dace.transformation.interstate import LoopScopeInline

counter = 0
blocks = [(n, p) for n, p in sdfg.all_nodes_recursive() if isinstance(n, LoopScopeBlock)]

for block, graph in optional_progressbar(reversed(blocks), title='Inlining Loops', n=len(blocks), progress=progress):
id = block.sdfg.sdfg_id

# We have to reevaluate every time due to changing IDs
block_id = graph.node_id(block)

candidate = {
LoopScopeInline.block: block,
}
inliner = LoopScopeInline()
inliner.setup_match(graph, id, block_id, candidate, 0, override=True)
if inliner.can_be_applied(graph, 0, block.sdfg, permissive=permissive):
inliner.apply(graph, block.sdfg)
counter += 1

return counter


def inline_sdfgs(sdfg: SDFG, permissive: bool = False, progress: bool = None, multistate: bool = True) -> int:
"""
Inlines all possible nested SDFGs (or sub-SDFGs) using an optimized
Expand Down
1 change: 1 addition & 0 deletions dace/transformation/interstate/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from .loop_unroll import LoopUnroll
from .loop_peeling import LoopPeeling
from .loop_to_map import LoopToMap
from .scope_inline import LoopScopeInline
from .move_loop_into_map import MoveLoopIntoMap
from .trivial_loop_elimination import TrivialLoopElimination
from .multistate_inline import InlineMultistateSDFG
80 changes: 80 additions & 0 deletions dace/transformation/interstate/scope_inline.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
# Copyright 2019-2023 ETH Zurich and the DaCe authors. All rights reserved.
""" Inline all scope blocks in SDFGs. """

from typing import Any, Set, Optional

from dace.frontend.python import astutils
from dace.sdfg import SDFG, ControlFlowGraph, InterstateEdge, SDFGState
from dace.sdfg import utils as sdutil
from dace.sdfg.nodes import CodeBlock
from dace.sdfg.state import LoopScopeBlock, ScopeBlock
from dace.transformation import transformation


class LoopScopeInline(transformation.MultiStateTransformation):
"""
Inlines a loop scope block into a legacy-style state machine.
"""

block = transformation.PatternNode(LoopScopeBlock)

@staticmethod
def annotates_memlets():
return False

@classmethod
def expressions(cls):
return [sdutil.node_path_graph(cls.block)]

def can_be_applied(self, graph: ControlFlowGraph, expr_index: int, sdfg: SDFG, permissive: bool = False) -> bool:
return True

def apply(self, graph: ControlFlowGraph, sdfg: SDFG) -> Optional[int]:
parent: ScopeBlock = graph

internal_start = self.block.start_block

# Construct the basic loop state structure.
init_state = parent.add_state(self.block.label + '_init')
for b_edge in parent.in_edges(self.block):
parent.add_edge(b_edge.src, init_state, b_edge.data)
parent.remove_edge(b_edge)

guard_state = parent.add_state(self.block.label + '_guard')
init_edge = InterstateEdge()
if self.block.init_statement is not None:
init_edge.assignments = {
self.block.loop_variable: self.block.init_statement.as_string.rpartition('=')[2].strip()
}
parent.add_edge(init_state, guard_state, init_edge)

end_state = parent.add_state(self.block.label + '_end')
parent.add_edge(guard_state, end_state,
InterstateEdge(condition=CodeBlock(astutils.negate_expr(self.block.scope_condition.code))))
for a_edge in parent.out_edges(self.block):
parent.add_edge(end_state, a_edge.dst, a_edge.data)
parent.remove_edge(a_edge)

last_loop_state = parent.add_state(self.block.label + '_loop')
loop_edge = InterstateEdge()
if self.block.update_statement is not None:
loop_edge.assignments = {
self.block.loop_variable: self.block.update_statement.as_string.rpartition('=')[2].strip()
}
parent.add_edge(last_loop_state, guard_state, loop_edge)

to_connect: Set[SDFGState] = set()
for node in self.block.nodes():
parent.add_node(node)
if self.block.out_degree(node) == 0:
to_connect.add(node)
for edge in self.block.edges():
parent.add_edge(edge.src, edge.dst, edge.data)

# Connect the loop states
parent.add_edge(guard_state, internal_start,
InterstateEdge(condition=self.block.scope_condition.as_string))
for node in to_connect:
parent.add_edge(node, last_loop_state, InterstateEdge())

parent.remove_node(self.block)

0 comments on commit 0b0a3c8

Please sign in to comment.