diff --git a/tests/sdfg/data/structure_test.py b/tests/sdfg/data/structure_test.py index 02b8f0c174..fa22420d53 100644 --- a/tests/sdfg/data/structure_test.py +++ b/tests/sdfg/data/structure_test.py @@ -443,6 +443,54 @@ def test_direct_read_structure(): assert np.allclose(B, ref) +def test_direct_read_structure_loops(): + + M, N, nnz = (dace.symbol(s) for s in ('M', 'N', 'nnz')) + csr_obj = dace.data.Structure(dict(indptr=dace.int32[M + 1], indices=dace.int32[nnz], data=dace.float32[nnz]), + name='CSRMatrix') + + sdfg = dace.SDFG('csr_to_dense_direct_loops') + + sdfg.add_datadesc('A', csr_obj) + sdfg.add_array('B', [M, N], dace.float32) + + state = sdfg.add_state() + + indices = state.add_access('A.indices') + data = state.add_access('A.data') + B = state.add_access('B') + + t = state.add_tasklet('indirection', {'j', '__val'}, {'__out'}, '__out[i, j] = __val') + state.add_edge(indices, None, t, 'j', dace.Memlet(data='A.indices', subset='idx')) + state.add_edge(data, None, t, '__val', dace.Memlet(data='A.data', subset='idx')) + state.add_edge(t, '__out', B, None, dace.Memlet(data='B', subset='0:M, 0:N', volume=1)) + + idx_before, idx_guard, idx_after = sdfg.add_loop(None, state, None, 'idx', 'A.indptr[i]', 'idx < A.indptr[i+1]', 'idx + 1') + i_before, i_guard, i_after = sdfg.add_loop(None, idx_before, None, 'i', '0', 'i < M', 'i + 1') + + sdfg.view() + + func = sdfg.compile() + + rng = np.random.default_rng(42) + A = sparse.random(20, 20, density=0.1, format='csr', dtype=np.float32, random_state=rng) + B = np.zeros((20, 20), dtype=np.float32) + + inpA = csr_obj.dtype._typeclass.as_ctypes()(indptr=A.indptr.__array_interface__['data'][0], + indices=A.indices.__array_interface__['data'][0], + data=A.data.__array_interface__['data'][0], + rows=A.shape[0], + cols=A.shape[1], + M=A.shape[0], + N=A.shape[1], + nnz=A.nnz) + + func(A=inpA, B=B, M=20, N=20, nnz=A.nnz) + ref = A.toarray() + + assert np.allclose(B, ref) + + def test_direct_read_nested_structure(): M, N, nnz = (dace.symbol(s) for s in ('M', 'N', 'nnz')) csr_obj = dace.data.Structure(dict(indptr=dace.int32[M + 1], indices=dace.int32[nnz], data=dace.float32[nnz]), @@ -505,3 +553,4 @@ def test_direct_read_nested_structure(): test_write_nested_structure() test_direct_read_structure() test_direct_read_nested_structure() + test_direct_read_structure_loops()