Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Symbol redeclaration fix #1788

Open
wants to merge 10 commits into
base: main
Choose a base branch
from
14 changes: 10 additions & 4 deletions dace/codegen/targets/framecode.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ def dispatcher(self):
def preprocess(self, sdfg: SDFG) -> None:
"""
Called before code generation. Used for making modifications on the SDFG prior to code generation.

:note: Post-conditions assume that the SDFG will NOT be changed after this point.
:param sdfg: The SDFG to modify in-place.
"""
Expand Down Expand Up @@ -896,6 +896,8 @@ def generate_code(self,
# Allocate outer-level transients
self.allocate_arrays_in_scope(sdfg, sdfg, sdfg, global_stream, callsite_stream)

outside_symbols = sdfg.arglist() if is_top_level else set()

# Define constants as top-level-allocated
for cname, (ctype, _) in sdfg.constants_prop.items():
if isinstance(ctype, data.Array):
Expand Down Expand Up @@ -951,10 +953,14 @@ def generate_code(self,
and config.Config.get('compiler', 'fpga', 'vendor').lower() == 'intel_fpga'):
# Emit OpenCL type
callsite_stream.write(f'{isvarType.ocltype} {isvarName};\n', sdfg)
self.dispatcher.defined_vars.add(isvarName, disp.DefinedType.Scalar, isvarType.ctype)
else:
callsite_stream.write('%s;\n' % (isvar.as_arg(with_types=True, name=isvarName)), sdfg)
self.dispatcher.defined_vars.add(isvarName, disp.DefinedType.Scalar, isvarType.ctype)

# If the variable is passed as an input argument to the SDFG, do not need to declare it
if isvarName not in outside_symbols:
callsite_stream.write('%s;\n' % (isvar.as_arg(with_types=True, name=isvarName)), sdfg)
self.dispatcher.defined_vars.add(isvarName, disp.DefinedType.Scalar, isvarType.ctype)
else:
callsite_stream.write('//%s;\n' % (isvar.as_arg(with_types=True, name=isvarName)), sdfg)
callsite_stream.write('\n', sdfg)

#######################################################################
Expand Down
5 changes: 3 additions & 2 deletions dace/sdfg/validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,7 @@ def validate_control_flow_region(sdfg: 'SDFG',

def validate_sdfg(sdfg: 'dace.sdfg.SDFG', references: Set[int] = None, **context: bool):
""" Verifies the correctness of an SDFG by applying multiple tests.

:param sdfg: The SDFG to verify.
:param references: An optional set keeping seen IDs for object
miscopy validation.
Expand Down Expand Up @@ -313,14 +313,15 @@ def validate_sdfg(sdfg: 'dace.sdfg.SDFG', references: Set[int] = None, **context
for sym in desc.free_symbols:
symbols[str(sym)] = sym.dtype
validate_control_flow_region(sdfg, sdfg, initialized_transients, symbols, references, **context)


except InvalidSDFGError as ex:
# If the SDFG is invalid, save it
fpath = os.path.join('_dacegraphs', 'invalid.sdfgz')
sdfg.save(fpath, exception=ex, compress=True)
ex.path = fpath
raise


def _accessible(sdfg: 'dace.sdfg.SDFG', container: str, context: Dict[str, bool]):
"""
Helper function that returns False if a data container cannot be accessed in the current SDFG context.
Expand Down
52 changes: 52 additions & 0 deletions tests/interstate_assignment_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
from typing import Dict
import dace

N = dace.symbol("N")

def _get_interstate_dependent_sdfg(assignments: Dict, symbols_at_start=False):
sdfg = dace.SDFG("interstate_dependent")
for k in assignments:
sdfg.add_symbol(k, dace.int32)

s1 = sdfg.add_state("s1")
s2 = sdfg.add_state("s2")

if not symbols_at_start:
s0 = sdfg.add_state("s0")
pre_assignments = dict()
for k,v in assignments.items():
pre_assignments[k] = v*2
sdfg.add_edge(s0, s1, dace.InterstateEdge(None, assignments=pre_assignments))

for sid, s in [("1", s1), ("2", s2)]:
sdfg.add_array(f"array{sid}", (N, ) , dace.int32, storage=dace.StorageType.CPU_Heap, transient=True)
an = s.add_access(f"array{sid}")
an2 = s.add_access(f"array{sid}")
t = s.add_tasklet(f"tasklet{sid}", {"_in"}, {"_out"}, "_out = _in * 2")
map_entry, map_exit = s.add_map(f"map{sid}", {"i":dace.subsets.Range([(0,N-1,1)])})
for m in [map_entry, map_exit]:
m.add_in_connector(f"IN_array{sid}")
m.add_out_connector(f"OUT_array{sid}")
s.add_edge(an, None, map_entry, f"IN_array{sid}", dace.memlet.Memlet(f"array{sid}[0:N]"))
s.add_edge(map_entry, f"OUT_array{sid}", t, "_in", dace.memlet.Memlet(f"array{sid}[i]"))
s.add_edge(t, "_out", map_exit, f"IN_array{sid}", dace.memlet.Memlet(f"array{sid}[i]"))
s.add_edge(map_exit, f"OUT_array{sid}", an2, None, dace.memlet.Memlet(f"array{sid}[0:N]"))

sdfg.add_edge(s1, s2, dace.InterstateEdge(None, assignments=assignments))
sdfg.save("s1.sdfg")
sdfg.validate()
return sdfg

def test_interstate_assignment():
sdfg = _get_interstate_dependent_sdfg({"N": 5}, False)
sdfg.validate()
sdfg()

def test_interstate_assignment_on_sdfg_input():
sdfg = _get_interstate_dependent_sdfg({"N": 5}, True)
sdfg.validate()
sdfg(N=10)

if __name__ == "__main__":
test_interstate_assignment()
test_interstate_assignment_on_sdfg_input()
Loading