Skip to content

Commit

Permalink
More codegen fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
ThrudPrimrose committed Dec 3, 2024
1 parent 82cdfde commit 97bc728
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 10 deletions.
9 changes: 5 additions & 4 deletions dace/sdfg/validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -747,8 +747,8 @@ def validate_state(state: 'dace.sdfg.SDFGState',
if isinstance(dst_node, nd.AccessNode) and isinstance(sdfg.arrays[dst_node.data], dt.Structure):
name = None
# Special case: if the name is the size array of the src_node, then it is ok, checked with the "size_desc_name"
src_size_access = isinstance(src_node, nd.AccessNode) and name == sdfg.arrays[src_node.data].size_desc_name
dst_size_access = isinstance(dst_node, nd.AccessNode) and name == sdfg.arrays[dst_node.data].size_desc_name
src_size_access = isinstance(src_node, nd.AccessNode) and isinstance(sdfg.arrays[src_node.data], dt.Array) and name is not None and name == sdfg.arrays[src_node.data].size_desc_name
dst_size_access = isinstance(dst_node, nd.AccessNode) and isinstance(sdfg.arrays[dst_node.data], dt.Array) and name is not None and name == sdfg.arrays[dst_node.data].size_desc_name
sdict = state.scope_dict()
if src_size_access and dst_size_access:
raise InvalidSDFGEdgeError(
Expand All @@ -766,9 +766,10 @@ def validate_state(state: 'dace.sdfg.SDFGState',
)
if dst_size_access:
dst_arr = sdfg.arrays[dst_node.data]
if dst_arr.storage != dace.dtypes.StorageType.GPU_Global or dst_arr.storage != dace.dtypes.StorageType.CPU_Heap:
if (dst_arr.storage != dtypes.StorageType.GPU_Global and
dst_arr.storage != dtypes.StorageType.CPU_Heap):
raise InvalidSDFGEdgeError(
"Reallocating data (writing to the size connector) within a scope is not valid",
f"Reallocating data is allowed only to GPU_Global or CPU_Heap, the storage type is {dst_arr.storage}",
sdfg,
state_id,
eid,
Expand Down
34 changes: 28 additions & 6 deletions tests/deferred_alloc_test.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,22 @@
import dace
import numpy
import cupy
import pytest

@pytest.fixture(params=[dace.dtypes.StorageType.CPU_Heap, dace.dtypes.StorageType.GPU_Global])
def storage_type(request):
return request.param

@pytest.fixture(params=[True, False])
def transient(request):
return request.param

@pytest.fixture
def schedule_type(storage_type):
if storage_type == dace.dtypes.StorageType.CPU_Heap:
return dace.dtypes.ScheduleType.Sequential
elif storage_type == dace.dtypes.StorageType.GPU_Global:
return dace.dtypes.ScheduleType.GPU_Device

def _get_trivial_alloc_sdfg(storage_type: dace.dtypes.StorageType, transient: bool, write_size="0:2"):
sdfg = dace.sdfg.SDFG(name="deferred_alloc_test")
Expand Down Expand Up @@ -126,26 +142,32 @@ def test_realloc_use(storage_type: dace.dtypes.StorageType, transient: bool, sch
assert ( arr.get()[0] == 3.0 )


def test_realloc_inside_map():
pass


def test_all_combinations(storage_type, transient, schedule_type):
test_trivial_realloc(storage_type, transient)
test_realloc_use(storage_type, transient, schedule_type)

def test_incomplete_write_dimensions_1():
sdfg = _get_trivial_alloc_sdfg(dace.dtypes.StorageType.CPU_Heap, True, "1:2")
sdfg = _get_trivial_alloc_sdfg(dace.dtypes.StorageType.CPU_Heap, True, "1:2")
try:
sdfg.validate()
except Exception:
return

raise AssertionError("Realloc-use with transient data and incomplete write did not fail when it was expected to.")
pytest.fail("Realloc-use with transient data and incomplete write did not fail when it was expected to.")

def test_incomplete_write_dimensions_2():
sdfg = _get_trivial_alloc_sdfg(dace.dtypes.StorageType.CPU_Heap, False, "1:2")
sdfg = _get_trivial_alloc_sdfg(dace.dtypes.StorageType.CPU_Heap, False, "1:2")
try:
sdfg.validate()
except Exception:
return

raise AssertionError("Realloc-use with non-transient data and incomplete write did not fail when it was expected to.")
pytest.fail("Realloc-use with non-transient data and incomplete write did not fail when it was expected to.")

def test_realloc_inside_map():
pass

if __name__ == "__main__":
for storage_type, schedule_type in [(dace.dtypes.StorageType.CPU_Heap, dace.dtypes.ScheduleType.Sequential),
Expand Down

0 comments on commit 97bc728

Please sign in to comment.