Skip to content

Commit

Permalink
[Unity] Fix BYOC codegen for dynamic shapes (#15750)
Browse files Browse the repository at this point in the history
  • Loading branch information
vinx13 authored Sep 14, 2023
1 parent bd51b5b commit 25d6c45
Show file tree
Hide file tree
Showing 2 changed files with 81 additions and 6 deletions.
10 changes: 5 additions & 5 deletions src/relax/transform/run_codegen.cc
Original file line number Diff line number Diff line change
Expand Up @@ -85,16 +85,16 @@ class CodeGenRunner : ExprMutator {
return Call(call_op, new_args, tvm::Attrs(), {ret_struct_info});
};

auto ret_sinfo = GetStructInfo(call);
if (auto it = extern_funcs_.find(gvar_node); it != extern_funcs_.end()) {
return create_call_dps_packed(it->second.first, it->second.second);
return create_call_dps_packed(it->second, ret_sinfo);
} else {
// TODO(@sunggg): Is there any better way to get this func?
Function func = Downcast<Function>(builder_->GetContextIRModule()->Lookup(gvar));
Expr new_func = VisitExpr(func);

if (new_func->IsInstance<ExternFuncNode>()) {
auto ret_sinfo = GetStructInfo(call);
extern_funcs_[gvar_node] = {new_func, ret_sinfo};
extern_funcs_[gvar_node] = new_func;
// Remove the global symbol and codegen attributes from the function so that it can be
// removed the module.
static const runtime::PackedFunc* RemoveFuncAttrFunc =
Expand Down Expand Up @@ -173,8 +173,8 @@ class CodeGenRunner : ExprMutator {

/*! \brief The names of all constants in the original module. */
Map<Constant, String> constant_names;
/*! \brief Extern funcs and their return struct infos for each global variable. */
std::unordered_map<const GlobalVarNode*, std::pair<Expr, StructInfo>> extern_funcs_;
/*! \brief Extern funcs for each global variable. */
std::unordered_map<const GlobalVarNode*, Expr> extern_funcs_;
};

} // namespace relax
Expand Down
77 changes: 76 additions & 1 deletion tests/python/relax/test_transform_codegen_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
import tvm.testing
from tvm import relax, tir
import numpy as np
from tvm.script import relax as R
from tvm.script import relax as R, ir as I, tir as T
from tvm.relax.testing import transform
import tempfile
from tvm.relax.transform.tuning_api import Trace
Expand Down Expand Up @@ -248,6 +248,81 @@ def test_multiple_calls_same_extern():
tvm.ir.assert_structural_equal(mod["main"], Conv2dx2_after["main"])


def test_dynamic_shape():
import tvm.relax.backend.contrib.cublas

@I.ir_module
class Before:
@R.function
def main(
x: R.Tensor((1, 4096), dtype="float16"),
w1: R.Tensor((4096, "r1"), dtype="float16"),
w2: R.Tensor((4096, "r2"), dtype="float16"),
) -> R.Tuple(R.Tensor((1, "r1"), dtype="float16"), R.Tensor((1, "r2"), dtype="float16")):
r1 = T.int64()
r2 = T.int64()
cls = Before
with R.dataflow():
lv: R.Tensor((1, r1), dtype="float16") = cls.fused_relax_matmul_cublas(x, w1)
lv1: R.Tensor((1, r2), dtype="float16") = cls.fused_relax_matmul_cublas(x, w2)
gv: R.Tuple(
R.Tensor((1, r1), dtype="float16"), R.Tensor((1, r2), dtype="float16")
) = (lv, lv1)
R.output(gv)
return gv

@R.function
def fused_relax_matmul_cublas(
x: R.Tensor((1, 4096), dtype="float16"), w1: R.Tensor((4096, "r1"), dtype="float16")
) -> R.Tensor((1, "r1"), dtype="float16"):
r1 = T.int64()
R.func_attr({"Codegen": "cublas"})

@R.function
def gv(
x_1: R.Tensor((1, 4096), dtype="float16"),
w1_1: R.Tensor((4096, r1), dtype="float16"),
) -> R.Tensor((1, r1), dtype="float16"):
R.func_attr({"Composite": "cublas.matmul"})
with R.dataflow():
gv_1: R.Tensor((1, r1), dtype="float16") = R.matmul(x_1, w1_1, out_dtype="void")
R.output(gv_1)
return gv_1

gv1: R.Tensor((1, r1), dtype="float16") = gv(x, w1)
return gv1

@I.ir_module
class Expected:
@R.function
def main(
x: R.Tensor((1, 4096), dtype="float16"),
w1: R.Tensor((4096, "r1"), dtype="float16"),
w2: R.Tensor((4096, "r2"), dtype="float16"),
) -> R.Tuple(R.Tensor((1, "r1"), dtype="float16"), R.Tensor((1, "r2"), dtype="float16")):
r1 = T.int64()
r2 = T.int64()
with R.dataflow():
lv = R.call_dps_packed(
"fused_relax_matmul_cublas",
(x, w1),
out_sinfo=R.Tensor((1, r1), dtype="float16"),
)
lv1 = R.call_dps_packed(
"fused_relax_matmul_cublas",
(x, w2),
out_sinfo=R.Tensor((1, r2), dtype="float16"),
)
gv: R.Tuple(
R.Tensor((1, r1), dtype="float16"), R.Tensor((1, r2), dtype="float16")
) = (lv, lv1)
R.output(gv)
return gv

after = relax.transform.RunCodegen()(Before)
tvm.ir.assert_structural_equal(after["main"], Expected["main"])


# TODO(@sunggg): test with more complex patterns (e.g., multiple annots, mixed codegens, different ops, const binding)

if __name__ == "__main__":
Expand Down

0 comments on commit 25d6c45

Please sign in to comment.