Skip to content

Commit

Permalink
[Fix][Op] Add groups to conv1d (#270)
Browse files Browse the repository at this point in the history
- Enable `groups` as parameters of `conv1d`
- Update the output shape of conv2d.
  • Loading branch information
LeshengJin authored and tqchen committed Aug 1, 2023
1 parent 4d37532 commit cca3bf0
Show file tree
Hide file tree
Showing 4 changed files with 19 additions and 15 deletions.
1 change: 1 addition & 0 deletions python/tvm/relax/transform/legalize_ops/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ def _nn_conv1d(bb: BlockBuilder, call: Call) -> Expr:
strides=call.attrs.strides,
padding=call.attrs.padding,
dilation=call.attrs.dilation,
groups=call.attrs.groups,
data_layout=call.attrs.data_layout,
kernel_layout=call.attrs.kernel_layout,
out_dtype=call.attrs.out_dtype if call.attrs.out_dtype != "" else None,
Expand Down
5 changes: 4 additions & 1 deletion python/tvm/topi/nn/conv1d.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ def conv1d(
strides=1,
padding="VALID",
dilation=1,
groups=1,
data_layout="NCW",
kernel_layout="",
out_dtype=None,
Expand Down Expand Up @@ -60,7 +61,9 @@ def conv1d(
out_dtype : str
The output data type. If None then output is same type as input.
"""
return conv(data, kernel, strides, padding, dilation, 1, data_layout, kernel_layout, out_dtype)
return conv(
data, kernel, strides, padding, dilation, groups, data_layout, kernel_layout, out_dtype
)


def conv1d_nwc(data, kernel, strides=1, padding="VALID", dilation=1, out_dtype=None):
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/topi/nn/conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -885,7 +885,7 @@ def conv(
# compute the output shape
out_channel = num_filter
out_dimensions = [
simplify(d - (k - 1) * dil - 1 + pb + pe) // stride + 1
simplify((d - (k - 1) * dil - 1 + pb + pe) // stride + 1)
for d, k, dil, pb, pe, stride in zip(
dimensions, kernel_dimensions, dilations, pad_begin, pad_end, strides
)
Expand Down
26 changes: 13 additions & 13 deletions tests/python/relax/test_transform_legalize_ops_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,23 +44,23 @@ def main(x: R.Tensor((2, 128, 28), dtype="float32"), w: R.Tensor((64, 16, 3), dt
return gv

@T.prim_func
def conv1d(rxplaceholder: T.Buffer((T.int64(2), T.int64(128), T.int64(28)), "float32"), rxplaceholder_1: T.Buffer((T.int64(64), T.int64(16), T.int64(3)), "float32"), conv1d_ncw: T.Buffer((T.int64(2), T.int64(64), T.int64(13)), "float32")):
def conv1d(A: T.Buffer((T.int64(2), T.int64(128), T.int64(28)), "float32"), B: T.Buffer((T.int64(64), T.int64(16), T.int64(3)), "float32"), group_conv1d_ncw: T.Buffer((T.int64(2), T.int64(64), T.int64(13)), "float32")):
T.func_attr({"tir.noalias": True})
pad_temp = T.alloc_buffer((T.int64(2), T.int64(128), T.int64(30)))
for i0, i1, i2 in T.grid(T.int64(2), T.int64(128), T.int64(30)):
with T.block("pad_temp"):
v_i0, v_i1, v_i2 = T.axis.remap("SSS", [i0, i1, i2])
T.reads(rxplaceholder[v_i0, v_i1, v_i2 - T.int64(1)])
T.reads(A[v_i0, v_i1, v_i2 - T.int64(1)])
T.writes(pad_temp[v_i0, v_i1, v_i2])
pad_temp[v_i0, v_i1, v_i2] = T.if_then_else(T.int64(1) <= v_i2 and v_i2 < T.int64(29), rxplaceholder[v_i0, v_i1, v_i2 - T.int64(1)], T.float32(0))
for nn, ff, yy, rc, ry in T.grid(T.int64(2), T.int64(64), T.int64(13), T.int64(128), T.int64(3)):
with T.block("conv1d_ncw"):
pad_temp[v_i0, v_i1, v_i2] = T.if_then_else(T.int64(1) <= v_i2 and v_i2 < T.int64(29), A[v_i0, v_i1, v_i2 - T.int64(1)], T.float32(0))
for nn, ff, yy, rc, ry in T.grid(T.int64(2), T.int64(64), T.int64(13), T.int64(16), T.int64(3)):
with T.block("group_conv1d_ncw"):
v_nn, v_ff, v_yy, v_rc, v_ry = T.axis.remap("SSSRR", [nn, ff, yy, rc, ry])
T.reads(pad_temp[v_nn, v_rc, v_yy * T.int64(2) + v_ry * T.int64(2)], rxplaceholder_1[v_ff, v_rc, v_ry])
T.writes(conv1d_ncw[v_nn, v_ff, v_yy])
T.reads(pad_temp[v_nn, v_ff // T.int64(8) * T.int64(16) + v_rc, v_yy * T.int64(2) + v_ry * T.int64(2)], B[v_ff, v_rc, v_ry])
T.writes(group_conv1d_ncw[v_nn, v_ff, v_yy])
with T.init():
conv1d_ncw[v_nn, v_ff, v_yy] = T.float32(0)
conv1d_ncw[v_nn, v_ff, v_yy] = conv1d_ncw[v_nn, v_ff, v_yy] + pad_temp[v_nn, v_rc, v_yy * T.int64(2) + v_ry * T.int64(2)] * rxplaceholder_1[v_ff, v_rc, v_ry]
group_conv1d_ncw[v_nn, v_ff, v_yy] = T.float32(0)
group_conv1d_ncw[v_nn, v_ff, v_yy] = group_conv1d_ncw[v_nn, v_ff, v_yy] + pad_temp[v_nn, v_ff // T.int64(8) * T.int64(16) + v_rc, v_yy * T.int64(2) + v_ry * T.int64(2)] * B[v_ff, v_rc, v_ry]
# fmt: on

mod = LegalizeOps()(Conv1d)
Expand Down Expand Up @@ -171,7 +171,7 @@ def main(x: R.Tensor(("n", "c", "w"), dtype="float32"), kernel: R.Tensor(("f", "
w = T.int64()
kw = T.int64()
c = T.int64()
gv = R.call_tir(Expected.conv1d, (x, kernel), out_sinfo=R.Tensor((n, f, w - kw + 1), dtype="float32"))
gv = R.call_tir(Expected.conv1d, (x, kernel), out_sinfo=R.Tensor((n, f, w + 1 - kw ), dtype="float32"))
return gv

@T.prim_func
Expand All @@ -181,7 +181,7 @@ def conv1d(var_rxplaceholder: T.handle, var_rxplaceholder_1: T.handle, var_conv1
rxplaceholder = T.match_buffer(var_rxplaceholder, (n, c, w))
f, kw = T.int64(), T.int64()
rxplaceholder_1 = T.match_buffer(var_rxplaceholder_1, (f, c, kw))
conv1d_ncw = T.match_buffer(var_conv1d_ncw, (n, f, w - kw + T.int64(1)))
conv1d_ncw = T.match_buffer(var_conv1d_ncw, (n, f, w + T.int64(1) - kw))
# with T.block("root"):
pad_temp = T.alloc_buffer((n, c, w))
for i0, i1, i2 in T.grid(n, c, w):
Expand Down Expand Up @@ -349,7 +349,7 @@ def main(x: R.Tensor(("n", "c", "h", "w"), "float32"), kernel: R.Tensor(("f", "c
kh = T.int64()
w = T.int64()
kw = T.int64()
gv = R.call_tir(Expected.conv2d, (x, kernel), R.Tensor((n, f, ((h - kh) + 1), ((w - kw) + 1)), dtype="float32"))
gv = R.call_tir(Expected.conv2d, (x, kernel), R.Tensor((n, f, h + 1 - kh, w + 1 - kw), dtype="float32"))
return gv

@T.prim_func
Expand All @@ -364,7 +364,7 @@ def conv2d(var_rxplaceholder: T.handle, var_rxplaceholder_1: T.handle, var_conv2
w = T.int64()
rxplaceholder = T.match_buffer(var_rxplaceholder, [n, c, h, w], dtype="float32")
rxplaceholder_1 = T.match_buffer(var_rxplaceholder_1, [f, c, kh, kw], dtype="float32")
conv2d_nchw = T.match_buffer(var_conv2d_nchw, [n, f, h - kh + T.int64(1), w - kw + T.int64(1)], dtype="float32")
conv2d_nchw = T.match_buffer(var_conv2d_nchw, [n, f, h + T.int64(1) - kh , w + T.int64(1) - kw], dtype="float32")
pad_temp = T.alloc_buffer([n, c, h, w], dtype="float32")
for i0, i1, i2, i3 in T.grid(n, c, h, w):
with T.block("pad_temp"):
Expand Down

0 comments on commit cca3bf0

Please sign in to comment.