diff --git a/src/relax/transform/run_codegen.cc b/src/relax/transform/run_codegen.cc index fa726b82af..9955b5f483 100644 --- a/src/relax/transform/run_codegen.cc +++ b/src/relax/transform/run_codegen.cc @@ -85,16 +85,16 @@ class CodeGenRunner : ExprMutator { return Call(call_op, new_args, tvm::Attrs(), {ret_struct_info}); }; + auto ret_sinfo = GetStructInfo(call); if (auto it = extern_funcs_.find(gvar_node); it != extern_funcs_.end()) { - return create_call_dps_packed(it->second.first, it->second.second); + return create_call_dps_packed(it->second, ret_sinfo); } else { // TODO(@sunggg): Is there any better way to get this func? Function func = Downcast(builder_->GetContextIRModule()->Lookup(gvar)); Expr new_func = VisitExpr(func); if (new_func->IsInstance()) { - auto ret_sinfo = GetStructInfo(call); - extern_funcs_[gvar_node] = {new_func, ret_sinfo}; + extern_funcs_[gvar_node] = new_func; // Remove the global symbol and codegen attributes from the function so that it can be // removed the module. static const runtime::PackedFunc* RemoveFuncAttrFunc = @@ -173,8 +173,8 @@ class CodeGenRunner : ExprMutator { /*! \brief The names of all constants in the original module. */ Map constant_names; - /*! \brief Extern funcs and their return struct infos for each global variable. */ - std::unordered_map> extern_funcs_; + /*! \brief Extern funcs for each global variable. */ + std::unordered_map extern_funcs_; }; } // namespace relax diff --git a/tests/python/relax/test_transform_codegen_pass.py b/tests/python/relax/test_transform_codegen_pass.py index 77756dc664..d103291388 100644 --- a/tests/python/relax/test_transform_codegen_pass.py +++ b/tests/python/relax/test_transform_codegen_pass.py @@ -21,7 +21,7 @@ import tvm.testing from tvm import relax, tir import numpy as np -from tvm.script import relax as R +from tvm.script import relax as R, ir as I, tir as T from tvm.relax.testing import transform import tempfile from tvm.relax.transform.tuning_api import Trace @@ -248,6 +248,81 @@ def test_multiple_calls_same_extern(): tvm.ir.assert_structural_equal(mod["main"], Conv2dx2_after["main"]) +def test_dynamic_shape(): + import tvm.relax.backend.contrib.cublas + + @I.ir_module + class Before: + @R.function + def main( + x: R.Tensor((1, 4096), dtype="float16"), + w1: R.Tensor((4096, "r1"), dtype="float16"), + w2: R.Tensor((4096, "r2"), dtype="float16"), + ) -> R.Tuple(R.Tensor((1, "r1"), dtype="float16"), R.Tensor((1, "r2"), dtype="float16")): + r1 = T.int64() + r2 = T.int64() + cls = Before + with R.dataflow(): + lv: R.Tensor((1, r1), dtype="float16") = cls.fused_relax_matmul_cublas(x, w1) + lv1: R.Tensor((1, r2), dtype="float16") = cls.fused_relax_matmul_cublas(x, w2) + gv: R.Tuple( + R.Tensor((1, r1), dtype="float16"), R.Tensor((1, r2), dtype="float16") + ) = (lv, lv1) + R.output(gv) + return gv + + @R.function + def fused_relax_matmul_cublas( + x: R.Tensor((1, 4096), dtype="float16"), w1: R.Tensor((4096, "r1"), dtype="float16") + ) -> R.Tensor((1, "r1"), dtype="float16"): + r1 = T.int64() + R.func_attr({"Codegen": "cublas"}) + + @R.function + def gv( + x_1: R.Tensor((1, 4096), dtype="float16"), + w1_1: R.Tensor((4096, r1), dtype="float16"), + ) -> R.Tensor((1, r1), dtype="float16"): + R.func_attr({"Composite": "cublas.matmul"}) + with R.dataflow(): + gv_1: R.Tensor((1, r1), dtype="float16") = R.matmul(x_1, w1_1, out_dtype="void") + R.output(gv_1) + return gv_1 + + gv1: R.Tensor((1, r1), dtype="float16") = gv(x, w1) + return gv1 + + @I.ir_module + class Expected: + @R.function + def main( + x: R.Tensor((1, 4096), dtype="float16"), + w1: R.Tensor((4096, "r1"), dtype="float16"), + w2: R.Tensor((4096, "r2"), dtype="float16"), + ) -> R.Tuple(R.Tensor((1, "r1"), dtype="float16"), R.Tensor((1, "r2"), dtype="float16")): + r1 = T.int64() + r2 = T.int64() + with R.dataflow(): + lv = R.call_dps_packed( + "fused_relax_matmul_cublas", + (x, w1), + out_sinfo=R.Tensor((1, r1), dtype="float16"), + ) + lv1 = R.call_dps_packed( + "fused_relax_matmul_cublas", + (x, w2), + out_sinfo=R.Tensor((1, r2), dtype="float16"), + ) + gv: R.Tuple( + R.Tensor((1, r1), dtype="float16"), R.Tensor((1, r2), dtype="float16") + ) = (lv, lv1) + R.output(gv) + return gv + + after = relax.transform.RunCodegen()(Before) + tvm.ir.assert_structural_equal(after["main"], Expected["main"]) + + # TODO(@sunggg): test with more complex patterns (e.g., multiple annots, mixed codegens, different ops, const binding) if __name__ == "__main__":