Skip to content

Commit

Permalink
[Unity] Unwrap tuples during CanonicalizeBindings (#15696)
Browse files Browse the repository at this point in the history
Prior to this commit, `CanonicalizeBindings` would identify trivial
bindings of `var_y = var_x` and replace subsequent usage of `var_y`
with `var_x`, but bindings of the form `tuple_var = (var_x,); var_y =
var_x[0]` would not be canonicalized.

This commit updates the `CanonicalizeBindings` pass to identify
trivial bindings that occur across a `TupleGetItem` expression,
handling them as if they were direct assignments.
  • Loading branch information
Lunderberg authored Sep 15, 2023
1 parent afb2e42 commit 21cdf6c
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 0 deletions.
19 changes: 19 additions & 0 deletions src/relax/transform/canonicalize_bindings.cc
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@ class BindingCanonicalizer : public ExprMutator {
public:
BindingCanonicalizer() {}

using ExprMutator::VisitExpr_;

Expr VisitExpr_(const VarNode* op) override {
// remap first
Var v = Downcast<Var>(ExprMutator::VisitExpr_(op));
Expand All @@ -54,6 +56,23 @@ class BindingCanonicalizer : public ExprMutator {
return ExprMutator::VisitExpr_(LookupBinding(v).as<VarNode>());
}

Expr VisitExpr_(const TupleGetItemNode* tuple_get_item) override {
if (auto tuple_var = tuple_get_item->tuple.as<Var>()) {
if (auto tuple_value = LookupBinding(tuple_var.value())) {
if (auto explicit_tuple = tuple_value.as<TupleNode>()) {
CHECK_GE(tuple_get_item->index, 0)
<< "Tuple " << tuple_value << " is accessed at index " << tuple_get_item->index
<< ", but negative indices are not supported in this context.";
CHECK_LT(tuple_get_item->index, explicit_tuple->fields.size())
<< "Tuple " << tuple_value << " is accessed at index " << tuple_get_item->index
<< ", but the tuple size is only " << explicit_tuple->fields.size();
return VisitExpr(explicit_tuple->fields[tuple_get_item->index]);
}
}
}
return ExprMutator::VisitExpr_(tuple_get_item);
}

void VisitBinding_(const VarBindingNode* binding) override {
// Unlike default visitor, we do not permit the checked type to change
// if the new value's checked type is different (this preserves user annotations)
Expand Down
25 changes: 25 additions & 0 deletions tests/python/relax/test_transform_canonicalize_bindings.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,5 +254,30 @@ def main(x: R.Tensor(("m", "n"))):
assert_structural_equal(new_mod, Expected)


def test_unwrap_tuple():
@tvm.script.ir_module
class Before:
@R.function
def main(x: R.Tensor, y: R.Tensor):
tuple_var = (x, y)
w = tuple_var[0]
q = tuple_var[1]
z = R.add(w, q)
return R.add(q, z)

@tvm.script.ir_module
class Expected:
@R.function
def main(x: R.Tensor, y: R.Tensor):
tuple_var = (x, y)
w = x
q = y
z = R.add(x, y)
return R.add(y, z)

after = relax.transform.CanonicalizeBindings()(Before)
assert_structural_equal(Expected, after)


if __name__ == "__main__":
tvm.testing.main()

0 comments on commit 21cdf6c

Please sign in to comment.