Skip to content

Commit

Permalink
Refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
phschaad committed Sep 5, 2023
1 parent 125c1b6 commit a320d09
Show file tree
Hide file tree
Showing 31 changed files with 305 additions and 397 deletions.
8 changes: 4 additions & 4 deletions dace/codegen/codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,13 +100,13 @@ def _get_codegen_targets(sdfg: SDFG, frame: framecode.DaCeCodeGenerator):
frame.targets.add(disp.get_scope_dispatcher(node.schedule))
elif isinstance(node, dace.nodes.Node):
state: SDFGState = parent
nsdfg = state.parent
nsdfg = state.sdfg
frame.targets.add(disp.get_node_dispatcher(nsdfg, state, node))

# Array allocation
if isinstance(node, dace.nodes.AccessNode):
state: SDFGState = parent
nsdfg = state.parent
nsdfg = state.sdfg
desc = node.desc(nsdfg)
frame.targets.add(disp.get_array_dispatcher(desc.storage))

Expand All @@ -124,13 +124,13 @@ def _get_codegen_targets(sdfg: SDFG, frame: framecode.DaCeCodeGenerator):
dst_node = leaf_e.dst
if leaf_e.data.is_empty():
continue
tgt = disp.get_copy_dispatcher(node, dst_node, leaf_e, state.parent, state)
tgt = disp.get_copy_dispatcher(node, dst_node, leaf_e, state.sdfg, state)
if tgt is not None:
frame.targets.add(tgt)
else:
# Rooted at dst_node
dst_node = mtree.root().edge.dst
tgt = disp.get_copy_dispatcher(node, dst_node, e, state.parent, state)
tgt = disp.get_copy_dispatcher(node, dst_node, e, state.sdfg, state)
if tgt is not None:
frame.targets.add(tgt)

Expand Down
8 changes: 4 additions & 4 deletions dace/codegen/control_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ class SingleState(ControlFlow):
last_state: bool = False

def as_cpp(self, codegen, symbols) -> str:
sdfg = self.state.parent
sdfg = self.state.sdfg

expr = '__state_{}_{}:;\n'.format(sdfg.sdfg_id, self.state.label)
if self.state.number_of_nodes() > 0:
Expand Down Expand Up @@ -218,7 +218,7 @@ def as_cpp(self, codegen, symbols) -> str:
# In a general block, emit transitions and assignments after each
# individual state
if isinstance(elem, SingleState):
sdfg = elem.state.parent
sdfg = elem.state.sdfg
out_edges = sdfg.out_edges(elem.state)
for j, e in enumerate(out_edges):
if e not in self.gotos_to_ignore:
Expand Down Expand Up @@ -361,7 +361,7 @@ def as_cpp(self, codegen, symbols) -> str:
init = f'{symbols[self.itervar]} {self.itervar}'
init += ' = ' + self.init

sdfg = self.guard.parent
sdfg = self.guard.sdfg

preinit = ''
if self.init_edges:
Expand Down Expand Up @@ -403,7 +403,7 @@ class WhileScope(ControlFlow):

def as_cpp(self, codegen, symbols) -> str:
if self.test is not None:
sdfg = self.guard.parent
sdfg = self.guard.sdfg
test = unparse_interstate_edge(self.test.code[0], sdfg, codegen=codegen)
else:
test = 'true'
Expand Down
10 changes: 5 additions & 5 deletions dace/codegen/targets/cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,9 +204,9 @@ def preprocess(self, sdfg: SDFG) -> None:
for state, node, defined_syms in sdutil.traverse_sdfg_with_defined_symbols(sdfg, recursive=True):
if (isinstance(node, nodes.MapEntry)
and node.map.schedule in (dtypes.ScheduleType.GPU_Device, dtypes.ScheduleType.GPU_Persistent)):
if state.parent not in shared_transients:
shared_transients[state.parent] = state.parent.shared_transients()
self._arglists[node] = state.scope_subgraph(node).arglist(defined_syms, shared_transients[state.parent])
if state.sdfg not in shared_transients:
shared_transients[state.sdfg] = state.sdfg.shared_transients()
self._arglists[node] = state.scope_subgraph(node).arglist(defined_syms, shared_transients[state.sdfg])

def _compute_pool_release(self, top_sdfg: SDFG):
"""
Expand Down Expand Up @@ -831,7 +831,7 @@ def increment(streams):
# Remove CUDA streams from paths of non-gpu copies and CPU tasklets
for node, graph in sdfg.all_nodes_recursive():
if isinstance(graph, SDFGState):
cur_sdfg = graph.parent
cur_sdfg = graph.sdfg

if (isinstance(node, (nodes.EntryNode, nodes.ExitNode)) and node.schedule in dtypes.GPU_SCHEDULES):
# Node must have GPU stream, remove childpath and continue
Expand Down Expand Up @@ -1421,7 +1421,7 @@ def generate_scope(self, sdfg, dfg_scope, state_id, function_stream, callsite_st
visited = set()
for node, parent in dfg_scope.all_nodes_recursive():
if isinstance(node, nodes.AccessNode):
nsdfg: SDFG = parent.parent
nsdfg: SDFG = parent.sdfg
desc = node.desc(nsdfg)
if (nsdfg, node.data) in visited:
continue
Expand Down
2 changes: 1 addition & 1 deletion dace/codegen/targets/fpga.py
Original file line number Diff line number Diff line change
Expand Up @@ -1332,7 +1332,7 @@ def partition_kernels(self, state: dace.SDFGState, default_kernel: int = 0):
"""

concurrent_kernels = 0 # Max number of kernels
sdfg = state.parent
sdfg = state.sdfg

def increment(kernel_id):
if concurrent_kernels > 0:
Expand Down
4 changes: 2 additions & 2 deletions dace/codegen/targets/framecode.py
Original file line number Diff line number Diff line change
Expand Up @@ -491,7 +491,7 @@ def _get_schedule(self, scope: Union[nodes.EntryNode, SDFGState, SDFG]) -> dtype
elif isinstance(scope, nodes.EntryNode):
return scope.schedule
elif isinstance(scope, (SDFGState, SDFG)):
sdfg: SDFG = (scope if isinstance(scope, SDFG) else scope.parent)
sdfg: SDFG = (scope if isinstance(scope, SDFG) else scope.sdfg)
if sdfg.parent_nsdfg_node is None:
return TOP_SCHEDULE

Expand Down Expand Up @@ -721,7 +721,7 @@ def determine_allocation_lifetime(self, top_sdfg: SDFG):
if curscope is None:
curscope = curstate
elif isinstance(curscope, (SDFGState, SDFG)):
cursdfg: SDFG = (curscope if isinstance(curscope, SDFG) else curscope.parent)
cursdfg: SDFG = (curscope if isinstance(curscope, SDFG) else curscope.sdfg)
# Go one SDFG up
if cursdfg.parent_nsdfg_node is None:
curscope = None
Expand Down
2 changes: 1 addition & 1 deletion dace/frontend/python/newast.py
Original file line number Diff line number Diff line change
Expand Up @@ -1335,7 +1335,7 @@ def _add_block(self, block: ControlFlowBlock):
self.current_state = block

def _add_state(self, label=None) -> SDFGState:
state = self.cfg_target.add_state(label, False, self.sdfg)
state = self.cfg_target.add_state(label, False)
self._add_block(state)
return state

Expand Down
4 changes: 2 additions & 2 deletions dace/memlet.py
Original file line number Diff line number Diff line change
Expand Up @@ -400,11 +400,11 @@ def try_initialize(self, sdfg: 'dace.sdfg.SDFG', state: 'dace.sdfg.SDFGState',
self.subset = subsets.Range.from_array(sdfg.arrays[self.data])

def get_src_subset(self, edge: 'dace.sdfg.graph.MultiConnectorEdge', state: 'dace.sdfg.SDFGState'):
self.try_initialize(state.parent, state, edge)
self.try_initialize(state.sdfg, state, edge)
return self.src_subset

def get_dst_subset(self, edge: 'dace.sdfg.graph.MultiConnectorEdge', state: 'dace.sdfg.SDFGState'):
self.try_initialize(state.parent, state, edge)
self.try_initialize(state.sdfg, state, edge)
return self.dst_subset

@staticmethod
Expand Down
32 changes: 16 additions & 16 deletions dace/sdfg/analysis/cutout.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,7 @@ def singlestate_cutout(cls,
if reduce_input_config:
nodes = _reduce_in_configuration(state, nodes, use_alibi_nodes, symbols_map)
create_element = copy.deepcopy if make_copy else (lambda x: x)
sdfg = state.parent
sdfg = state.sdfg
subgraph: StateSubgraphView = StateSubgraphView(state, nodes)
subgraph = _extend_subgraph_with_access_nodes(state, subgraph, use_alibi_nodes)

Expand Down Expand Up @@ -341,8 +341,8 @@ def multistate_cutout(cls,
create_element = copy.deepcopy

# Check that all states are inside the same SDFG.
sdfg = list(states)[0].parent
if any(i.parent != sdfg for i in states):
sdfg = list(states)[0].sdfg
if any(i.sdfg != sdfg for i in states):
raise Exception('Not all cutout states reside in the same SDFG')

cutout_states: Set[SDFGState] = set(states)
Expand Down Expand Up @@ -423,13 +423,13 @@ def multistate_cutout(cls,
in_translation[is_edge.src] = new_el
out_translation[new_el] = is_edge.src
cutout.add_node(new_el, is_start_state=(is_edge.src == start_state))
new_el.parent = cutout
new_el.sdfg = cutout
if is_edge.dst not in in_translation:
new_el: SDFGState = create_element(is_edge.dst)
in_translation[is_edge.dst] = new_el
out_translation[new_el] = is_edge.dst
cutout.add_node(new_el, is_start_state=(is_edge.dst == start_state))
new_el.parent = cutout
new_el.sdfg = cutout
new_isedge: InterstateEdge = create_element(is_edge.data)
in_translation[is_edge.data] = new_isedge
out_translation[new_isedge] = is_edge.data
Expand All @@ -442,7 +442,7 @@ def multistate_cutout(cls,
in_translation[state] = new_el
out_translation[new_el] = state
cutout.add_node(new_el, is_start_state=(state == start_state))
new_el.parent = cutout
new_el.sdfg = cutout

in_translation[sdfg.sdfg_id] = cutout.sdfg_id
out_translation[cutout.sdfg_id] = sdfg.sdfg_id
Expand Down Expand Up @@ -574,8 +574,8 @@ def _reduce_in_configuration(state: SDFGState, affected_nodes: Set[nd.Node], use

# For the given state, determine what should count as the input configuration if we were to cut out the entire
# state.
state_reachability_dict = StateReachability().apply_pass(state.parent, None)
state_reach = state_reachability_dict[state.parent.sdfg_id]
state_reachability_dict = StateReachability().apply_pass(state.sdfg, None)
state_reach = state_reachability_dict[state.sdfg.sdfg_id]
reaching_cutout: Set[SDFGState] = set()
for k, v in state_reach.items():
if state in v:
Expand All @@ -586,7 +586,7 @@ def _reduce_in_configuration(state: SDFGState, affected_nodes: Set[nd.Node], use
if state.out_degree(dn) > 0:
# This is read from, add to the system state if it is written anywhere else in the graph.
# Except if it is also written to at the same time and is scalar or of size 1.
array = state.parent.arrays[dn.data]
array = state.sdfg.arrays[dn.data]
if state.in_degree(dn) > 0 and (array.total_size == 1 or isinstance(array, data.Scalar)):
continue
elif not array.transient:
Expand All @@ -608,8 +608,8 @@ def _reduce_in_configuration(state: SDFGState, affected_nodes: Set[nd.Node], use
# about symbol values. Not sure how to do that yet.
if symbols_map is None:
symbols_map = dict()
consts = state.parent.constants
for s in state.parent.symbols:
consts = state.sdfg.constants
for s in state.sdfg.symbols:
if s in consts:
symbols_map[s] = consts[s]
else:
Expand Down Expand Up @@ -730,8 +730,8 @@ def _reduce_in_configuration(state: SDFGState, affected_nodes: Set[nd.Node], use

for node in scope_nodes:
if isinstance(node, nd.AccessNode) and node.data in state_input_configuration:
if not proxy_graph.has_edge(source, node) and node.data in state.parent.arrays:
vol = state.parent.arrays[node.data].total_size
if not proxy_graph.has_edge(source, node) and node.data in state.sdfg.arrays:
vol = state.sdfg.arrays[node.data].total_size
if isinstance(vol, sp.Expr):
vol = vol.subs(symbols_map)
proxy_graph.add_edge(source, node, capacity=vol)
Expand Down Expand Up @@ -767,7 +767,7 @@ def _stateset_predecessor_frontier(states: Set[SDFGState]) -> Tuple[Set[SDFGStat
pred_frontier = set()
pred_frontier_edges = set()
for state in states:
for iedge in state.parent.in_edges(state):
for iedge in state.sdfg.in_edges(state):
if iedge.src not in states:
if iedge.src not in pred_frontier:
pred_frontier.add(iedge.src)
Expand Down Expand Up @@ -819,7 +819,7 @@ def _create_alibi_access_node_for_edge(target_sdfg: SDFG, target_state: SDFGStat
def _extend_subgraph_with_access_nodes(state: SDFGState, subgraph: StateSubgraphView,
use_alibi_nodes: bool) -> StateSubgraphView:
""" Expands a subgraph view to include necessary input/output access nodes, using memlet paths. """
sdfg = state.parent
sdfg = state.sdfg
result: List[nd.Node] = copy.copy(subgraph.nodes())
queue: Deque[nd.Node] = deque(subgraph.nodes())

Expand Down Expand Up @@ -1014,7 +1014,7 @@ def _cutout_determine_output_configuration(ct: SDFG, cutout_reach: Set[SDFGState
check_for_read_after.add(dn.data)

original_state: SDFGState = out_translation[state]
for edge in original_state.parent.out_edges(original_state):
for edge in original_state.sdfg.out_edges(original_state):
if edge.dst in cutout_reach:
border_out_edges.add(edge.data)

Expand Down
6 changes: 3 additions & 3 deletions dace/sdfg/infer_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,7 +257,7 @@ def _determine_schedule_from_storage(state: SDFGState, node: nodes.Node) -> Opti

# From memlets, use non-scalar data descriptors for decision
constraints: Set[dtypes.ScheduleType] = set()
sdfg = state.parent
sdfg = state.sdfg
for dname in memlets:
if isinstance(sdfg.arrays[dname], data.Scalar):
continue # Skip scalars
Expand All @@ -276,7 +276,7 @@ def _determine_schedule_from_storage(state: SDFGState, node: nodes.Node) -> Opti
raise validation.InvalidSDFGNodeError(
f'Cannot determine default schedule for node {node}. '
'Multiple arrays that point to it say that it should be the following schedules: '
f'{constraints}', state.parent, state.parent.node_id(state), state.node_id(node))
f'{constraints}', state.sdfg, state.sdfg.node_id(state), state.node_id(node))
else:
child_schedule = next(iter(constraints))

Expand Down Expand Up @@ -338,7 +338,7 @@ def _set_default_storage_in_scope(state: SDFGState, parent_node: Optional[nodes.
parent_schedules = parent_schedules + [dtypes.ScheduleType.GPU_ThreadBlock]
# End of special case

sdfg = state.parent
sdfg = state.sdfg
child_storage = _determine_child_storage(parent_schedules)
if child_storage is None:
child_storage = dtypes.SCOPEDEFAULT_STORAGE[None]
Expand Down
4 changes: 2 additions & 2 deletions dace/sdfg/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,7 +265,7 @@ def __label__(self, sdfg, state):
def desc(self, sdfg):
from dace.sdfg import SDFGState, ScopeSubgraphView
if isinstance(sdfg, (SDFGState, ScopeSubgraphView)):
sdfg = sdfg.parent
sdfg = sdfg.sdfg
return sdfg.arrays[self.data]

def validate(self, sdfg, state):
Expand Down Expand Up @@ -588,7 +588,7 @@ def used_symbols(self, all_symbols: bool) -> Set[str]:

# Filter out unused internal symbols from symbol mapping
if not all_symbols:
internally_used_symbols = self.sdfg.used_symbols(all_symbols=False)
internally_used_symbols = self.sdfg.used_symbols(all_symbols=False)[0]
free_syms &= internally_used_symbols

return free_syms
Expand Down
29 changes: 15 additions & 14 deletions dace/sdfg/replace.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,17 +168,18 @@ def replace_datadesc_names(sdfg, repl: Dict[str, str]):
sdfg.constants_prop[repl[aname]] = sdfg.constants_prop[aname]
del sdfg.constants_prop[aname]

# Replace in interstate edges
for e in sdfg.edges():
e.data.replace_dict(repl, replace_keys=False)

for state in sdfg.nodes():
# Replace in access nodes
for node in state.data_nodes():
if node.data in repl:
node.data = repl[node.data]

# Replace in memlets
for edge in state.edges():
if edge.data.data in repl:
edge.data.data = repl[edge.data.data]
for cf in sdfg.all_cfgs_recursive(recurse_into_sdfgs=False):
# Replace in interstate edges
for e in cf.edges():
e.data.replace_dict(repl, replace_keys=False)

for block in cf.nodes():
if isinstance(block, dace.SDFGState):
# Replace in access nodes
for node in block.data_nodes():
if node.data in repl:
node.data = repl[node.data]
# Replace in memlets
for edge in block.edges():
if edge.data.data in repl:
edge.data.data = repl[edge.data.data]
9 changes: 6 additions & 3 deletions dace/sdfg/scope.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,10 @@ def node_id_or_none(node):
if node is None: return -1
return state.node_id(node)

return {node_id_or_none(k): [node_id_or_none(vi) for vi in v] for k, v in scope_dict.items()}
res = {}
for k, v in scope_dict.items():
res[node_id_or_none(k)] = [node_id_or_none(vi) for vi in v] if v is not None else []
return res


def scope_contains_scope(sdict: ScopeDictType, node: NodeType, other_node: NodeType) -> bool:
Expand Down Expand Up @@ -246,7 +249,7 @@ def is_devicelevel_gpu_kernel(sdfg: 'dace.sdfg.SDFG', state: 'dace.sdfg.SDFGStat
if is_parent_nested:
return is_devicelevel_gpu(sdfg.parent.parent, sdfg.parent, sdfg.parent_nsdfg_node, with_gpu_default=True)
else:
return is_devicelevel_gpu(state.parent, state, node, with_gpu_default=True)
return is_devicelevel_gpu(state.sdfg, state, node, with_gpu_default=True)


def is_devicelevel_fpga(sdfg: 'dace.sdfg.SDFG', state: 'dace.sdfg.SDFGState', node: NodeType) -> bool:
Expand Down Expand Up @@ -294,7 +297,7 @@ def devicelevel_block_size(sdfg: 'dace.sdfg.SDFG', state: 'dace.sdfg.SDFGState',
# Traverse up nested SDFGs
if sdfg.parent is not None:
if isinstance(sdfg.parent, SDFGState):
parent = sdfg.parent.parent
parent = sdfg.parent.sdfg
else:
parent = sdfg.parent
state, node = next((s, n) for s in parent.nodes() for n in s.nodes()
Expand Down
Loading

0 comments on commit a320d09

Please sign in to comment.