Skip to content

Commit

Permalink
Infer aliasing information for nested SDFGs (#1121)
Browse files Browse the repository at this point in the history
* Infer aliasing information and generate `__restrict__` keywords in code generation
  • Loading branch information
tbennun authored Oct 12, 2022
1 parent f1afce3 commit 85843f0
Show file tree
Hide file tree
Showing 7 changed files with 188 additions and 4 deletions.
2 changes: 1 addition & 1 deletion dace/codegen/targets/cpp.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,7 +305,7 @@ def emit_memlet_reference(dispatcher,
else:
datadef = ptr(memlet.data, desc, sdfg, dispatcher.frame)

def make_const(expr):
def make_const(expr: str) -> str:
# check whether const has already been added before
if not expr.startswith("const "):
return "const " + expr
Expand Down
16 changes: 15 additions & 1 deletion dace/codegen/targets/cpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -1444,7 +1444,21 @@ def generate_nsdfg_header(self, sdfg, state, state_id, node, memlet_references,
toplevel_sdfg: SDFG = sdfg.sdfg_list[0]
arguments.append(f'{toplevel_sdfg.name}_t *__state')

arguments += [f'{atype} {aname}' for atype, aname, _ in memlet_references]
# Add "__restrict__" keywords to arguments that do not alias with others in the context of this SDFG
restrict_args = []
for atype, aname, _ in memlet_references:
def make_restrict(expr: str) -> str:
# Check whether "restrict" has already been added before and can be added
if expr.strip().endswith('*'):
return '__restrict__'
else:
return ''
if aname in node.sdfg.arrays and not node.sdfg.arrays[aname].may_alias:
restrict_args.append(make_restrict(atype))
else:
restrict_args.append('')

arguments += [f'{atype} {restrict} {aname}' for (atype, aname, _), restrict in zip(memlet_references, restrict_args)]
arguments += [
f'{node.sdfg.symbols[aname].as_arg(aname)}' for aname in sorted(node.symbol_mapping.keys())
if aname not in sdfg.constants
Expand Down
8 changes: 8 additions & 0 deletions dace/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -397,6 +397,10 @@ def optional(self) -> bool:
def pool(self) -> bool:
return False

@property
def may_alias(self) -> bool:
return False

def is_equivalent(self, other):
if not isinstance(other, Scalar):
return False
Expand Down Expand Up @@ -765,6 +769,10 @@ def start_offset(self):
def optional(self) -> bool:
return False

@property
def may_alias(self) -> bool:
return False

def clone(self):
return type(self)(self.dtype, self.buffer_size, self.shape, self.transient, self.storage, self.location,
self.offset, self.lifetime, self.debuginfo)
Expand Down
72 changes: 71 additions & 1 deletion dace/sdfg/infer_types.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
# Copyright 2019-2021 ETH Zurich and the DaCe authors. All rights reserved.
from collections import defaultdict
from dace import data, dtypes
from dace.codegen.tools import type_inference
from dace.memlet import Memlet
from dace.sdfg import SDFG, SDFGState, nodes
from dace.sdfg import nodes
from dace.sdfg.graph import Edge
from dace.sdfg.utils import dfs_topological_sort
from typing import Dict, List, Optional
from typing import Callable, Dict, List, Optional, Set

#############################################################################
# Connector type inference
Expand Down Expand Up @@ -268,3 +271,70 @@ def _set_default_storage_types(sdfg: SDFG, toplevel_schedule: dtypes.ScheduleTyp
desc.storage = sdfg.arrays[e.data.data].storage
break
_set_default_storage_types(node.sdfg, node.schedule)


def infer_aliasing(node: nodes.NestedSDFG, sdfg: SDFG, state: SDFGState) -> None:
"""
Infers aliasing information on nested SDFG arrays based on external edges and connectors.
Operates in-place on nested SDFG node.
:param node: The nested SDFG node.
:param sdfg: Parent SDFG of the nested SDFG node.
:param state: Parent state of the nested SDFG node.
"""
data_to_conn: Dict[str, Set[str]] = defaultdict(set)

def _infer_aliased_connectors(
get_edges: Callable[[nodes.NestedSDFG], List[Edge[Memlet]]],
get_conn: Callable[[Edge[Memlet]], str],
outgoing: bool,
):
for e in get_edges(node):
if e.data.is_empty(): # Skip empty memlets
continue

# Get all addressed arrays (through views)
dnames = _get_addressed_arrays(state, e, outgoing=outgoing)

# Register data name mapping to matching connectors
conn = get_conn(e)
for dname in dnames:
data_to_conn[dname].add(conn)

# Infer for input arrays
_infer_aliased_connectors(state.in_edges, lambda e: e.dst_conn, False)

# Infer for output arrays
_infer_aliased_connectors(state.out_edges, lambda e: e.src_conn, True)

# If array is already connected to the nested SDFG in multiple, different connector names;
# it may alias with others.
for dname, conns in data_to_conn.items():
# If the original array may alias already, set the child to alias too
if len(conns) > 1 or sdfg.arrays[dname].may_alias:
for aname in conns:
# Modify internal arrays
if aname in node.sdfg.arrays:
desc = node.sdfg.arrays[aname]
if isinstance(desc, data.Array): # The only data type where may_alias can be set
desc.may_alias = True


def _get_addressed_arrays(state: SDFGState, edge: Edge[Memlet], outgoing: bool) -> Set[str]:
"""
Helper function that returns the actual array data descriptor name from a memlet.
Traces the memlet path out, including through views.
"""
# Avoid import loop
from dace.sdfg import utils as sdutil

mpath = state.memlet_path(edge)
last_node = mpath[-1].dst if outgoing else mpath[0].src
if not isinstance(last_node, nodes.AccessNode):
return {edge.data.data}

# If access node, find viewed node
last_node = sdutil.get_all_view_nodes(state, last_node)
if last_node is None:
return {edge.data.data}
return set(n.data for n in last_node)
6 changes: 5 additions & 1 deletion dace/sdfg/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -573,7 +573,11 @@ def free_symbols(self) -> Set[str]:

def infer_connector_types(self, sdfg, state):
# Avoid import loop
from dace.sdfg.infer_types import infer_connector_types
from dace.sdfg.infer_types import infer_connector_types, infer_aliasing

# Propagate aliasing information into SDFG
infer_aliasing(self, sdfg, state)

# Infer internal connector types
infer_connector_types(self.sdfg)

Expand Down
18 changes: 18 additions & 0 deletions dace/sdfg/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -700,6 +700,24 @@ def get_last_view_node(state: SDFGState, view: nd.AccessNode) -> nd.AccessNode:
return node


def get_all_view_nodes(state: SDFGState, view: nd.AccessNode) -> List[nd.AccessNode]:
"""
Given a view access node, returns a list of viewed access nodes
if existent, else None
"""
sdfg = state.parent
node = view
desc = sdfg.arrays[node.data]
result = [node]
while isinstance(desc, dt.View):
node = get_view_node(state, node)
if node is None or not isinstance(node, nd.AccessNode):
return None
desc = sdfg.arrays[node.data]
result.append(node)
return result


def get_view_edge(state: SDFGState, view: nd.AccessNode) -> gr.MultiConnectorEdge[mm.Memlet]:
"""
Given a view access node, returns the
Expand Down
70 changes: 70 additions & 0 deletions tests/codegen/alias_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
# Copyright 2019-2022 ETH Zurich and the DaCe authors. All rights reserved.
""" Tests aliasing analysis. """
import pytest
import dace

AliasedArray = dace.data.Array(dace.float64, (20, ), may_alias=True)


@pytest.mark.parametrize('may_alias', (False, True))
def test_simple_program(may_alias):
desc = AliasedArray if may_alias else dace.float64[20]

@dace.program
def tester(a: desc, b: desc, c: desc):
c[:] = a + b

code = tester.to_sdfg().generate_code()[0]

if may_alias:
assert code.clean_code.count('__restrict__') == 0
else:
assert code.clean_code.count('__restrict__') >= 3


def test_multi_nested():

@dace.program
def nested(a: dace.float64[20], b: dace.float64[20]):
b[:] = a + 1

@dace.program
def interim(a: dace.float64[20], b: dace.float64[20]):
nested(a, b)

@dace.program
def tester(a: AliasedArray, b: dace.float64[20]):
interim(a, b)

code = tester.to_sdfg(simplify=False).generate_code()[0]

# Restrict keyword should show up once per aliased array, even if nested programs say otherwise
assert code.clean_code.count('__restrict__') == 4 # = [__program, tester, interim, nested]


def test_inference():

@dace.program
def nested(a: dace.float64[2, 20], b: dace.float64[2, 20]):
b[:] = a + 1

@dace.program
def interim(a: dace.float64[3, 20]):
nested(a[:2], a[1:])

@dace.program
def tester(a: dace.float64[20]):
interim(a)

code = tester.to_sdfg(simplify=False).generate_code()[0]

# Restrict keyword should never show up in "nested", since arrays are aliased,
# but should show up in [__program, tester, interim]
assert code.clean_code.count('__restrict__') == 3


if __name__ == '__main__':
test_simple_program(False)
test_simple_program(True)
test_multi_nested()
test_inference()

0 comments on commit 85843f0

Please sign in to comment.