Skip to content

Commit

Permalink
Added test.
Browse files Browse the repository at this point in the history
  • Loading branch information
alexnick83 committed Sep 14, 2023
1 parent 0bbe5c4 commit 29b269b
Showing 1 changed file with 49 additions and 0 deletions.
49 changes: 49 additions & 0 deletions tests/sdfg/data/structure_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -443,6 +443,54 @@ def test_direct_read_structure():
assert np.allclose(B, ref)


def test_direct_read_structure_loops():

M, N, nnz = (dace.symbol(s) for s in ('M', 'N', 'nnz'))
csr_obj = dace.data.Structure(dict(indptr=dace.int32[M + 1], indices=dace.int32[nnz], data=dace.float32[nnz]),
name='CSRMatrix')

sdfg = dace.SDFG('csr_to_dense_direct_loops')

sdfg.add_datadesc('A', csr_obj)
sdfg.add_array('B', [M, N], dace.float32)

state = sdfg.add_state()

indices = state.add_access('A.indices')
data = state.add_access('A.data')
B = state.add_access('B')

t = state.add_tasklet('indirection', {'j', '__val'}, {'__out'}, '__out[i, j] = __val')
state.add_edge(indices, None, t, 'j', dace.Memlet(data='A.indices', subset='idx'))
state.add_edge(data, None, t, '__val', dace.Memlet(data='A.data', subset='idx'))
state.add_edge(t, '__out', B, None, dace.Memlet(data='B', subset='0:M, 0:N', volume=1))

idx_before, idx_guard, idx_after = sdfg.add_loop(None, state, None, 'idx', 'A.indptr[i]', 'idx < A.indptr[i+1]', 'idx + 1')
i_before, i_guard, i_after = sdfg.add_loop(None, idx_before, None, 'i', '0', 'i < M', 'i + 1')

sdfg.view()

func = sdfg.compile()

rng = np.random.default_rng(42)
A = sparse.random(20, 20, density=0.1, format='csr', dtype=np.float32, random_state=rng)
B = np.zeros((20, 20), dtype=np.float32)

inpA = csr_obj.dtype._typeclass.as_ctypes()(indptr=A.indptr.__array_interface__['data'][0],
indices=A.indices.__array_interface__['data'][0],
data=A.data.__array_interface__['data'][0],
rows=A.shape[0],
cols=A.shape[1],
M=A.shape[0],
N=A.shape[1],
nnz=A.nnz)

func(A=inpA, B=B, M=20, N=20, nnz=A.nnz)
ref = A.toarray()

assert np.allclose(B, ref)


def test_direct_read_nested_structure():
M, N, nnz = (dace.symbol(s) for s in ('M', 'N', 'nnz'))
csr_obj = dace.data.Structure(dict(indptr=dace.int32[M + 1], indices=dace.int32[nnz], data=dace.float32[nnz]),
Expand Down Expand Up @@ -505,3 +553,4 @@ def test_direct_read_nested_structure():
test_write_nested_structure()
test_direct_read_structure()
test_direct_read_nested_structure()
test_direct_read_structure_loops()

0 comments on commit 29b269b

Please sign in to comment.