Skip to content

Commit

Permalink
Fix TIR TVMScript to adapt apache/tvm#15214
Browse files Browse the repository at this point in the history
  • Loading branch information
MasterJH5574 committed Aug 1, 2023
1 parent 7a603c9 commit e1e5333
Show file tree
Hide file tree
Showing 19 changed files with 314 additions and 314 deletions.
70 changes: 35 additions & 35 deletions tests/python/relax/test_analysis_suggest_layout_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def apply_transformations(func, suggested_transfoms, print_transformation=False)


def test_nested_blocks():
@T.prim_func
@T.prim_func(private=True)
def nested_block(
arg: T.Buffer((32, 64, 224, 224), "float32"),
relu: T.Buffer((32, 64, 224, 224), "float32"),
Expand All @@ -67,7 +67,7 @@ def nested_block(


def test_mismatch_transformations_and_num_params():
@T.prim_func
@T.prim_func(private=True)
def elemwise(
arg: T.Buffer((32, 64, 224, 224), "float32"),
relu: T.Buffer((32, 64, 224, 224), "float32"),
Expand All @@ -91,7 +91,7 @@ def elemwise(


def test_empty_write_transformations():
@T.prim_func
@T.prim_func(private=True)
def elemwise(
arg: T.Buffer((32, 64, 224, 224), "float32"),
relu: T.Buffer((32, 64, 224, 224), "float32"),
Expand All @@ -110,7 +110,7 @@ def elemwise(


def test_non_bijective_block_transform():
@T.prim_func
@T.prim_func(private=True)
def before(
arg: T.Buffer((32, 64), "float32"),
output: T.Buffer((32, 64), "float32"),
Expand All @@ -129,7 +129,7 @@ def before(


def test_non_affine_access():
@T.prim_func
@T.prim_func(private=True)
def before(
arg: T.Buffer((32, 64), "float32"),
output: T.Buffer((32 * 64, 10), "float32"),
Expand All @@ -148,7 +148,7 @@ def before(


def test_unsupported_write_spatial_layout():
@T.prim_func
@T.prim_func(private=True)
def before(
arg: T.Buffer((4, 4), "float32"),
output: T.Buffer((16), "float32"),
Expand All @@ -167,7 +167,7 @@ def before(


def test_unpacked_iter_used_in_read_access():
@T.prim_func
@T.prim_func(private=True)
def before(
arg: T.Buffer((8, 4), "float32"),
output: T.Buffer((4, 8), "float32"),
Expand All @@ -179,7 +179,7 @@ def before(
T.writes(output[v_ax0, v_ax1])
output[v_ax0, v_ax1] = arg[v_ax1, v_ax2]

@T.prim_func
@T.prim_func(private=True)
def expected(
arg: T.Buffer((8, 4), "float32"),
output: T.Buffer((32), "float32"),
Expand All @@ -199,7 +199,7 @@ def expected(


def test_invalid_index_map():
@T.prim_func
@T.prim_func(private=True)
def elemwise(
arg: T.Buffer((32, 64, 224, 224), "float32"),
relu: T.Buffer((32, 64, 224, 224), "float32"),
Expand All @@ -220,7 +220,7 @@ def elemwise(


def test_SRSR_block():
@T.prim_func
@T.prim_func(private=True)
def before(
arg: T.Buffer((32, 224, 64, 224), "float32"),
sum: T.Buffer((32, 64), "float32"),
Expand All @@ -234,7 +234,7 @@ def before(
sum[v_ax0, v_ax1] = T.float32(0)
sum[v_ax0, v_ax1] = sum[v_ax0, v_ax1] + arg[v_ax0, v_k2, v_ax1, v_k3]

@T.prim_func
@T.prim_func(private=True)
def expected(
arg: T.Buffer((32, 224, 16, 224, 4), "float32"),
sum: T.Buffer((32, 16, 4), "float32"),
Expand All @@ -256,7 +256,7 @@ def expected(


def test_op_elemwise_symbolic():
@T.prim_func
@T.prim_func(private=True)
def before(arg: T.handle, relu: T.handle):
N = T.int64()
C = T.int64()
Expand All @@ -271,7 +271,7 @@ def before(arg: T.handle, relu: T.handle):
T.writes(Relu[v_i0, v_i1, v_i2, v_i3])
Relu[v_i0, v_i1, v_i2, v_i3] = T.max(Arg[v_i0, v_i1, v_i2, v_i3], T.float32(0))

@T.prim_func
@T.prim_func(private=True)
def expected(arg: T.handle, relu: T.handle):
N = T.int64()
C = T.int64()
Expand All @@ -295,7 +295,7 @@ def expected(arg: T.handle, relu: T.handle):


def test_op_elemwise():
@T.prim_func
@T.prim_func(private=True)
def before(
arg: T.Buffer((32, 64, 224, 224), "float32"),
relu: T.Buffer((32, 64, 224, 224), "float32"),
Expand All @@ -307,7 +307,7 @@ def before(
T.writes(relu[v_i0, v_i1, v_i2, v_i3])
relu[v_i0, v_i1, v_i2, v_i3] = T.max(arg[v_i0, v_i1, v_i2, v_i3], T.float32(0))

@T.prim_func
@T.prim_func(private=True)
def expected(
arg: T.Buffer((32, 224, 224, 64), "float32"),
relu: T.Buffer((32, 224, 224, 64), "float32"),
Expand All @@ -327,7 +327,7 @@ def expected(


def test_op_pool_nchw_nhwc():
@T.prim_func
@T.prim_func(private=True)
def before(
arg: T.Buffer((32, 64, 224, 224), "float32"),
pool_max: T.Buffer((32, 64, 111, 223), "float32"),
Expand Down Expand Up @@ -359,7 +359,7 @@ def before(
],
)

@T.prim_func
@T.prim_func(private=True)
def expected(
arg: T.Buffer((32, 224, 224, 64), "float32"),
pool_max: T.Buffer((32, 111, 223, 64), "float32"),
Expand Down Expand Up @@ -387,7 +387,7 @@ def expected(


def test_op_pool_nchw16c_nhwc():
@T.prim_func
@T.prim_func(private=True)
def before(
arg: T.Buffer(
(32, 4, 224, 224, 16),
Expand All @@ -413,7 +413,7 @@ def before(
arg[v_ax0, v_ax1, v_ax2 * 2 + v_rv0, v_ax3 + v_rv1, v_ax4],
)

@T.prim_func
@T.prim_func(private=True)
def expected(
arg: T.Buffer((32, 224, 224, 64), "float32"),
pool_max: T.Buffer((32, 110, 220, 64), "float32"),
Expand All @@ -440,7 +440,7 @@ def expected(


def test_op_reduce():
@T.prim_func
@T.prim_func(private=True)
def before(
arg: T.Buffer((32, 64, 224, 224), "float32"),
sum: T.Buffer((32, 64), "float32"),
Expand All @@ -454,7 +454,7 @@ def before(
sum[v_ax0, v_ax1] = T.float32(0)
sum[v_ax0, v_ax1] = sum[v_ax0, v_ax1] + arg[v_ax0, v_ax1, v_k2, v_k3]

@T.prim_func
@T.prim_func(private=True)
def expected(
arg: T.Buffer((32, 4, 224, 224, 16), "float32"),
sum: T.Buffer((32, 4, 16), "float32"),
Expand All @@ -477,7 +477,7 @@ def expected(

def test_op_upsampling():
# relay materializes the layout if H, W or D dimensions are moved or tiled.
@T.prim_func
@T.prim_func(private=True)
def before(
arg: T.Buffer((32, 64, 224, 224), "float32"),
resize: T.Buffer((32, 64, 202, 246), "float32"),
Expand Down Expand Up @@ -518,7 +518,7 @@ def before(
),
]

@T.prim_func
@T.prim_func(private=True)
def expected(
arg: T.Buffer((32, 64, 224, 224), "float32"),
resize: T.Buffer((32, 202, 246, 64), "float32"),
Expand Down Expand Up @@ -568,7 +568,7 @@ def expected(


def test_op_strided_slice():
@T.prim_func
@T.prim_func(private=True)
def before(
arg: T.Buffer((32, 64, 224, 224), "float32"),
T_strided_slice_with_axes: T.Buffer((32, 64, 10, 8), "float32"),
Expand All @@ -592,7 +592,7 @@ def before(
v_ax3 * 7 + 4,
]

@T.prim_func
@T.prim_func(private=True)
def expected(
arg: T.Buffer((32, 224, 224, 16, 4), "float32"),
T_strided_slice_with_axes: T.Buffer((32, 10, 8, 16, 4), "float32"),
Expand All @@ -615,7 +615,7 @@ def expected(


def test_op_binary_broadcast():
@T.prim_func
@T.prim_func(private=True)
def before(
arg0: T.Buffer((32, 64, 224, 224), "float32"),
arg1: T.Buffer((64, 224, 224), "float32"),
Expand All @@ -635,7 +635,7 @@ def before(
arg0[v_ax0, v_ax1, v_ax2, v_ax3] + arg1[v_ax1, v_ax2, v_ax3]
)

@T.prim_func
@T.prim_func(private=True)
def expected(
arg0: T.Buffer((32, 224, 224, 16, 4), "float32"),
arg1: T.Buffer((224, 224, 16, 4), "float32"),
Expand All @@ -658,7 +658,7 @@ def expected(


def test_op_transpose():
@T.prim_func
@T.prim_func(private=True)
def before(
arg: T.Buffer((32, 64, 224, 224), "float32"),
T_transpose: T.Buffer((32, 224, 224, 64), "float32"),
Expand All @@ -670,7 +670,7 @@ def before(
T.writes(T_transpose[v_ax0, v_ax1, v_ax2, v_ax3])
T_transpose[v_ax0, v_ax1, v_ax2, v_ax3] = arg[v_ax0, v_ax3, v_ax1, v_ax2]

@T.prim_func
@T.prim_func(private=True)
def expected(
arg: T.Buffer((32, 64, 224, 224), "float32"),
T_transpose: T.Buffer((32, 224, 64, 224), "float32"),
Expand All @@ -690,7 +690,7 @@ def expected(


def test_op_pad():
@T.prim_func
@T.prim_func(private=True)
def before(
arg: T.Buffer((32, 64, 224, 224), "float32"),
PadInput: T.Buffer((32, 64, 230, 230), "float32"),
Expand All @@ -706,7 +706,7 @@ def before(
T.float32(2),
)

@T.prim_func
@T.prim_func(private=True)
def expected(
arg: T.Buffer((32, 224, 224, 16, 4), "float32"),
PadInput: T.Buffer((32, 230, 230, 16, 4), "float32"),
Expand All @@ -730,7 +730,7 @@ def expected(


def test_op_split():
@T.prim_func
@T.prim_func(private=True)
def before(
arg: T.Buffer((32, 64, 224, 224), "float32"),
split0: T.Buffer((32, 32, 224, 224), "float32"),
Expand All @@ -749,7 +749,7 @@ def before(
T.writes(split1[v_ax0, v_ax1, v_ax2, v_ax3])
split1[v_ax0, v_ax1, v_ax2, v_ax3] = arg[v_ax0, v_ax1 + 32, v_ax2, v_ax3]

@T.prim_func
@T.prim_func(private=True)
def expected(
arg: T.Buffer((32, 224, 224, 64), "float32"),
split0: T.Buffer((32, 224, 224, 32), "float32"),
Expand Down Expand Up @@ -778,7 +778,7 @@ def expected(

@pytest.mark.skip("temp disable, due to minor arith regression")
def test_op_split_tiling_split_dim():
@T.prim_func
@T.prim_func(private=True)
def before(
arg: T.Buffer((32, 64, 224, 224), "float32"),
split0: T.Buffer((32, 32, 224, 224), "float32"),
Expand All @@ -797,7 +797,7 @@ def before(
T.writes(split1[v_ax0, v_ax1, v_ax2, v_ax3])
split1[v_ax0, v_ax1, v_ax2, v_ax3] = arg[v_ax0, v_ax1 + 32, v_ax2, v_ax3]

@T.prim_func
@T.prim_func(private=True)
def expected(
arg: T.Buffer((32, 224, 224, 16, 4), "float32"),
split0: T.Buffer((32, 224, 224, 8, 4), "float32"),
Expand Down
2 changes: 1 addition & 1 deletion tests/python/relax/test_backend_transform_shape_lower.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,7 @@ def main(

@tvm.script.ir_module
class Expected:
@T.prim_func
@T.prim_func(private=True)
def shape_func(H: T.Buffer(T.int64(4), "int64")):
# generated compute function
T.func_attr({"tir.is_host_func": 1})
Expand Down
2 changes: 1 addition & 1 deletion tests/python/relax/test_blockbuilder_emit_te.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def te_func(A, offset):

@I.ir_module
class Expected:
@T.prim_func
@T.prim_func(private=True)
def te_func(
A: T.Buffer((T.int64(10),), "float32"),
B: T.Buffer((T.int64(10),), "float32"),
Expand Down
2 changes: 1 addition & 1 deletion tests/python/relax/test_frontend_nn_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,7 @@ def test(self, x: Tensor):
# fmt: off
@I.ir_module
class Expected:
@T.prim_func
@T.prim_func(private=True)
def add_one(A: T.Buffer((T.int64(10), T.int64(10)), "float32"), T_add: T.Buffer((T.int64(10), T.int64(10)), "float32")):
T.func_attr({"tir.noalias": T.bool(True)})
# with T.block("root"):
Expand Down
6 changes: 3 additions & 3 deletions tests/python/relax/test_meta_schedule_relax_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def main(data: R.Tensor((1, 8, 8, 4), dtype="int32")) -> R.Tensor((1, 8, 8, 4),
# fmt: off
@I.ir_module
class Module:
@T.prim_func
@T.prim_func(private=True)
def conv2d(rxplaceholder: T.Buffer((T.int64(1), T.int64(8), T.int64(8), T.int64(4)), "int32"), DepthwiseConv2d: T.Buffer((T.int64(1), T.int64(8), T.int64(8), T.int64(4)), "int32")):
T.func_attr({"op_pattern": 4, "tir.noalias": True})
# with T.block("root"):
Expand All @@ -76,7 +76,7 @@ def conv2d(rxplaceholder: T.Buffer((T.int64(1), T.int64(8), T.int64(8), T.int64(
DepthwiseConv2d[v_b, v_i, v_j, v_c] = 0
DepthwiseConv2d[v_b, v_i, v_j, v_c] = DepthwiseConv2d[v_b, v_i, v_j, v_c] + PaddedInput[v_b, v_i + v_di, v_j + v_dj, v_c] * fused_constant_1[v_di, v_dj, v_c, T.int64(0)]

@T.prim_func
@T.prim_func(private=True)
def conv2d0(rxplaceholder0: T.Buffer((T.int64(1), T.int64(8), T.int64(8), T.int64(4)), "int32"), DepthwiseConv2d0: T.Buffer((T.int64(1), T.int64(8), T.int64(8), T.int64(4)), "int32")):
T.func_attr({"op_pattern": 4, "tir.noalias": True})
# with T.block("root"):
Expand All @@ -98,7 +98,7 @@ def conv2d0(rxplaceholder0: T.Buffer((T.int64(1), T.int64(8), T.int64(8), T.int6
DepthwiseConv2d0[v_b, v_i, v_j, v_c] = 0
DepthwiseConv2d0[v_b, v_i, v_j, v_c] = DepthwiseConv2d0[v_b, v_i, v_j, v_c] + PaddedInput0[v_b, v_i + v_di, v_j + v_dj, v_c] * fused_constant0_1[v_di, v_dj, v_c, T.int64(0)]

@T.prim_func
@T.prim_func(private=True)
def fused_conv2d_add(data: T.Buffer((T.int64(1), T.int64(8), T.int64(8), T.int64(4)), "int32"), T_add: T.Buffer((T.int64(1), T.int64(8), T.int64(8), T.int64(4)), "int32")):
T.func_attr({"tir.noalias": True})
# with T.block("root"):
Expand Down
Loading

0 comments on commit e1e5333

Please sign in to comment.