Skip to content

Commit

Permalink
[UNITY] Handle pattern with remove_pad (#15843)
Browse files Browse the repository at this point in the history
* [UNITY] Handle pattern with remove_pad

After PR #15679 added support for padding
reversal in AlterOpImpl, a new pattern is created in graph
 transform_layout -> remove_pad -> transform_layout
The current change is meant to handle the pattern

* Fix LINT errors
  • Loading branch information
abhikran-quic authored Oct 3, 2023
1 parent 8dda1cb commit 4735a5b
Show file tree
Hide file tree
Showing 2 changed files with 161 additions and 5 deletions.
23 changes: 18 additions & 5 deletions python/tvm/relax/transform/optimize_layout_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,11 @@
# under the License.
# pylint: disable=invalid-name, unused-argument, redefined-argument-from-local
"""Relax Optimize Layout Transform pass."""
from tvm.ir import structural_equal
from tvm.ir.module import IRModule
from tvm.ir.transform import PassContext
from tvm.relax import Expr, Function
from tvm.relax.dpl import is_op, rewrite_call, wildcard
from tvm.relax.dpl import is_op, rewrite_call, wildcard, TuplePattern
from . import function_pass


Expand All @@ -35,7 +36,12 @@ def __init__(self):
pattern_transform_layout = is_op("relax.layout_transform")(self.input)
pattern_1 = is_op("relax.layout_transform")(pattern_transform_layout)

self.pattern = pattern_1
self.gv_ = wildcard()
args = TuplePattern([pattern_transform_layout])
pattern_2 = is_op("relax.call_tir")(self.gv_, args)
self.pattern_2 = is_op("relax.layout_transform")(pattern_2)

self.pattern = pattern_1 | self.pattern_2

def transform_function(self, func: Expr, mod: IRModule, ctx: PassContext) -> IRModule:
"""
Expand All @@ -54,6 +60,7 @@ def transform_function(self, func: Expr, mod: IRModule, ctx: PassContext) -> IRM
Relax pass context
"""

self.mod = mod
updated_func = func
for _, func in mod.functions.items():
# Skip non-relax functions
Expand All @@ -65,9 +72,15 @@ def transform_function(self, func: Expr, mod: IRModule, ctx: PassContext) -> IRM

def rewriter(expr, matches):
arg1 = matches[self.pattern]
arg2 = matches[self.input]
if list(arg1.struct_info.shape) == list(arg2.struct_info.shape):
return arg2
if self.pattern_2 not in matches.keys():
arg2 = matches[self.input]
else:
arg2 = matches[self.gv_]
if "remove_pad" == self.mod[arg2].attrs["operator_name"]:
arg2 = matches[self.input]
if hasattr(arg1.struct_info, "shape") and hasattr(arg2.struct_info, "shape"):
if structural_equal(arg1.struct_info.shape, arg2.struct_info.shape):
return arg2
return expr

updated_func = rewrite_call(self.pattern, rewriter, func)
Expand Down
143 changes: 143 additions & 0 deletions tests/python/relax/test_optimize_layout_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,5 +273,148 @@ def main(
_run_pass_compare_output(Before, Expected)


def test_tranform_layout_tir_remove_pad_transform_layout():
@I.ir_module
class Before:
@T.prim_func(private=True)
def relax_relu_replacement(
arg0: T.Buffer((16,), "float32"), output: T.Buffer((16,), "float32")
):
T.func_attr({"operator_name": "relax.relu"})
# with T.block("root"):
for ax0 in range(16):
with T.block("T_add"):
v_ax0 = T.axis.spatial(16, ax0)
T.reads(arg0[v_ax0])
T.writes(output[v_ax0])
output[v_ax0] = T.max(arg0[v_ax0], T.float32(0))

@T.prim_func(private=True)
def remove_pad(var_input: T.handle, var_output: T.handle):
T.func_attr({"operator_name": "remove_pad", "tir.noalias": T.bool(True)})
p0 = T.int64()
input = T.match_buffer(var_input, (p0,))
i0 = T.int64()
output = T.match_buffer(var_output, (i0,))
# with T.block("root"):
for ax0 in range(i0):
with T.block("output"):
v_ax0 = T.axis.spatial(i0, ax0)
T.reads(input[v_ax0])
T.writes(output[v_ax0])
output[v_ax0] = input[v_ax0]

@R.function
def main(x: R.Tensor((14,), dtype="float32")) -> R.Tensor((14,), dtype="float32"):
with R.dataflow():
lv: R.Tensor((16,), dtype="float32") = R.layout_transform(
x,
index_map=T.index_map(lambda i: (i % 16,)),
pad_value=None,
axis_separators=[],
)
lv1 = R.call_tir(
Before.relax_relu_replacement,
(lv,),
out_sinfo=R.Tensor((16,), dtype="float32"),
)
lv2: R.Tensor((16,), dtype="float32") = R.layout_transform(
lv1,
index_map=T.index_map(lambda axis0: (axis0,)),
pad_value=None,
axis_separators=[],
)
lv_1 = R.call_tir(
Before.remove_pad, (lv2,), out_sinfo=R.Tensor((14,), dtype="float32")
)
lv3: R.Tensor((16,), dtype="float32") = R.layout_transform(
lv_1,
index_map=T.index_map(lambda i: (i % 16,)),
pad_value=None,
axis_separators=[],
)
lv4 = R.call_tir(
Before.relax_relu_replacement,
(lv3,),
out_sinfo=R.Tensor((16,), dtype="float32"),
)
lv5: R.Tensor((16,), dtype="float32") = R.layout_transform(
lv4,
index_map=T.index_map(lambda axis0: (axis0,)),
pad_value=None,
axis_separators=[],
)
lv_2 = R.call_tir(
Before.remove_pad, (lv5,), out_sinfo=R.Tensor((14,), dtype="float32")
)
gv: R.Tensor((14,), dtype="float32") = lv_2
R.output(gv)
return gv

@I.ir_module
class Expected:
@T.prim_func(private=True)
def relax_relu_replacement(
arg0: T.Buffer((16,), "float32"), output: T.Buffer((16,), "float32")
):
T.func_attr({"operator_name": "relax.relu"})
# with T.block("root"):
for ax0 in range(16):
with T.block("T_add"):
v_ax0 = T.axis.spatial(16, ax0)
T.reads(arg0[v_ax0])
T.writes(output[v_ax0])
output[v_ax0] = T.max(arg0[v_ax0], T.float32(0))

@T.prim_func(private=True)
def remove_pad(var_input: T.handle, var_output: T.handle):
T.func_attr({"operator_name": "remove_pad", "tir.noalias": T.bool(True)})
p0 = T.int64()
input = T.match_buffer(var_input, (p0,))
i0 = T.int64()
output = T.match_buffer(var_output, (i0,))
# with T.block("root"):
for ax0 in range(i0):
with T.block("output"):
v_ax0 = T.axis.spatial(i0, ax0)
T.reads(input[v_ax0])
T.writes(output[v_ax0])
output[v_ax0] = input[v_ax0]

@R.function
def main(x: R.Tensor((14,), dtype="float32")) -> R.Tensor((14,), dtype="float32"):
with R.dataflow():
lv: R.Tensor((16,), dtype="float32") = R.layout_transform(
x,
index_map=T.index_map(lambda i: (i % 16,)),
pad_value=None,
axis_separators=[],
)
lv1 = R.call_tir(
Expected.relax_relu_replacement,
(lv,),
out_sinfo=R.Tensor((16,), dtype="float32"),
)
lv4 = R.call_tir(
Expected.relax_relu_replacement,
(lv1,),
out_sinfo=R.Tensor((16,), dtype="float32"),
)
lv5: R.Tensor((16,), dtype="float32") = R.layout_transform(
lv4,
index_map=T.index_map(lambda axis0: (axis0,)),
pad_value=None,
axis_separators=[],
)
lv_2 = R.call_tir(
Expected.remove_pad, (lv5,), out_sinfo=R.Tensor((14,), dtype="float32")
)
gv: R.Tensor((14,), dtype="float32") = lv_2
R.output(gv)
return gv

_run_pass_compare_output(Before, Expected)


if __name__ == "__main__":
tvm.testing.main()

0 comments on commit 4735a5b

Please sign in to comment.