Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Passes Interface Improvements #1124

Merged
merged 14 commits into from
Oct 24, 2022
Merged
18 changes: 9 additions & 9 deletions dace/libraries/stencil/intel_fpga.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,15 +88,15 @@ def expansion(node, parent_state, parent_sdfg):

# Manually add pipeline entry and exit nodes
pipeline_range = dace.properties.SubsetProperty.from_string(', '.join(iterators.values()))
pipeline = dace.sdfg.nodes.Pipeline("compute_" + node.label,
list(iterators.keys()),
pipeline_range,
dace.dtypes.ScheduleType.FPGA_Device,
False,
init_size=init_size_max,
init_overlap=False,
drain_size=init_size_max,
drain_overlap=True)
pipeline = dace.sdfg.nodes.PipelineScope("compute_" + node.label,
list(iterators.keys()),
pipeline_range,
dace.dtypes.ScheduleType.FPGA_Device,
False,
init_size=init_size_max,
init_overlap=False,
drain_size=init_size_max,
drain_overlap=True)
entry = dace.sdfg.nodes.PipelineEntry(pipeline)
exit = dace.sdfg.nodes.PipelineExit(pipeline)
state.add_nodes_from([entry, exit])
Expand Down
10 changes: 5 additions & 5 deletions dace/sdfg/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -1080,7 +1080,7 @@ class PipelineEntry(MapEntry):

@staticmethod
def map_type():
return Pipeline
return PipelineScope

@property
def pipeline(self):
Expand Down Expand Up @@ -1113,7 +1113,7 @@ class PipelineExit(MapExit):

@staticmethod
def map_type():
return Pipeline
return PipelineScope

@property
def pipeline(self):
Expand All @@ -1125,7 +1125,7 @@ def pipeline(self, val):


@make_properties
class Pipeline(Map):
class PipelineScope(Map):
""" This a convenience-subclass of Map that allows easier implementation of
loop nests (using regular Map indices) that need a constant-sized
initialization and drain phase (e.g., N*M + c iterations), which would
Expand All @@ -1149,7 +1149,7 @@ def __init__(self,
drain_overlap=False,
additional_iterators={},
**kwargs):
super(Pipeline, self).__init__(*args, **kwargs)
super(PipelineScope, self).__init__(*args, **kwargs)
self.init_size = init_size
self.init_overlap = init_overlap
self.drain_size = drain_size
Expand Down Expand Up @@ -1184,7 +1184,7 @@ def drain_condition(self):
return self.iterator_str() + "_drain"


PipelineEntry = indirect_properties(Pipeline, lambda obj: obj.map)(PipelineEntry)
PipelineEntry = indirect_properties(PipelineScope, lambda obj: obj.map)(PipelineEntry)

# ------------------------------------------------------------------------------

Expand Down
20 changes: 10 additions & 10 deletions dace/sdfg/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -1427,16 +1427,16 @@ def add_pipeline(self,
:return: (map_entry, map_exit) node 2-tuple
"""
debuginfo = _getdebuginfo(debuginfo or self._default_lineinfo)
pipeline = nd.Pipeline(name,
*_make_iterators(ndrange),
init_size=init_size,
init_overlap=init_overlap,
drain_size=drain_size,
drain_overlap=drain_overlap,
additional_iterators=additional_iterators,
schedule=schedule,
debuginfo=debuginfo,
**kwargs)
pipeline = nd.PipelineScope(name,
*_make_iterators(ndrange),
init_size=init_size,
init_overlap=init_overlap,
drain_size=drain_size,
drain_overlap=drain_overlap,
additional_iterators=additional_iterators,
schedule=schedule,
debuginfo=debuginfo,
**kwargs)
pipeline_entry = nd.PipelineEntry(pipeline)
pipeline_exit = nd.PipelineExit(pipeline)
self.add_nodes_from([pipeline_entry, pipeline_exit])
Expand Down
129 changes: 0 additions & 129 deletions dace/transformation/interstate/state_elimination.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,92 +219,6 @@ def _str_repl(s, d):
symbolic.safe_replace(repl_dict, lambda m: _str_repl(sdfg, m))


class ConstantPropagation(transformation.MultiStateTransformation):
"""
Removes constant assignments in interstate edges and replaces them in successor states.
"""

end_state = transformation.PatternNode(sdfg.SDFGState)

@classmethod
def expressions(cls):
return [sdutil.node_path_graph(cls.end_state)]

def can_be_applied(self, graph, expr_index, sdfg: SDFG, permissive=False):
state = self.end_state

out_edges = graph.out_edges(state)
in_edges = graph.in_edges(state)

# We only match states with one source and at least one assignment
if len(in_edges) != 1:
return False
edge = in_edges[0]
assignments_to_consider = _assignments_to_consider(sdfg, edge, True)

# No assignments to eliminate
if len(assignments_to_consider) == 0:
return False

# If this is an end state, there are no other edges to consider
if len(out_edges) == 0:
return True

# Otherwise, ensure the symbols are never set/used again in edges
akeys = set(assignments_to_consider.keys())
for e in sdfg.bfs_edges(state):
if e is edge:
continue
if e.data.assignments.keys() & akeys:
return False

return True

def apply(self, _, sdfg: SDFG):
state = self.end_state
edge = sdfg.in_edges(state)[0]
# Since inter-state assignments that use an assigned value leads to
# undefined behavior (e.g., {m: n, n: m}), we can replace each
# assignment separately.
assignments_to_consider = _assignments_to_consider(sdfg, edge, True)

def _str_repl(s, d, **kwargs):
for k, v in d.items():
s.replace(str(k), str(v), **kwargs)

# Replace in state, and all successors
symbolic.safe_replace(assignments_to_consider, lambda m: _str_repl(state, m))
visited = {edge}
for isedge in sdfg.bfs_edges(state):
if isedge not in visited:
symbolic.safe_replace(assignments_to_consider, lambda m: _str_repl(isedge.data, m, replace_keys=False))
visited.add(isedge)
if isedge.dst not in visited:
symbolic.safe_replace(assignments_to_consider, lambda m: _str_repl(isedge.dst, m))
visited.add(isedge.dst)

repl_dict = {}

for varname in assignments_to_consider.keys():
# Remove assignments from edge
del edge.data.assignments[varname]

for e in sdfg.edges():
if varname in e.data.free_symbols:
break
else:
# If removed assignment does not appear in any other edge,
# replace and remove symbol
if varname in sdfg.symbols:
sdfg.remove_symbol(varname)
# if assignments_to_consider[varname] in sdfg.symbols:
if varname in sdfg.free_symbols:
repl_dict[varname] = assignments_to_consider[varname]

if repl_dict:
symbolic.safe_replace(repl_dict, lambda m: _str_repl(sdfg, m))


def _alias_assignments(sdfg, edge):
assignments_to_consider = {}
for var, assign in edge.assignments.items():
Expand Down Expand Up @@ -570,49 +484,6 @@ def replfunc(m):
nsdfg.sdfg.start_state = nsdfg.sdfg.node_id(nisedge.dst)


class DeadStateElimination(transformation.MultiStateTransformation):
"""
Dead state elimination removes an unreachable state and all of its dominated
states.
"""

end_state = transformation.PatternNode(sdfg.SDFGState)

@classmethod
def expressions(cls):
return [sdutil.node_path_graph(cls.end_state)]

def can_be_applied(self, graph: SDFG, expr_index, sdfg: SDFG, permissive=False):
state: SDFGState = self.end_state
in_edges = graph.in_edges(state)

# We only match end states with one source and at least one assignment
if len(in_edges) != 1:
return False
edge = in_edges[0]

if edge.data.assignments:
return False
if edge.data.is_unconditional():
return False

# Evaluate condition
scond = edge.data.condition_sympy()
if scond == False:
return True

return False

def apply(self, _, sdfg: SDFG):
# Remove state and all dominated states
state = self.end_state

domset = cfg.all_dominators(sdfg)
states_to_remove = {k for k, v in domset.items() if state in v}
states_to_remove.add(state)
sdfg.remove_nodes_from(states_to_remove)


class TrueConditionElimination(transformation.MultiStateTransformation):
"""
If a state transition condition is always true, removes condition from edge.
Expand Down
Loading