From 21cdf6c1c40ca03d310871de59c8cd5e528c19b9 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Fri, 15 Sep 2023 13:12:55 -0500 Subject: [PATCH] [Unity] Unwrap tuples during CanonicalizeBindings (#15696) Prior to this commit, `CanonicalizeBindings` would identify trivial bindings of `var_y = var_x` and replace subsequent usage of `var_y` with `var_x`, but bindings of the form `tuple_var = (var_x,); var_y = var_x[0]` would not be canonicalized. This commit updates the `CanonicalizeBindings` pass to identify trivial bindings that occur across a `TupleGetItem` expression, handling them as if they were direct assignments. --- src/relax/transform/canonicalize_bindings.cc | 19 ++++++++++++++ .../test_transform_canonicalize_bindings.py | 25 +++++++++++++++++++ 2 files changed, 44 insertions(+) diff --git a/src/relax/transform/canonicalize_bindings.cc b/src/relax/transform/canonicalize_bindings.cc index 2da72a4f5a..d355c09786 100644 --- a/src/relax/transform/canonicalize_bindings.cc +++ b/src/relax/transform/canonicalize_bindings.cc @@ -36,6 +36,8 @@ class BindingCanonicalizer : public ExprMutator { public: BindingCanonicalizer() {} + using ExprMutator::VisitExpr_; + Expr VisitExpr_(const VarNode* op) override { // remap first Var v = Downcast(ExprMutator::VisitExpr_(op)); @@ -54,6 +56,23 @@ class BindingCanonicalizer : public ExprMutator { return ExprMutator::VisitExpr_(LookupBinding(v).as()); } + Expr VisitExpr_(const TupleGetItemNode* tuple_get_item) override { + if (auto tuple_var = tuple_get_item->tuple.as()) { + if (auto tuple_value = LookupBinding(tuple_var.value())) { + if (auto explicit_tuple = tuple_value.as()) { + CHECK_GE(tuple_get_item->index, 0) + << "Tuple " << tuple_value << " is accessed at index " << tuple_get_item->index + << ", but negative indices are not supported in this context."; + CHECK_LT(tuple_get_item->index, explicit_tuple->fields.size()) + << "Tuple " << tuple_value << " is accessed at index " << tuple_get_item->index + << ", but the tuple size is only " << explicit_tuple->fields.size(); + return VisitExpr(explicit_tuple->fields[tuple_get_item->index]); + } + } + } + return ExprMutator::VisitExpr_(tuple_get_item); + } + void VisitBinding_(const VarBindingNode* binding) override { // Unlike default visitor, we do not permit the checked type to change // if the new value's checked type is different (this preserves user annotations) diff --git a/tests/python/relax/test_transform_canonicalize_bindings.py b/tests/python/relax/test_transform_canonicalize_bindings.py index 5e1d1b881e..91396ccb13 100644 --- a/tests/python/relax/test_transform_canonicalize_bindings.py +++ b/tests/python/relax/test_transform_canonicalize_bindings.py @@ -254,5 +254,30 @@ def main(x: R.Tensor(("m", "n"))): assert_structural_equal(new_mod, Expected) +def test_unwrap_tuple(): + @tvm.script.ir_module + class Before: + @R.function + def main(x: R.Tensor, y: R.Tensor): + tuple_var = (x, y) + w = tuple_var[0] + q = tuple_var[1] + z = R.add(w, q) + return R.add(q, z) + + @tvm.script.ir_module + class Expected: + @R.function + def main(x: R.Tensor, y: R.Tensor): + tuple_var = (x, y) + w = x + q = y + z = R.add(x, y) + return R.add(y, z) + + after = relax.transform.CanonicalizeBindings()(Before) + assert_structural_equal(Expected, after) + + if __name__ == "__main__": tvm.testing.main()