Skip to content

Commit

Permalink
Merge pull request #1114 from spcl/transformation-fixes-2
Browse files Browse the repository at this point in the history
Transformation Fixes Round 2
  • Loading branch information
tbennun authored Oct 13, 2022
2 parents 85843f0 + 42cd81f commit 5153940
Show file tree
Hide file tree
Showing 12 changed files with 478 additions and 119 deletions.
1 change: 1 addition & 0 deletions dace/frontend/python/newast.py
Original file line number Diff line number Diff line change
Expand Up @@ -2433,6 +2433,7 @@ def visit_Continue(self, node: ast.Continue):
def visit_If(self, node: ast.If):
# Add a guard state
self._add_state('if_guard')
self.last_state.debuginfo = self.current_lineinfo

# Generate conditions
cond, cond_else = self._visit_test(node.test)
Expand Down
11 changes: 6 additions & 5 deletions dace/sdfg/analysis/cfg.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,14 +219,13 @@ def _stateorder_topological_sort(sdfg: SDFG,
"""
# Traverse states in custom order
visited = visited or set()
if stop is not None:
visited.add(stop)
stack = [start]
while stack:
node = stack.pop()
if node in visited:
if node in visited or node is stop:
continue
yield node
visited.add(node)

oe = sdfg.out_edges(node)
if len(oe) == 0: # End state
Expand Down Expand Up @@ -265,6 +264,9 @@ def _stateorder_topological_sort(sdfg: SDFG,
mergestate = stop

for branch in oe:
if branch.dst is mergestate:
# If we hit the merge state (if without else), defer to end of branch traversal
continue
for s in _stateorder_topological_sort(sdfg,
branch.dst,
ptree,
Expand All @@ -273,8 +275,7 @@ def _stateorder_topological_sort(sdfg: SDFG,
visited=visited):
yield s
visited.add(s)
if mergestate != stop:
stack.append(mergestate)
stack.append(mergestate)


def stateorder_topological_sort(sdfg: SDFG) -> Iterator[SDFGState]:
Expand Down
15 changes: 11 additions & 4 deletions dace/sdfg/sdfg.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,15 +196,22 @@ def condition_sympy(self):
self._cond_sympy = symbolic.pystr_to_symbolic(self.condition.as_string)
return self._cond_sympy

@property
def free_symbols(self) -> Set[str]:
""" Returns a set of symbols used in this edge's properties. """
def read_symbols(self) -> Set[str]:
"""
Returns a set of symbols read in this edge (including symbols in the condition and assignment values).
"""
# Symbols in conditions and assignments
result = set(map(str, dace.symbolic.symbols_in_ast(self.condition.code[0])))
for assign in self.assignments.values():
result |= symbolic.free_symbols_and_functions(assign)

return result - set(self.assignments.keys())
return result

@property
def free_symbols(self) -> Set[str]:
""" Returns a set of symbols used in this edge's properties. """
return self.read_symbols() - set(self.assignments.keys())


def replace_dict(self, repl: Dict[str, str], replace_keys=True) -> None:
"""
Expand Down
36 changes: 24 additions & 12 deletions dace/sdfg/validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,9 @@ def validate_sdfg(sdfg: 'dace.sdfg.SDFG', references: Set[int] = None):
undef_syms = set(edge.data.free_symbols) - set(symbols.keys())
if len(undef_syms) > 0:
eid = sdfg.edge_id(edge)
raise InvalidSDFGInterstateEdgeError("Undefined symbols in edge: %s" % undef_syms, sdfg, eid)
raise InvalidSDFGInterstateEdgeError(
f'Undefined symbols in edge: {undef_syms}. Add those with '
'`sdfg.add_symbol()` or define outside with `dace.symbol()`', sdfg, eid)

# Validate inter-state edge names
issyms = edge.data.new_symbols(sdfg, symbols)
Expand Down Expand Up @@ -231,8 +233,7 @@ def validate_state(state: 'dace.sdfg.SDFGState',
raise InvalidSDFGError("Invalid state name", sdfg, state_id)

if state._parent != sdfg:
raise InvalidSDFGError("State does not point to the correct "
"parent", sdfg, state_id)
raise InvalidSDFGError("State does not point to the correct " "parent", sdfg, state_id)

# Unreachable
########################################
Expand Down Expand Up @@ -618,7 +619,6 @@ def validate_state(state: 'dace.sdfg.SDFGState',

class InvalidSDFGError(Exception):
""" A class of exceptions thrown when SDFG validation fails. """

def __init__(self, message: str, sdfg: 'SDFG', state_id: int):
self.message = message
self.sdfg = sdfg
Expand All @@ -641,8 +641,7 @@ def _getlineinfo(self, obj) -> str:

if lineinfo.start_line >= 0:
if lineinfo.start_column > 0:
return (f'File "{lineinfo.filename}", line {lineinfo.start_line}, '
f'column {lineinfo.start_column}')
return (f'File "{lineinfo.filename}", line {lineinfo.start_line}, ' f'column {lineinfo.start_column}')
return f'File "{lineinfo.filename}", line {lineinfo.start_line}'

return f'File "{lineinfo.filename}"'
Expand Down Expand Up @@ -670,7 +669,6 @@ def __str__(self):

class InvalidSDFGInterstateEdgeError(InvalidSDFGError):
""" Exceptions of invalid inter-state edges in an SDFG. """

def __init__(self, message: str, sdfg: 'SDFG', edge_id: int):
self.message = message
self.sdfg = sdfg
Expand All @@ -687,15 +685,31 @@ def __str__(self):
str(e.src),
str(e.dst),
)
locinfo_src = self._getlineinfo(e.src)
locinfo_dst = self._getlineinfo(e.dst)
else:
edgestr = ''
locinfo_src = locinfo_dst = ''

if locinfo_src or locinfo_dst:
if locinfo_src == locinfo_dst:
locinfo = f'at {locinfo_src}'
elif locinfo_src and not locinfo_dst:
locinfo = f'at {locinfo_src}'
elif locinfo_dst and not locinfo_src:
locinfo = f'at {locinfo_src}'
else:
locinfo = f'between\n {locinfo_src}\n and\n {locinfo_dst}'

locinfo = f'\nOriginating from source code {locinfo}'
else:
edgestr = ""
locinfo = ''

return "%s%s" % (self.message, edgestr)
return f'{self.message}{edgestr}{locinfo}'


class InvalidSDFGNodeError(InvalidSDFGError):
""" Exceptions of invalid nodes in an SDFG state. """

def __init__(self, message: str, sdfg: 'SDFG', state_id: int, node_id: int):
self.message = message
self.sdfg = sdfg
Expand Down Expand Up @@ -729,14 +743,12 @@ class NodeNotExpandedError(InvalidSDFGNodeError):
Exception that is raised whenever a library node was not expanded
before code generation.
"""

def __init__(self, sdfg: 'SDFG', state_id: int, node_id: int):
super().__init__('Library node not expanded', sdfg, state_id, node_id)


class InvalidSDFGEdgeError(InvalidSDFGError):
""" Exceptions of invalid edges in an SDFG state. """

def __init__(self, message: str, sdfg: 'SDFG', state_id: int, edge_id: int):
self.message = message
self.sdfg = sdfg
Expand Down
38 changes: 34 additions & 4 deletions dace/transformation/dataflow/map_fission.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from dace.sdfg import nodes, graph as gr
from dace.sdfg import utils as sdutil
from dace.sdfg.graph import OrderedDiGraph
from dace.sdfg.propagation import propagate_memlets_state
from dace.sdfg.propagation import propagate_memlets_state, propagate_subset
from dace.symbolic import pystr_to_symbolic
from dace.transformation import transformation, helpers
from typing import List, Optional, Tuple
Expand Down Expand Up @@ -413,6 +413,14 @@ def apply(self, graph: sd.SDFGState, sdfg: sd.SDFG):
# Correct connectors and memlets in nested SDFGs to account for
# missing outside map
if self.expr_index == 1:

# NOTE: In the following scope dictionary, we mark the new MapEntries as existing in their own scope.
# This makes it easier to detect edges that are outside the new Map scopes (after MapFission).
scope_dict = state.scope_dict()
for k, v in scope_dict.items():
if isinstance(k, nodes.MapEntry) and k in new_map_entries and v is None:
scope_dict[k] = k

to_correct = ([(e, e.src) for e in external_edges_entry] + [(e, e.dst) for e in external_edges_exit])
corrected_nodes = set()
for edge, node in to_correct:
Expand Down Expand Up @@ -442,6 +450,12 @@ def apply(self, graph: sd.SDFGState, sdfg: sd.SDFG):
for e in state.memlet_tree(internal_edge):
e.data.subset.offset(desc.offset, False)
e.data.subset = helpers.unsqueeze_memlet(e.data, outer_edge.data).subset
# NOTE: If the edge is outside of the new Map scope, then try to propagate it. This is
# needed for edges directly connecting AccessNodes, because the standard memlet
# propagation will stop at the first AccessNode outside the Map scope. For example, see
# `test.transformations.mapfission_test.MapFissionTest.test_array_copy_outside_scope`.
if not (scope_dict[e.src] and scope_dict[e.dst]):
e.data = propagate_subset([e.data], desc, outer_map.params, outer_map.range)

# Only after offsetting memlets we can modify the
# overall offset
Expand All @@ -455,9 +469,21 @@ def apply(self, graph: sd.SDFGState, sdfg: sd.SDFG):
for edge in state.all_edges(node):
for e in state.memlet_tree(edge):
# Prepend map dimensions to memlet
e.data.subset = subsets.Range([(pystr_to_symbolic(d) - r[0], pystr_to_symbolic(d) - r[0], 1)
for d, r in zip(outer_map.params, outer_map.range)] +
e.data.subset.ranges)
# NOTE: Do this only for the subset corresponding to `node.data`. If the edge is copying
# to/from another AccessNode, the other data may not need extra dimensions. For example, see
# `test.transformations.mapfission_test.MapFissionTest.test_array_copy_outside_scope`.
if e.data.data == node.data:
if e.data.subset:
e.data.subset = subsets.Range([(pystr_to_symbolic(d) - r[0],
pystr_to_symbolic(d) - r[0], 1)
for d, r in zip(outer_map.params, outer_map.range)] +
e.data.subset.ranges)
else:
if e.data.other_subset:
e.data.other_subset = subsets.Range(
[(pystr_to_symbolic(d) - r[0], pystr_to_symbolic(d) - r[0], 1)
for d, r in zip(outer_map.params, outer_map.range)] +
e.data.other_subset.ranges)

# If nested SDFG, reconnect nodes around map and modify memlets
if self.expr_index == 1:
Expand Down Expand Up @@ -486,3 +512,7 @@ def apply(self, graph: sd.SDFGState, sdfg: sd.SDFG):

# Remove outer map
graph.remove_nodes_from([map_entry, map_exit])

# NOTE: It is better to manually call memlet propagation here to ensure that all subsets are properly updated.
# This can solve issues when, e.g., applying MapFission through `SDFG.apply_transformations_repeated`.
propagate_memlets_state(sdfg, graph)
99 changes: 54 additions & 45 deletions dace/transformation/interstate/gpu_transform_sdfg.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def _recursive_out_check(node, state, gpu_scalars):
scalset = scalset.union(sset)
scalout = scalout and ssout
continue
if desc.shape == (1,): # Pseudo-scalar
if desc.shape == (1, ): # Pseudo-scalar
scalout = False
sset, ssout = _recursive_out_check(last_edge.dst, state, gpu_scalars)
scalset = scalset.union(sset)
Expand Down Expand Up @@ -66,7 +66,7 @@ def _recursive_in_check(node, state, gpu_scalars):
scalset = scalset.union(sset)
scalout = scalout and ssout
continue
if desc.shape == (1,): # Pseudo-scalar
if desc.shape == (1, ): # Pseudo-scalar
scalout = False
sset, ssout = _recursive_in_check(last_edge.src, state, gpu_scalars)
scalset = scalset.union(sset)
Expand All @@ -81,10 +81,6 @@ def _recursive_in_check(node, state, gpu_scalars):
return scalset, scalout


def _codenode_condition(node):
return isinstance(node, (nodes.LibraryNode, nodes.NestedSDFG)) and node.schedule == dtypes.ScheduleType.GPU_Default


@make_properties
class GPUTransformSDFG(transformation.MultiStateTransformation):
""" Implements the GPUTransformSDFG transformation.
Expand Down Expand Up @@ -305,33 +301,63 @@ def apply(self, _, sdfg: sd.SDFG):

#######################################################
# Step 5: Collect free tasklets and check for scalars that have to be moved to the GPU
# Also recursively call GPUTransformSDFG on NestedSDFGs that have GPU device schedule but are not actually
# inside a GPU kernel.

gpu_scalars = {}
nsdfgs = []
changed = True
# Iterates over Tasklets that not inside a GPU kernel. Such Tasklets must be moved inside a GPU kernel only
# if they write to GPU memory. The check takes into account the fact that GPU kernels can read host-based
# Scalars, but cannot write to them.
while changed:
changed = False
for node, state in sdfg.all_nodes_recursive():
if isinstance(node, nodes.Tasklet):
if node in global_code_nodes[state]:
continue
if state.entry_node(node) is None and not scope.is_devicelevel_gpu_kernel(
state.parent, state, node):
scalars, scalar_output = _recursive_out_check(node, state, gpu_scalars)
sset, ssout = _recursive_in_check(node, state, gpu_scalars)
scalars = scalars.union(sset)
scalar_output = scalar_output and ssout
csdfg = state.parent
# If the tasklet is not adjacent only to scalars or it is in a GPU scope.
# The latter includes NestedSDFGs that have a GPU-Device schedule but are not in a GPU kernel.
if (not scalar_output
or (csdfg.parent is not None
and csdfg.parent_nsdfg_node.schedule == dtypes.ScheduleType.GPU_Default)):
global_code_nodes[state].append(node)
gpu_scalars.update({k: None for k in scalars})
changed = True
for state in sdfg.states():
for node in state.nodes():
# Handle NestedSDFGs later.
if isinstance(node, nodes.NestedSDFG):
if state.entry_node(node) is None and not scope.is_devicelevel_gpu_kernel(
state.parent, state, node):
nsdfgs.append((node, state))
elif isinstance(node, nodes.Tasklet):
if node in global_code_nodes[state]:
continue
if state.entry_node(node) is None and not scope.is_devicelevel_gpu_kernel(
state.parent, state, node):
scalars, scalar_output = _recursive_out_check(node, state, gpu_scalars)
sset, ssout = _recursive_in_check(node, state, gpu_scalars)
scalars = scalars.union(sset)
scalar_output = scalar_output and ssout
csdfg = state.parent
# If the tasklet is not adjacent only to scalars or it is in a GPU scope.
# The latter includes NestedSDFGs that have a GPU-Device schedule but are not in a GPU kernel.
if (not scalar_output
or (csdfg.parent is not None
and csdfg.parent_nsdfg_node.schedule == dtypes.ScheduleType.GPU_Default)):
global_code_nodes[state].append(node)
gpu_scalars.update({k: None for k in scalars})
changed = True

# Apply GPUTransformSDFG recursively to NestedSDFGs.
for node, state in nsdfgs:
excl_copyin = set()
for e in state.in_edges(node):
src = state.memlet_path(e)[0].src
if isinstance(src, nodes.AccessNode) and sdfg.arrays[src.data].storage in gpu_storage:
excl_copyin.add(e.dst_conn)
node.sdfg.arrays[e.dst_conn].storage = sdfg.arrays[src.data].storage
excl_copyout = set()
for e in state.out_edges(node):
dst = state.memlet_path(e)[-1].dst
if isinstance(dst, nodes.AccessNode) and sdfg.arrays[dst.data].storage in gpu_storage:
excl_copyout.add(e.src_conn)
node.sdfg.arrays[e.src_conn].storage = sdfg.arrays[dst.data].storage
# TODO: Do we want to copy here the options from the top-level SDFG?
node.sdfg.apply_transformations(
GPUTransformSDFG, {
'exclude_copyin': ','.join([str(n) for n in excl_copyin]),
'exclude_copyout': ','.join([str(n) for n in excl_copyout])
})

#######################################################
# Step 6: Modify transient data storage
Expand All @@ -350,26 +376,9 @@ def apply(self, _, sdfg: sd.SDFG):

if sdict[node] is None and nodedesc.storage not in gpu_storage:

# Ensure that scalars not already GPU-marked are actually used in a GPU scope.
# Scalars were already checked.
if isinstance(nodedesc, data.Scalar) and not node.data in gpu_scalars:
used_in_gpu_scope = False
for e in state.in_edges(node):
if _codenode_condition(state.memlet_path(e)[0].src):
used_in_gpu_scope = True
break
if not used_in_gpu_scope:
for e in state.out_edges(node):
if _codenode_condition(state.memlet_path(e)[-1].dst):
used_in_gpu_scope = True
break
if not used_in_gpu_scope:
continue
for e in state.all_edges(node):
for node in (e.src, e.dst):
if isinstance(node, nodes.Tasklet):
if (state.entry_node(node) is None and not scope.is_devicelevel_gpu(
state.parent, state, node, with_gpu_default=True)):
global_code_nodes[state].append(node)
continue

# NOTE: the cloned arrays match too but it's the same storage so we don't care
nodedesc.storage = dtypes.StorageType.GPU_Global
Expand Down Expand Up @@ -470,5 +479,5 @@ def apply(self, _, sdfg: sd.SDFG):
# Step 9: Simplify
if not self.simplify:
return

sdfg.simplify()
Loading

0 comments on commit 5153940

Please sign in to comment.