diff --git a/python/tvm/relax/transform/legalize_ops/nn.py b/python/tvm/relax/transform/legalize_ops/nn.py index 5386fbf7cb..6263e4490f 100644 --- a/python/tvm/relax/transform/legalize_ops/nn.py +++ b/python/tvm/relax/transform/legalize_ops/nn.py @@ -60,6 +60,7 @@ def _nn_conv1d(bb: BlockBuilder, call: Call) -> Expr: strides=call.attrs.strides, padding=call.attrs.padding, dilation=call.attrs.dilation, + groups=call.attrs.groups, data_layout=call.attrs.data_layout, kernel_layout=call.attrs.kernel_layout, out_dtype=call.attrs.out_dtype if call.attrs.out_dtype != "" else None, diff --git a/python/tvm/topi/nn/conv1d.py b/python/tvm/topi/nn/conv1d.py index ee388b4297..af856c2ed5 100644 --- a/python/tvm/topi/nn/conv1d.py +++ b/python/tvm/topi/nn/conv1d.py @@ -25,6 +25,7 @@ def conv1d( strides=1, padding="VALID", dilation=1, + groups=1, data_layout="NCW", kernel_layout="", out_dtype=None, @@ -60,7 +61,9 @@ def conv1d( out_dtype : str The output data type. If None then output is same type as input. """ - return conv(data, kernel, strides, padding, dilation, 1, data_layout, kernel_layout, out_dtype) + return conv( + data, kernel, strides, padding, dilation, groups, data_layout, kernel_layout, out_dtype + ) def conv1d_nwc(data, kernel, strides=1, padding="VALID", dilation=1, out_dtype=None): diff --git a/python/tvm/topi/nn/conv2d.py b/python/tvm/topi/nn/conv2d.py index f70d749e0f..0bb298ecce 100644 --- a/python/tvm/topi/nn/conv2d.py +++ b/python/tvm/topi/nn/conv2d.py @@ -885,7 +885,7 @@ def conv( # compute the output shape out_channel = num_filter out_dimensions = [ - simplify(d - (k - 1) * dil - 1 + pb + pe) // stride + 1 + simplify((d - (k - 1) * dil - 1 + pb + pe) // stride + 1) for d, k, dil, pb, pe, stride in zip( dimensions, kernel_dimensions, dilations, pad_begin, pad_end, strides ) diff --git a/tests/python/relax/test_transform_legalize_ops_nn.py b/tests/python/relax/test_transform_legalize_ops_nn.py index d750901b59..311ca9d487 100644 --- a/tests/python/relax/test_transform_legalize_ops_nn.py +++ b/tests/python/relax/test_transform_legalize_ops_nn.py @@ -44,23 +44,23 @@ def main(x: R.Tensor((2, 128, 28), dtype="float32"), w: R.Tensor((64, 16, 3), dt return gv @T.prim_func - def conv1d(rxplaceholder: T.Buffer((T.int64(2), T.int64(128), T.int64(28)), "float32"), rxplaceholder_1: T.Buffer((T.int64(64), T.int64(16), T.int64(3)), "float32"), conv1d_ncw: T.Buffer((T.int64(2), T.int64(64), T.int64(13)), "float32")): + def conv1d(A: T.Buffer((T.int64(2), T.int64(128), T.int64(28)), "float32"), B: T.Buffer((T.int64(64), T.int64(16), T.int64(3)), "float32"), group_conv1d_ncw: T.Buffer((T.int64(2), T.int64(64), T.int64(13)), "float32")): T.func_attr({"tir.noalias": True}) pad_temp = T.alloc_buffer((T.int64(2), T.int64(128), T.int64(30))) for i0, i1, i2 in T.grid(T.int64(2), T.int64(128), T.int64(30)): with T.block("pad_temp"): v_i0, v_i1, v_i2 = T.axis.remap("SSS", [i0, i1, i2]) - T.reads(rxplaceholder[v_i0, v_i1, v_i2 - T.int64(1)]) + T.reads(A[v_i0, v_i1, v_i2 - T.int64(1)]) T.writes(pad_temp[v_i0, v_i1, v_i2]) - pad_temp[v_i0, v_i1, v_i2] = T.if_then_else(T.int64(1) <= v_i2 and v_i2 < T.int64(29), rxplaceholder[v_i0, v_i1, v_i2 - T.int64(1)], T.float32(0)) - for nn, ff, yy, rc, ry in T.grid(T.int64(2), T.int64(64), T.int64(13), T.int64(128), T.int64(3)): - with T.block("conv1d_ncw"): + pad_temp[v_i0, v_i1, v_i2] = T.if_then_else(T.int64(1) <= v_i2 and v_i2 < T.int64(29), A[v_i0, v_i1, v_i2 - T.int64(1)], T.float32(0)) + for nn, ff, yy, rc, ry in T.grid(T.int64(2), T.int64(64), T.int64(13), T.int64(16), T.int64(3)): + with T.block("group_conv1d_ncw"): v_nn, v_ff, v_yy, v_rc, v_ry = T.axis.remap("SSSRR", [nn, ff, yy, rc, ry]) - T.reads(pad_temp[v_nn, v_rc, v_yy * T.int64(2) + v_ry * T.int64(2)], rxplaceholder_1[v_ff, v_rc, v_ry]) - T.writes(conv1d_ncw[v_nn, v_ff, v_yy]) + T.reads(pad_temp[v_nn, v_ff // T.int64(8) * T.int64(16) + v_rc, v_yy * T.int64(2) + v_ry * T.int64(2)], B[v_ff, v_rc, v_ry]) + T.writes(group_conv1d_ncw[v_nn, v_ff, v_yy]) with T.init(): - conv1d_ncw[v_nn, v_ff, v_yy] = T.float32(0) - conv1d_ncw[v_nn, v_ff, v_yy] = conv1d_ncw[v_nn, v_ff, v_yy] + pad_temp[v_nn, v_rc, v_yy * T.int64(2) + v_ry * T.int64(2)] * rxplaceholder_1[v_ff, v_rc, v_ry] + group_conv1d_ncw[v_nn, v_ff, v_yy] = T.float32(0) + group_conv1d_ncw[v_nn, v_ff, v_yy] = group_conv1d_ncw[v_nn, v_ff, v_yy] + pad_temp[v_nn, v_ff // T.int64(8) * T.int64(16) + v_rc, v_yy * T.int64(2) + v_ry * T.int64(2)] * B[v_ff, v_rc, v_ry] # fmt: on mod = LegalizeOps()(Conv1d) @@ -171,7 +171,7 @@ def main(x: R.Tensor(("n", "c", "w"), dtype="float32"), kernel: R.Tensor(("f", " w = T.int64() kw = T.int64() c = T.int64() - gv = R.call_tir(Expected.conv1d, (x, kernel), out_sinfo=R.Tensor((n, f, w - kw + 1), dtype="float32")) + gv = R.call_tir(Expected.conv1d, (x, kernel), out_sinfo=R.Tensor((n, f, w + 1 - kw ), dtype="float32")) return gv @T.prim_func @@ -181,7 +181,7 @@ def conv1d(var_rxplaceholder: T.handle, var_rxplaceholder_1: T.handle, var_conv1 rxplaceholder = T.match_buffer(var_rxplaceholder, (n, c, w)) f, kw = T.int64(), T.int64() rxplaceholder_1 = T.match_buffer(var_rxplaceholder_1, (f, c, kw)) - conv1d_ncw = T.match_buffer(var_conv1d_ncw, (n, f, w - kw + T.int64(1))) + conv1d_ncw = T.match_buffer(var_conv1d_ncw, (n, f, w + T.int64(1) - kw)) # with T.block("root"): pad_temp = T.alloc_buffer((n, c, w)) for i0, i1, i2 in T.grid(n, c, w): @@ -349,7 +349,7 @@ def main(x: R.Tensor(("n", "c", "h", "w"), "float32"), kernel: R.Tensor(("f", "c kh = T.int64() w = T.int64() kw = T.int64() - gv = R.call_tir(Expected.conv2d, (x, kernel), R.Tensor((n, f, ((h - kh) + 1), ((w - kw) + 1)), dtype="float32")) + gv = R.call_tir(Expected.conv2d, (x, kernel), R.Tensor((n, f, h + 1 - kh, w + 1 - kw), dtype="float32")) return gv @T.prim_func @@ -364,7 +364,7 @@ def conv2d(var_rxplaceholder: T.handle, var_rxplaceholder_1: T.handle, var_conv2 w = T.int64() rxplaceholder = T.match_buffer(var_rxplaceholder, [n, c, h, w], dtype="float32") rxplaceholder_1 = T.match_buffer(var_rxplaceholder_1, [f, c, kh, kw], dtype="float32") - conv2d_nchw = T.match_buffer(var_conv2d_nchw, [n, f, h - kh + T.int64(1), w - kw + T.int64(1)], dtype="float32") + conv2d_nchw = T.match_buffer(var_conv2d_nchw, [n, f, h + T.int64(1) - kh , w + T.int64(1) - kw], dtype="float32") pad_temp = T.alloc_buffer([n, c, h, w], dtype="float32") for i0, i1, i2, i3 in T.grid(n, c, h, w): with T.block("pad_temp"):