Skip to content

Commit

Permalink
Unskip unit tests and provide reasons for skipped tests (#1742)
Browse files Browse the repository at this point in the history
* Brings back 36 tests that were skipped due to prior regressions that
are now fixed.
* Fixes codegen generating non-atomic WCR w.r.t. neighboring edges
* Contains minor modifications that were used to unskip certain tests
(e.g., use of the no-longer-existent `symbol.get()` method)
* Gives valid reasons for all other skipped tests
  • Loading branch information
tbennun authored Nov 11, 2024
1 parent 1b99fe2 commit cb6391f
Show file tree
Hide file tree
Showing 47 changed files with 159 additions and 163 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/general-ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ jobs:
else
export DACE_optimizer_automatic_simplification=${{ matrix.simplify }}
fi
pytest -n auto --cov-report=xml --cov=dace --tb=short -m "not gpu and not verilator and not tensorflow and not mkl and not sve and not papi and not mlir and not lapack and not fpga and not mpi and not rtl_hardware and not scalapack and not datainstrument"
pytest -n auto --cov-report=xml --cov=dace --tb=short -m "not gpu and not verilator and not tensorflow and not mkl and not sve and not papi and not mlir and not lapack and not fpga and not mpi and not rtl_hardware and not scalapack and not datainstrument and not long"
./codecov
- name: Test OpenBLAS LAPACK
Expand Down
2 changes: 1 addition & 1 deletion dace/codegen/compiled_sdfg.py
Original file line number Diff line number Diff line change
Expand Up @@ -580,7 +580,7 @@ def _construct_args(self, kwargs) -> Tuple[Tuple[Any], Tuple[Any]]:
arg_ctypes = tuple(at.dtype.as_ctypes() for at in argtypes)

constants = self.sdfg.constants
callparams = tuple((actype(arg.get()) if isinstance(arg, symbolic.symbol) else arg, actype, atype, aname)
callparams = tuple((arg, actype, atype, aname)
for arg, actype, atype, aname in zip(arglist, arg_ctypes, argtypes, argnames)
if not (symbolic.issymbolic(arg) and (hasattr(arg, 'name') and arg.name in constants)))

Expand Down
42 changes: 34 additions & 8 deletions dace/codegen/targets/cpp.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from dace.frontend import operations
from dace.frontend.python import astutils
from dace.frontend.python.astutils import ExtNodeTransformer, rname, unparse
from dace.sdfg import nodes, graph as gr, utils
from dace.sdfg import nodes, graph as gr, utils, propagation
from dace.properties import LambdaProperty
from dace.sdfg import SDFG, is_devicelevel_gpu, SDFGState
from dace.codegen.targets import fpga
Expand Down Expand Up @@ -713,6 +713,31 @@ def _check_map_conflicts(map, edge):
return True


def _check_neighbor_conflicts(dfg, edge):
"""
Checks for other memlets writing to edges that may overlap in subsets.
Returns True if there are no conflicts, False if there may be.
"""
outer = propagation.propagate_memlet(dfg, edge.data, edge.dst, False)
siblings = dfg.in_edges(edge.dst)
for sibling in siblings:
if sibling is edge:
continue
if sibling.data.data != edge.data.data:
continue
# Check if there is definitely no overlap in the propagated memlet
sibling_outer = propagation.propagate_memlet(dfg, sibling.data, edge.dst, False)
if subsets.intersects(outer.subset, sibling_outer.subset) == False:
# In that case, continue
continue

# Other cases are indeterminate and will be atomic
return False
# No overlaps in current scope
return True


def write_conflicted_map_params(map, edge):
result = []
for itervar, (_, _, mapskip) in zip(map.params, map.range):
Expand Down Expand Up @@ -769,6 +794,8 @@ def is_write_conflicted_with_reason(dfg, edge, datanode=None, sdfg_schedule=None
for e in path:
if (isinstance(e.dst, nodes.ExitNode) and (e.dst.map.schedule != dtypes.ScheduleType.Sequential
and e.dst.map.schedule != dtypes.ScheduleType.Snitch)):
if not _check_neighbor_conflicts(dfg, e):
return e.dst
if _check_map_conflicts(e.dst.map, e):
# This map is parallel w.r.t. WCR
# print('PAR: Continuing from map')
Expand Down Expand Up @@ -984,10 +1011,9 @@ def unparse_tasklet(sdfg, cfg, state_id, dfg, node, function_stream, callsite_st
# To prevent variables-redefinition, build dictionary with all the previously defined symbols
defined_symbols = state_dfg.symbols_defined_at(node)

defined_symbols.update({
k: v.dtype if hasattr(v, 'dtype') else dtypes.typeclass(type(v))
for k, v in sdfg.constants.items()
})
defined_symbols.update(
{k: v.dtype if hasattr(v, 'dtype') else dtypes.typeclass(type(v))
for k, v in sdfg.constants.items()})

for connector, (memlet, _, _, conntype) in memlets.items():
if connector is not None:
Expand Down Expand Up @@ -1038,7 +1064,7 @@ def _Name(self, t: ast.Name):
# Replace values with their code-generated names (for example, persistent arrays)
desc = self.sdfg.arrays[t.id]
self.write(ptr(t.id, desc, self.sdfg, self.codegen))

def _Attribute(self, t: ast.Attribute):
from dace.frontend.python.astutils import rname
name = rname(t)
Expand Down Expand Up @@ -1325,8 +1351,8 @@ def visit_BinOp(self, node: ast.BinOp):
evaluated_constant = symbolic.evaluate(unparsed, self.constants)
evaluated = symbolic.symstr(evaluated_constant, cpp_mode=True)
value = ast.parse(evaluated).body[0].value
if isinstance(evaluated_node, numbers.Number) and evaluated_node != (value.value if sys.version_info
>= (3, 8) else value.n):
if isinstance(evaluated_node, numbers.Number) and evaluated_node != (value.value if sys.version_info >=
(3, 8) else value.n):
raise TypeError
node.right = ast.parse(evaluated).body[0].value
except (TypeError, AttributeError, NameError, KeyError, ValueError, SyntaxError):
Expand Down
2 changes: 1 addition & 1 deletion dace/sdfg/sdfg.py
Original file line number Diff line number Diff line change
Expand Up @@ -1208,7 +1208,7 @@ def cast(dtype: dt.Data, value: Any):
if isinstance(dtype, dt.Array):
return value
elif isinstance(dtype, dt.Scalar):
return dtype.dtype(value)
return dtype.dtype.type(value)
raise TypeError('Unsupported data type %s' % dtype)

result.update({k: cast(*v) for k, v in self.constants_prop.items()})
Expand Down
4 changes: 4 additions & 0 deletions dace/subsets.py
Original file line number Diff line number Diff line change
Expand Up @@ -404,6 +404,8 @@ def data_dims(self):
for ts in self.tile_sizes))

def offset(self, other, negative, indices=None, offset_end=True):
if other is None:
return
if not isinstance(other, Subset):
if isinstance(other, (list, tuple)):
other = Indices(other)
Expand All @@ -420,6 +422,8 @@ def offset(self, other, negative, indices=None, offset_end=True):
self.ranges[i] = (rb + mult * off[i], re, rs)

def offset_new(self, other, negative, indices=None, offset_end=True):
if other is None:
return Range(self.ranges)
if not isinstance(other, Subset):
if isinstance(other, (list, tuple)):
other = Indices(other)
Expand Down
1 change: 1 addition & 0 deletions pytest.ini
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ markers =
scalapack: Test requires ScaLAPACK (Intel MKL and OpenMPI). (select with '-m scalapack')
datainstrument: Test uses data instrumentation (select with '-m datainstrument')
hptt: Test requires the HPTT library (select with '-m "hptt')
long: Test runs for a long time and is skipped in CI (select with '-m "long"')
python_files =
*_test.py
*_cudatest.py
Expand Down
6 changes: 3 additions & 3 deletions samples/fpga/jacobi_fpga_systolic.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,10 +276,10 @@ def run_jacobi(w: int, h: int, t: int, p: int, specialize_all: bool = False):
print("Specializing H and T...")

jacobi = make_sdfg(specialize_all, h, w, t, p)
jacobi.specialize(dict(W=W, P=P))
jacobi.specialize(dict(W=w, P=p))

if specialize_all:
jacobi.specialize(dict(H=H, T=T))
jacobi.specialize(dict(H=h, T=t))

if t % p != 0:
raise ValueError("Iteration must be divisable by number of processing elements")
Expand All @@ -301,7 +301,7 @@ def run_jacobi(w: int, h: int, t: int, p: int, specialize_all: bool = False):
if specialize_all:
jacobi(A=A)
else:
jacobi(A=A, H=H, T=T)
jacobi(A=A, H=h, T=t)

# Regression
kernel = np.array([[0, 0.2, 0], [0.2, 0.2, 0.2], [0, 0.2, 0]], dtype=np.float32)
Expand Down
2 changes: 1 addition & 1 deletion tests/codegen/allocation_lifetime_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -480,7 +480,7 @@ def test_branched_allocation(mode):
sdfg.compile()


@pytest.mark.skip
@pytest.mark.skip('Dynamic array resize is not yet supported')
def test_scope_multisize():
""" An array that needs to be allocated multiple times with different sizes. """
sdfg = dace.SDFG('test')
Expand Down
4 changes: 2 additions & 2 deletions tests/fpga/jacobi_fpga_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@

# This kernel does not work with the Intel FPGA codegen, because it uses the
# constant systolic array index in the connector on the nested SDFG.
@pytest.mark.skip
@xilinx_test()
@pytest.mark.skip('Xilinx failure due to unresolved phi nodes, Intel FPGA failure due to systolic array index')
@xilinx_test(assert_ii_1=False)
def test_jacobi_fpga():
jacobi = import_sample(Path("fpga") / "jacobi_fpga_systolic.py")
return jacobi.run_jacobi(64, 512, 16, 4)
Expand Down
5 changes: 3 additions & 2 deletions tests/fpga/map_unroll_processing_elements_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from dace.config import set_temporary


@pytest.mark.skip
@pytest.mark.skip('Xilinx HLS fails due to unresolved phi nodes')
@xilinx_test(assert_ii_1=False)
def test_map_unroll_processing_elements():
# Grab the systolic GEMM implementation the samples directory
Expand Down Expand Up @@ -56,7 +56,7 @@ def test_map_unroll_processing_elements():
return sdfg


@pytest.mark.skip
@pytest.mark.skip('Test no longer achieves II=1')
@xilinx_test(assert_ii_1=True)
def test_map_unroll_processing_elements_decoupled():
# Grab the systolic GEMM implementation the samples directory
Expand Down Expand Up @@ -105,3 +105,4 @@ def test_map_unroll_processing_elements_decoupled():

if __name__ == "__main__":
test_map_unroll_processing_elements(None)
test_map_unroll_processing_elements_decoupled(None)
6 changes: 4 additions & 2 deletions tests/fpga/matmul_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ def test_gemm_vectorized():
return sdfg


@pytest.mark.skip
@pytest.mark.skip('Xilinx HLS fails due to unresolved phi nodes')
@xilinx_test(assert_ii_1=True)
def test_gemm_vectorized_decoupled():
# Test with vectorization
Expand Down Expand Up @@ -201,7 +201,7 @@ def test_gemm_size_not_multiples_of():
return sdfg


@pytest.mark.skip
@pytest.mark.skip('Xilinx HLS fails due to unresolved phi nodes')
@xilinx_test()
def test_gemm_size_not_multiples_of_decoupled():
# Test with matrix sizes that are not a multiple of #PEs and Tile sizes
Expand Down Expand Up @@ -249,5 +249,7 @@ def matmul_np(A: dace.float64[128, 64], B: dace.float64[64, 32], C: dace.float64
test_naive_matmul_fpga(None)
test_systolic_matmul_fpga(None)
test_gemm_vectorized(None)
test_gemm_vectorized_decoupled(None)
test_gemm_size_not_multiples_of(None)
test_gemm_size_not_multiples_of_decoupled(None)
test_matmul_np(None)
36 changes: 18 additions & 18 deletions tests/fpga/streaming_memory_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -380,7 +380,7 @@ def test_streaming_and_composition():
return sdfg


@pytest.mark.skip(reason="Save time")
@pytest.mark.long
def test_mem_buffer_vec_add_1():
# Make SDFG
sdfg: dace.SDFG = vecadd_1_streaming.to_sdfg()
Expand Down Expand Up @@ -408,7 +408,7 @@ def test_mem_buffer_vec_add_1():
return sdfg


@pytest.mark.skip(reason="Save time")
@pytest.mark.long
def test_mem_buffer_vec_add_1_symbolic():
# Make SDFG
sdfg: dace.SDFG = vecadd_1_streaming_symbol.to_sdfg()
Expand Down Expand Up @@ -495,55 +495,55 @@ def mem_buffer_vec_add_types(dace_type0, dace_type1, dace_type2, np_type0, np_ty
return sdfg


@pytest.mark.skip(reason="Save time")
@pytest.mark.long
# def test_mem_buffer_vec_add_float16():
# return mem_buffer_vec_add_types(dace.float16, dace.float16, dace.float16, np.float16, np.float16, np.float16)
@pytest.mark.skip(reason="Save time")
@pytest.mark.long
def test_mem_buffer_vec_add_float32():
return mem_buffer_vec_add_types(dace.float32, dace.float32, dace.float32, np.float32, np.float32, np.float32)


@pytest.mark.skip(reason="Save time")
@pytest.mark.long
def test_mem_buffer_vec_add_float64():
return mem_buffer_vec_add_types(dace.float64, dace.float64, dace.float64, np.float64, np.float64, np.float64)


@pytest.mark.skip(reason="Save time")
@pytest.mark.long
def test_mem_buffer_vec_add_int8():
return mem_buffer_vec_add_types(dace.int8, dace.int8, dace.int8, np.int8, np.int8, np.int8)


@pytest.mark.skip(reason="Save time")
@pytest.mark.long
def test_mem_buffer_vec_add_int16():
return mem_buffer_vec_add_types(dace.int16, dace.int16, dace.int16, np.int16, np.int16, np.int16)


@pytest.mark.skip(reason="Save time")
@pytest.mark.long
def test_mem_buffer_vec_add_int32():
return mem_buffer_vec_add_types(dace.int32, dace.int32, dace.int32, np.int32, np.int32, np.int32)


@pytest.mark.skip(reason="Save time")
@pytest.mark.long
def test_mem_buffer_vec_add_int64():
return mem_buffer_vec_add_types(dace.int64, dace.int64, dace.int64, np.int64, np.int64, np.int64)


@pytest.mark.skip(reason="Save time")
@pytest.mark.long
def test_mem_buffer_vec_add_complex64():
return mem_buffer_vec_add_types(dace.complex64, dace.complex64, dace.complex64, np.complex64, np.complex64,
np.complex64)


@pytest.mark.skip(reason="Save time")
@pytest.mark.long
def test_mem_buffer_vec_add_complex128():
return mem_buffer_vec_add_types(dace.complex128, dace.complex128, dace.complex128, np.complex128, np.complex128,
np.complex128)


@pytest.mark.skip(reason="Save time")
@pytest.mark.long
# def test_mem_buffer_vec_add_mixed_float():
# return mem_buffer_vec_add_types(dace.float16, dace.float32, dace.float64, np.float16, np.float32, np.float64)
@pytest.mark.skip(reason="Save time")
@pytest.mark.long
def test_mem_buffer_vec_add_mixed_int():
return mem_buffer_vec_add_types(dace.int16, dace.int32, dace.int64, np.int16, np.int32, np.int64)

Expand Down Expand Up @@ -575,7 +575,7 @@ def test_mem_buffer_mat_add():
return sdfg


@pytest.mark.skip(reason="Save time")
@pytest.mark.long
def test_mem_buffer_mat_add_symbol():
# Make SDFG
sdfg: dace.SDFG = matadd_streaming_symbol.to_sdfg()
Expand All @@ -602,7 +602,7 @@ def test_mem_buffer_mat_add_symbol():
return sdfg


@pytest.mark.skip(reason="Save time")
@pytest.mark.long
def test_mem_buffer_tensor_add():
# Make SDFG
sdfg: dace.SDFG = tensoradd_streaming.to_sdfg()
Expand Down Expand Up @@ -688,7 +688,7 @@ def test_mem_buffer_multistream_with_deps():
return sdfg


@pytest.mark.skip(reason="Save time")
@pytest.mark.long
def test_mem_buffer_mat_mul():
# Make SDFG
sdfg: dace.SDFG = matmul_streaming.to_sdfg()
Expand Down Expand Up @@ -799,7 +799,7 @@ def test_mem_buffer_not_applicable():
return []


@pytest.mark.skip(reason="Save time")
@pytest.mark.long
def test_mem_buffer_atax():

A = np.random.rand(M, N).astype(np.float32)
Expand Down Expand Up @@ -843,7 +843,7 @@ def test_mem_buffer_atax():
return sdfg


@pytest.mark.skip(reason="Save time")
@pytest.mark.long
def test_mem_buffer_bicg():

A = np.random.rand(N, M).astype(np.float32)
Expand Down
6 changes: 3 additions & 3 deletions tests/fpga/vec_sum_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,16 +83,16 @@ def test_vec_sum_vectorize_first_decoupled_interfaces():
return run_vec_sum(True)


@pytest.mark.skip
@xilinx_test(assert_ii_1=True)
def test_vec_sum_fpga_transform_first_decoupled_interfaces():
# For this test, decoupled read/write interfaces are needed to achieve II=1
with set_temporary("compiler", "xilinx", "decouple_array_interfaces", value=True):
return run_vec_sum(True)
with set_temporary('testing', 'serialization', value=False):
return run_vec_sum(True)


if __name__ == "__main__":
test_vec_sum_vectorize_first(None)
test_vec_sum_fpga_transform_first(None)

test_vec_sum_fpga_transform_first_decoupled_interfaces(None)

4 changes: 2 additions & 2 deletions tests/inlining_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def test():
myprogram.compile(dace.float32[W, H], dace.float32[H, W], dace.int32)


@pytest.mark.skip
@pytest.mark.skip('CI failure that cannot be reproduced outside CI')
def test_regression_reshape_unsqueeze():
nsdfg = dace.SDFG("nested_reshape_node")
nstate = nsdfg.add_state()
Expand Down Expand Up @@ -456,7 +456,7 @@ def test(A: dace.float64[96, 32], B: dace.float64[42, 32]):

if __name__ == "__main__":
test()
# Skipped to to bug that cannot be reproduced
# Skipped due to bug that cannot be reproduced outside CI
# test_regression_reshape_unsqueeze()
test_empty_memlets()
test_multistate_inline()
Expand Down
Loading

0 comments on commit cb6391f

Please sign in to comment.