From 84ca3d2afff304e67bddc37cbb916828da31becc Mon Sep 17 00:00:00 2001 From: Ruihang Lai Date: Tue, 5 Sep 2023 22:03:02 -0400 Subject: [PATCH] [BugFix][Arith] IterMapRewriter abort rewriting once failure This PR fixes an issue of the IterMapRewriter. Prior to this PR, the mutation function of the rewriter class returns the mutation results even when an invalid PrimExpr pattern was detected. Returning the mutation results when failure is not expected, and in such cases we should "abort" the mutation, and return the input PrimExpr which is not mutated, since insisting on returning the mutation results sometimes it incurs further error on other arith components like simplification. One unit test is added to ensure the rewriter behaves as expectation. --- src/arith/iter_affine_map.cc | 1 + .../unittest/test_arith_iter_affine_map.py | 20 +++++++++++++++++++ 2 files changed, 21 insertions(+) diff --git a/src/arith/iter_affine_map.cc b/src/arith/iter_affine_map.cc index af1128aa27..57b2b4aad9 100644 --- a/src/arith/iter_affine_map.cc +++ b/src/arith/iter_affine_map.cc @@ -322,6 +322,7 @@ class IterMapRewriter : public ExprMutator { ErrorLogger(this) << "IterMapExpr or subclasses should only result from calls in " << "IterMapRewriter using DirectMutate. " << "Indirect return occurred in " << input_expr; + return input_expr; } return expr; } diff --git a/tests/python/unittest/test_arith_iter_affine_map.py b/tests/python/unittest/test_arith_iter_affine_map.py index cee9922e86..22043cded2 100644 --- a/tests/python/unittest/test_arith_iter_affine_map.py +++ b/tests/python/unittest/test_arith_iter_affine_map.py @@ -1285,5 +1285,25 @@ def test_normalize_to_iter_sum(): ) +def test_detect_iter_map_with_bufferload_recursion(): + n = tvm.tir.Var("n", "int32") + m = tvm.tir.Var("m", "int32") + divisor = tvm.tir.Var("divisor", "int32") + + i = tvm.tir.Var("i", "int32") + j = tvm.tir.Var("j", "int32") + + buffer = tvm.tir.decl_buffer((n,), "int32", name="seqlen") + + indices = [(buffer[i] + j) // divisor] + iter_vars = { + i: tvm.ir.Range(tvm.tir.const(0, "int32"), n), + j: tvm.ir.Range(tvm.tir.const(0, "int32"), m), + } + + result = tvm.arith.detect_iter_map(indices, iter_vars) + assert len(result.indices) == 0 + + if __name__ == "__main__": tvm.testing.main()