diff --git a/chapter_auto_program_optimization/index.md b/chapter_auto_program_optimization/index.md index e485440..d907d57 100644 --- a/chapter_auto_program_optimization/index.md +++ b/chapter_auto_program_optimization/index.md @@ -22,8 +22,6 @@ from tvm.ir.module import IRModule from tvm.script import tir as T, relax as R import numpy as np from tvm import relax -# This is needed for deferring annotation parsing in TVMScript -from __future__ import annotations ``` ```{.python .input n=1} @@ -48,9 +46,9 @@ def code2html(code): class MyModule: @T.prim_func def main( - A: T.Buffer[(128, 128), "float32"], - B: T.Buffer[(128, 128), "float32"], - C: T.Buffer[(128, 128), "float32"], + A: T.Buffer((128, 128), "float32"), + B: T.Buffer((128, 128), "float32"), + C: T.Buffer((128, 128), "float32"), ): T.func_attr({"global_symbol": "main", "tir.noalias": True}) for i, j, k in T.grid(128, 128, 128): @@ -295,31 +293,31 @@ print(sch.trace) ```{.python .input n=25} from tvm import meta_schedule as ms -sch_tuned = ms.tune_tir( +database = ms.tune_tir( mod=MyModule, target="llvm --num-cores=1", - config=ms.TuneConfig( - max_trials_global=64, - num_trials_per_iter=64, - ), + max_trials_global=64, + num_trials_per_iter=64, space=ms.space_generator.ScheduleFn(stochastic_schedule_mm), work_dir="./tune_tmp", task_name="main" ) + +sch = ms.tir_integration.compile_tir(database, MyModule, "llvm --num-cores=1") ``` `tune_tir` 函数返回在调优过程中找到的优化后的调度。 ```{.python .input n=26} -print(sch_tuned.trace) +sch.trace.show() ``` ```{.python .input n=27} -IPython.display.HTML(code2html(sch_tuned.mod.script())) +IPython.display.HTML(code2html(sch.mod.script())) ``` ```{.python .input n=28} -lib = tvm.build(sch_tuned.mod, target="llvm") +lib = tvm.build(sch.mod, target="llvm") f_timer_after = lib.time_evaluator("main", tvm.cpu()) print("Time cost of MyModule after tuning: %.3f ms" % (f_timer_after(a_nd, b_nd, c_nd).mean * 1000)) ``` @@ -331,20 +329,19 @@ print("Time cost of MyModule after tuning: %.3f ms" % (f_timer_after(a_nd, b_nd, 在底层,Meta-Schedule 分析每个 TensorIR block 的数据访问和循环模式,并提出对程序的随机变换方式。我们不会在本章中讨论这些通用的变换,但要注意它们也只是随机转换加上代码分析而已。我们可以使用上一节中学到的相同机制来增强自动调度。我们将在以后的章节中触及这个主题。 ```{.python .input n=29} -sch_tuned = ms.tune_tir( +database = ms.tune_tir( mod=MyModule, target="llvm --num-cores=1", - config=ms.TuneConfig( - max_trials_global=64, - num_trials_per_iter=64, - ), + max_trials_global=64, + num_trials_per_iter=64, work_dir="./tune_tmp", task_name="main", ) +sch = ms.tir_integration.compile_tir(database, MyModule, "llvm --num-cores=1") ``` ```{.python .input n=30} -lib = tvm.build(sch_tuned.mod, target="llvm") +lib = tvm.build(sch.mod, target="llvm") f_timer_after = lib.time_evaluator("main", tvm.cpu()) print("Time cost of MyModule after tuning: %.3f ms" % (f_timer_after(a_nd, b_nd, c_nd).mean * 1000)) ``` @@ -356,11 +353,11 @@ print("Time cost of MyModule after tuning: %.3f ms" % (f_timer_after(a_nd, b_nd, - 并行化和循环展开 ```{.python .input n=31} -sch_tuned.trace +sch.trace.show() ``` ```{.python .input n=32} -IPython.display.HTML(code2html(sch_tuned.mod.script())) +IPython.display.HTML(code2html(sch.mod.script())) ``` ### 章节检查点 @@ -433,10 +430,10 @@ nd_params = {k: tvm.nd.array(v) for k, v in mlp_params.items()} @tvm.script.ir_module class MyModuleMixture: @T.prim_func - def linear0(X: T.Buffer[(1, 784), "float32"], - W: T.Buffer[(128, 784), "float32"], - B: T.Buffer[(128,), "float32"], - Z: T.Buffer[(1, 128), "float32"]): + def linear0(X: T.Buffer((1, 784), "float32"), + W: T.Buffer((128, 784), "float32"), + B: T.Buffer((128,), "float32"), + Z: T.Buffer((1, 128), "float32")): T.func_attr({"global_symbol": "linear0", "tir.noalias": True}) Y = T.alloc_buffer((1, 128), "float32") for i, j, k in T.grid(1, 128, 784): @@ -452,15 +449,15 @@ class MyModuleMixture: Z[vi, vj] = Y[vi, vj] + B[vj] @R.function - def main(x: Tensor((1, 784), "float32"), - w0: Tensor((128, 784), "float32"), - b0: Tensor((128,), "float32"), - w1: Tensor((10, 128), "float32"), - b1: Tensor((10,), "float32")): + def main(x: R.Tensor((1, 784), "float32"), + w0: R.Tensor((128, 784), "float32"), + b0: R.Tensor((128,), "float32"), + w1: R.Tensor((10, 128), "float32"), + b1: R.Tensor((10,), "float32")): with R.dataflow(): - lv0 = R.call_tir(linear0, (x, w0, b0), (1, 128), dtype="float32") - lv1 = R.call_tir("env.relu", (lv0,), (1, 128), dtype="float32") - out = R.call_tir("env.linear", (lv1, w1, b1), (1, 10), dtype="float32") + lv0 = R.call_dps_packed("linear0", (x, w0, b0), R.Tensor((1, 128), dtype="float32")) + lv1 = R.call_dps_packed("env.relu", (lv0,), R.Tensor((1, 128), dtype="float32")) + out = R.call_dps_packed("env.linear", (lv1, w1, b1), R.Tensor((1, 10), dtype="float32")) R.output(out) return out ``` @@ -493,7 +490,7 @@ MyModuleWithParams = relax.transform.BindParams("main", nd_params)(MyModuleMixtu ``` ```{.python .input n=40} -ex = relax.vm.build(MyModuleWithParams, target="llvm") +ex = relax.build(MyModuleWithParams, target="llvm") vm = relax.VirtualMachine(ex, tvm.cpu()) nd_res = vm["main"](data_nd) @@ -522,23 +519,22 @@ IPython.display.HTML(code2html(mod_linear.script())) ``` ```{.python .input n=43} -sch_tuned_linear = ms.tune_tir( +database = ms.tune_tir( mod=mod_linear, target="llvm --num-cores=1", - config=ms.TuneConfig( - max_trials_global=64, - num_trials_per_iter=64, - ), + max_trials_global=64, + num_trials_per_iter=64, work_dir="./tune_tmp", task_name="main", ) +sch = ms.tir_integration.compile_tir(database, mod_linear, "llvm --num-cores=1") ``` 现在我们需要在调优后用新函数替换原来的 `linear0`。我们可以通过首先获得一个 `global_var`(一个指向 IRModule 中函数的 `pointer` 引用),然后调用 `update_func` 来用新的函数替换原本的函数。 ```{.python .input n=44} MyModuleWithParams2 = relax.transform.BindParams("main", nd_params)(MyModuleMixture) -new_func = sch_tuned_linear.mod["main"].with_attr("global_symbol", "linear0") +new_func = sch.mod["main"].with_attr("global_symbol", "linear0") gv = MyModuleWithParams2.get_global_var("linear0") MyModuleWithParams2.update_func(gv, new_func) IPython.display.HTML(code2html(MyModuleWithParams2.script())) @@ -547,7 +543,7 @@ IPython.display.HTML(code2html(MyModuleWithParams2.script())) 我们可以发现上面代码中的 `linear0` 已经被替换了。 ```{.python .input n=45} -ex = relax.vm.build(MyModuleWithParams2, target="llvm") +ex = relax.build(MyModuleWithParams2, target="llvm") vm = relax.VirtualMachine(ex, tvm.cpu()) nd_res = vm["main"](data_nd) diff --git a/chapter_end_to_end/index.md b/chapter_end_to_end/index.md index 1b7bc84..4fdd246 100644 --- a/chapter_end_to_end/index.md +++ b/chapter_end_to_end/index.md @@ -21,8 +21,6 @@ from tvm.ir.module import IRModule from tvm.script import tir as T, relax as R import numpy as np from tvm import relax -# This is needed for deferring annotation parsing in TVMScript -from __future__ import annotations import IPython ``` @@ -175,63 +173,48 @@ print("Low-level Numpy MLP Prediction:", class_names[pred_kind[0]]) @tvm.script.ir_module class MyModule: @T.prim_func - def relu0(X: T.Buffer[(1, 128), "float32"], - Y: T.Buffer[(1, 128), "float32"]): - # function attr dict - T.func_attr({"global_symbol": "relu0", "tir.noalias": True}) - for i, j in T.grid(1, 128): + def relu0(x: T.handle, y: T.handle): + n = T.int64() + X = T.match_buffer(x, (1, n), "float32") + Y = T.match_buffer(y, (1, n), "float32") + for i, j in T.grid(1, n): with T.block("Y"): vi, vj = T.axis.remap("SS", [i, j]) Y[vi, vj] = T.max(X[vi, vj], T.float32(0)) @T.prim_func - def linear0(X: T.Buffer[(1, 784), "float32"], - W: T.Buffer[(128, 784), "float32"], - B: T.Buffer[(128,), "float32"], - Z: T.Buffer[(1, 128), "float32"]): - T.func_attr({"global_symbol": "linear0", "tir.noalias": True}) - Y = T.alloc_buffer((1, 128), "float32") - for i, j, k in T.grid(1, 128, 784): + def linear0(x: T.handle, + w: T.handle, + b: T.handle, + z: T.handle): + m, n, k = T.int64(), T.int64(), T.int64() + X = T.match_buffer(x, (1, m), "float32") + W = T.match_buffer(w, (n, m), "float32") + B = T.match_buffer(b, (n, ), "float32") + Z = T.match_buffer(z, (1, n), "float32") + Y = T.alloc_buffer((1, n), "float32") + for i, j, k in T.grid(1, n, m): with T.block("Y"): vi, vj, vk = T.axis.remap("SSR", [i, j, k]) with T.init(): Y[vi, vj] = T.float32(0) Y[vi, vj] = Y[vi, vj] + X[vi, vk] * W[vj, vk] - - for i, j in T.grid(1, 128): - with T.block("Z"): - vi, vj = T.axis.remap("SS", [i, j]) - Z[vi, vj] = Y[vi, vj] + B[vj] - - @T.prim_func - def linear1(X: T.Buffer[(1, 128), "float32"], - W: T.Buffer[(10, 128), "float32"], - B: T.Buffer[(10,), "float32"], - Z: T.Buffer[(1, 10), "float32"]): - T.func_attr({"global_symbol": "linear1", "tir.noalias": True}) - Y = T.alloc_buffer((1, 10), "float32") - for i, j, k in T.grid(1, 10, 128): - with T.block("Y"): - vi, vj, vk = T.axis.remap("SSR", [i, j, k]) - with T.init(): - Y[vi, vj] = T.float32(0) - Y[vi, vj] = Y[vi, vj] + X[vi, vk] * W[vj, vk] - - for i, j in T.grid(1, 10): + for i, j in T.grid(1, n): with T.block("Z"): vi, vj = T.axis.remap("SS", [i, j]) Z[vi, vj] = Y[vi, vj] + B[vj] @R.function - def main(x: Tensor((1, 784), "float32"), - w0: Tensor((128, 784), "float32"), - b0: Tensor((128,), "float32"), - w1: Tensor((10, 128), "float32"), - b1: Tensor((10,), "float32")): + def main(x: R.Tensor((1, "m"), "float32"), + w0: R.Tensor(("n", "m"), "float32"), + b0: R.Tensor(("n", ), "float32"), + w1: R.Tensor(("k", "n"), "float32"), + b1: R.Tensor(("k", ), "float32")): + m, n, k = T.int64(), T.int64(), T.int64() with R.dataflow(): - lv0 = R.call_tir(linear0, (x, w0, b0), (1, 128), dtype="float32") - lv1 = R.call_tir(relu0, (lv0,), (1, 128), dtype="float32") - out = R.call_tir(linear1, (lv1, w1, b1), (1, 10), dtype="float32") + lv0 = R.call_dps_packed("linear0", (x, w0, b0), R.Tensor((1, n), "float32")) + lv1 = R.call_dps_packed("relu0", (lv0, ), R.Tensor((1, n), "float32")) + out = R.call_dps_packed("linear0", (lv1, w1, b1), R.Tensor((1, k), "float32")) R.output(out) return out ``` @@ -254,28 +237,28 @@ class MyModule: 我们在之前的课程中已经看到了这种可视化。 图本身可以看作是一种抽象,在机器学习框架中通常称为**计算图 (computational graph)**。 -### `call_tir` +### `call_dps_packed` -您可能已经注意到的一件事是,计算图中的每个操作步骤都包含一个`R.call_tir`操作。 这是引入元张量函数的过程: +您可能已经注意到的一件事是,计算图中的每个操作步骤都包含一个`R.call_dps_packed`操作。 这是引入元张量函数的过程: ```python -lv0 = R.call_tir(linear0, (x, w0, b0), (1, 128), dtype="float32") +lv0 = R.call_dps_packed("linear0", (x, w0, b0), R.Tensor((1, n), dtype="float32")) ``` -为了解释 `R.call_tir` 的含义,让我们回顾一下操作的等效底层 NumPy 实现,如下所示: +为了解释 `R.call_dps_packed` 的含义,让我们回顾一下操作的等效底层 NumPy 实现,如下所示: ```{.python .input n=8} -def lnumpy_call_tir(prim_func, inputs, shape, dtype): +def lnumpy_call_dps_packed(prim_func, inputs, shape, dtype): res = np.empty(shape, dtype=dtype) prim_func(*inputs, res) return res ``` -具体来说,`call_tir` 接受一个元函数 (`prim_func`) 的输入列表,并分配一个输出张量`res`,然后将输入和输出传递给`prim_func`。 执行 `prim_func` 后,结果会填充到 `res` 中,然后我们可以返回结果。 +具体来说,`call_dps_packed` 接受一个元函数 (`prim_func`) 的输入列表,并分配一个输出张量`res`,然后将输入和输出传递给`prim_func`。 执行 `prim_func` 后,结果会填充到 `res` 中,然后我们可以返回结果。 -请注意,`lnumpy_call_tir` 只是一个参考实现,以显示 `R.call_tir` 的含义。 在实际应用中,可以有不同的底层方法来优化执行。 例如,我们可能会选择提前分配所有输出内存,然后运行,我们将在以后的课程中介绍。 +请注意,`lnumpy_call_dps_packed` 只是一个参考实现,以显示 `R.call_dps_packed` 的含义。 在实际应用中,可以有不同的底层方法来优化执行。 例如,我们可能会选择提前分配所有输出内存,然后运行,我们将在以后的课程中介绍。 -一个很自然的问题:为什么我们需要 `call_tir`? 这是因为我们的元张量函数采用以下调用约定: +一个很自然的问题:为什么我们需要 `call_dps_packed`? 这是因为我们的元张量函数采用以下调用约定: ```python def low_level_prim_func(in0, in1, ..., out): @@ -311,22 +294,22 @@ def lnumpy_mlp(data, w0, b0, w1, b1): 当然,我们仍然可以通过引入输入边和输出边来概括图定义,这会使抽象的定义和变换复杂化。 -所以回到`call_tir`,这里的关键思想是我们想要隐藏可能的分配或对函数的显式写入。 用更正式的术语来说,我们希望函数是 **pure** 或 **side-effect free**。(译者注:“pure”和“side-effect”是 PL 中的术语,译者不确定中文的准确名称,故不进行翻译。欢迎社区中的专业人士参与完善) +所以回到`call_dps_packed`,这里的关键思想是我们想要隐藏可能的分配或对函数的显式写入。 用更正式的术语来说,我们希望函数是 **pure** 或 **side-effect free**。(译者注:“pure”和“side-effect”是 PL 中的术语,译者不确定中文的准确名称,故不进行翻译。欢迎社区中的专业人士参与完善) 如果一个函数只从其输入中读取并通过其输出返回结果,它不会改变程序的其他部分(例如递增全局计数器),那么它是**pure**或**side-effect free**的。 -**call_tir** 使我们能够隐藏调用低层元函数细节,并将它们应用到计算图中。 +**call_dps_packed** 使我们能够隐藏调用低层元函数细节,并将它们应用到计算图中。 -我们还可以在底层 NumPy 中看到 `call_tir` 的作用。 现在我们已经定义了 `lnumpy_call_tir`,我们可以将底层 NumPy 代码重写为: +我们还可以在底层 NumPy 中看到 `call_dps_packed` 的作用。 现在我们已经定义了 `lnumpy_call_dps_packed`,我们可以将底层 NumPy 代码重写为: ```{.python .input n=9} -def lnumpy_mlp_with_call_tir(data, w0, b0, w1, b1): - lv0 = lnumpy_call_tir(lnumpy_linear0, (data, w0, b0), (1, 128), dtype="float32") - lv1 = lnumpy_call_tir(lnumpy_relu0, (lv0, ), (1, 128), dtype="float32") - out = lnumpy_call_tir(lnumpy_linear1, (lv1, w1, b1), (1, 10), dtype="float32") +def lnumpy_mlp_with_call_dps_packed(data, w0, b0, w1, b1): + lv0 = lnumpy_call_dps_packed(lnumpy_linear0, (data, w0, b0), (1, 128), dtype="float32") + lv1 = lnumpy_call_dps_packed(lnumpy_relu0, (lv0, ), (1, 128), dtype="float32") + out = lnumpy_call_dps_packed(lnumpy_linear1, (lv1, w1, b1), (1, 10), dtype="float32") return out -result = lnumpy_mlp_with_call_tir( +result = lnumpy_mlp_with_call_dps_packed( img.reshape(1, 784), mlp_params["w0"], mlp_params["b0"], @@ -337,7 +320,7 @@ pred_kind = np.argmax(result, axis=1) print("Low-level Numpy with CallTIR Prediction:", class_names[pred_kind[0]]) ``` -实际上,最底层的实现会有显式的内存分配,所以`call_tir`主要是为了让我们在生成实际实现之前继续做一些高层的转换。 +实际上,最底层的实现会有显式的内存分配,所以`call_dps_packed`主要是为了让我们在生成实际实现之前继续做一些高层的转换。 ### Dataflow Block @@ -345,9 +328,9 @@ Relax 函数中的另一个重要元素是 `R.dataflow()` 范围标注: ```python with R.dataflow(): - lv0 = R.call_tir(linear0, (x, w0, b0), (1, 128), dtype="float32") - lv1 = R.call_tir(relu0, (lv0,), (1, 128), dtype="float32") - out = R.call_tir(linear1, (lv1, w1, b1), (1, 10), dtype="float32") + lv0 = R.call_dps_packed("linear0", (x, w0, b0), R.Tensor((1, n), "float32")) + lv1 = R.call_dps_packed("relu0", (lv0, ), R.Tensor((1, n), "float32")) + out = R.call_dps_packed("linear0", (lv1, w1, b1), R.Tensor((1, k), "float32")) R.output(out) ``` @@ -357,21 +340,20 @@ with R.dataflow(): ```python @R.function -def main(x: Tensor((1, 784), "float32"), - w0: Tensor((128, 784), "float32"), - b0: Tensor((128,), "float32"), - w1: Tensor((10, 128), "float32"), - b1: Tensor((10,), "float32")): +def main(x: R.Tensor((1, "m"), "float32"), + w0: R.Tensor(("n", "m"), "float32"), + b0: R.Tensor(("n", ), "float32"), + w1: R.Tensor(("k", "n"), "float32"), + b1: R.Tensor(("k", ), "float32")): + m, n, k = T.int64(), T.int64(), T.int64() with R.dataflow(): - lv0 = R.call_tir(linear0, (x, w0, b0), (1, 128), dtype="float32") - gv0 = R.call_tir(relu0, (lv0,), (1, 128), dtype="float32") + lv0 = R.call_dps_packed("linear0", (x, w0, b0), R.Tensor((1, n), "float32")) + gv0 = R.call_dps_packed("relu0", (lv0, ), R.Tensor((1, n), "float32")) R.output(gv0) - gv1 = R.alloc_tensor((1, 128), dtype="float32") - with R.dataflow(): - out = R.call_tir(linear1, (gv0, gv1, b0), (1, 128), dtype="float32") + out = R.call_dps_packed("linear0", (gv0, w1, b1), R.Tensor((1, k), "float32")) R.output(out) return out ``` @@ -383,7 +365,7 @@ def main(x: Tensor((1, 784), "float32"), 到目前为止,我们已经完成了一个 Relax 程序的示例,并涵盖了大部分元素,包括: - 计算图 -- `call_tir` +- `call_dps_packed` - Dataflow block 这些元素应该让我们能够开始端到端的模型执行和编译。 当我们在后面的章节中遇到新概念时,我们还将介绍它们。 @@ -396,10 +378,10 @@ def main(x: Tensor((1, 784), "float32"), IPython.display.Code(MyModule.script(), language="python") ``` -我们调用 `relax.vm.build` 来构建这个函数。 注意:Relax 仍在开发中,因此某些 API 可能会更改。 不过,我们的主要目标是熟悉端到端模型的整体 MLC 流程(构造、转换、构建)。 +我们调用 `relax.build` 来构建这个函数。 注意:Relax 仍在开发中,因此某些 API 可能会更改。 不过,我们的主要目标是熟悉端到端模型的整体 MLC 流程(构造、转换、构建)。 ```{.python .input n=11} -ex = relax.vm.build(MyModule, target="llvm") +ex = relax.build(MyModule, target="llvm") type(ex) ``` @@ -444,24 +426,25 @@ print("MyModule Prediction:", class_names[pred_kind[0]]) @tvm.script.ir_module class MyModuleWithExternCall: @R.function - def main(x: Tensor((1, 784), "float32"), - w0: Tensor((128, 784), "float32"), - b0: Tensor((128,), "float32"), - w1: Tensor((10, 128), "float32"), - b1: Tensor((10,), "float32")): + def main(x: R.Tensor((1, "m"), "float32"), + w0: R.Tensor(("n", "m"), "float32"), + b0: R.Tensor(("n", ), "float32"), + w1: R.Tensor(("k", "n"), "float32"), + b1: R.Tensor(("k", ), "float32")): # block 0 + m, n, k = T.int64(), T.int64(), T.int64() with R.dataflow(): - lv0 = R.call_tir("env.linear", (x, w0, b0), (1, 128), dtype="float32") - lv1 = R.call_tir("env.relu", (lv0,), (1, 128), dtype="float32") - out = R.call_tir("env.linear", (lv1, w1, b1), (1, 10), dtype="float32") + lv0 = R.call_dps_packed("env.linear", (x, w0, b0), R.Tensor((1, n), "float32")) + lv1 = R.call_dps_packed("env.relu", (lv0, ), R.Tensor((1, n), "float32")) + out = R.call_dps_packed("env.linear", (lv1, w1, b1), R.Tensor((1, k), "float32")) R.output(out) return out ``` -请注意,我们现在直接在 `call_tir` 中传入字符串: +请注意,我们现在直接在 `call_dps_packed` 中传入字符串: ```python -R.call_tir("env.linear", (x, w0, b0), (1, 128), dtype="float32") +R.call_dps_packed("env.linear", (x, w0, b0), R.Tensor((1, n), "float32")) ``` 这些字符串是我们期望在模型执行期间的运行时函数 (runtime function) 的名称。 @@ -502,7 +485,7 @@ def lnumpy_relu(x: tvm.nd.NDArray, 现在我们可以构建并运行`MyModuleWithExternCall`,我们可以验证模型得到了相同的结果。 ```{.python .input n=18} -ex = relax.vm.build(MyModuleWithExternCall, target="llvm") +ex = relax.build(MyModuleWithExternCall, target="llvm") vm = relax.VirtualMachine(ex, tvm.cpu()) nd_res = vm["main"](data_nd, @@ -524,34 +507,38 @@ print("MyModuleWithExternCall Prediction:", class_names[pred_kind[0]]) @tvm.script.ir_module class MyModuleMixture: @T.prim_func - def linear0(X: T.Buffer[(1, 784), "float32"], - W: T.Buffer[(128, 784), "float32"], - B: T.Buffer[(128,), "float32"], - Z: T.Buffer[(1, 128), "float32"]): - T.func_attr({"global_symbol": "linear0", "tir.noalias": True}) - Y = T.alloc_buffer((1, 128), "float32") - for i, j, k in T.grid(1, 128, 784): + def linear0(x: T.handle, + w: T.handle, + b: T.handle, + z: T.handle): + m, n, k = T.int64(), T.int64(), T.int64() + X = T.match_buffer(x, (1, m), "float32") + W = T.match_buffer(w, (n, m), "float32") + B = T.match_buffer(b, (n, ), "float32") + Z = T.match_buffer(z, (1, n), "float32") + Y = T.alloc_buffer((1, n), "float32") + for i, j, k in T.grid(1, n, m): with T.block("Y"): vi, vj, vk = T.axis.remap("SSR", [i, j, k]) with T.init(): Y[vi, vj] = T.float32(0) Y[vi, vj] = Y[vi, vj] + X[vi, vk] * W[vj, vk] - - for i, j in T.grid(1, 128): + for i, j in T.grid(1, n): with T.block("Z"): vi, vj = T.axis.remap("SS", [i, j]) - Z[vi, vj] = Y[vi, vj] + B[vj] + Z[vi, vj] = Y[vi, vj] + B[vj] @R.function - def main(x: Tensor((1, 784), "float32"), - w0: Tensor((128, 784), "float32"), - b0: Tensor((128,), "float32"), - w1: Tensor((10, 128), "float32"), - b1: Tensor((10,), "float32")): + def main(x: R.Tensor((1, "m"), "float32"), + w0: R.Tensor(("n", "m"), "float32"), + b0: R.Tensor(("n", ), "float32"), + w1: R.Tensor(("k", "n"), "float32"), + b1: R.Tensor(("k", ), "float32")): + m, n, k = T.int64(), T.int64(), T.int64() with R.dataflow(): - lv0 = R.call_tir(linear0, (x, w0, b0), (1, 128), dtype="float32") - lv1 = R.call_tir("env.relu", (lv0,), (1, 128), dtype="float32") - out = R.call_tir("env.linear", (lv1, w1, b1), (1, 10), dtype="float32") + lv0 = R.call_dps_packed("linear0", (x, w0, b0), R.Tensor((1, n), "float32")) + lv1 = R.call_dps_packed("env.relu", (lv0, ), R.Tensor((1, n), "float32")) + out = R.call_dps_packed("env.linear", (lv1, w1, b1), R.Tensor((1, k), "float32")) R.output(out) return out ``` @@ -559,7 +546,7 @@ class MyModuleMixture: 上面的代码块显示了一个示例,其中 `linear0` 仍然在 `TensorIR` 中实现,而其余函数被重定向到库函数。 我们可以构建并运行以验证结果。 ```{.python .input n=20} -ex = relax.vm.build(MyModuleMixture, target="llvm") +ex = relax.build(MyModuleMixture, target="llvm") vm = relax.VirtualMachine(ex, tvm.cpu()) nd_res = vm["main"](data_nd, @@ -584,7 +571,7 @@ IPython.display.Code(MyModuleWithParams.script(), language="python") 在上面的脚本中,`meta[relay.Constant][0]` (译者注:目前 `Relax` 的常量表达依然继承自 `Relay` ,未来该 API 可能会更改) 对应于一个存储常量的隐式字典(它没有显示为脚本的一部分,但仍然是 IRModule 的一部分)。 如果我们构建转换后的 IRModule,我们现在可以通过传入输入数据来调用该函数。 ```{.python .input n=22} -ex = relax.vm.build(MyModuleWithParams, target="llvm") +ex = relax.build(MyModuleWithParams, target="llvm") vm = relax.VirtualMachine(ex, tvm.cpu()) nd_res = vm["main"](data_nd) @@ -614,6 +601,6 @@ print("MyModuleWithParams Prediction:", class_names[pred_kind[0]]) - 计算图抽象有助于将元张量函数拼接在一起以进行端到端执行。 - Relax 抽象的关键要素包括 - - call_tir 构造,将目标传递规范的元函数嵌入到计算图中 + - call_dps_packed 构造,将目标传递规范的元函数嵌入到计算图中 - Dataflow block - 计算图允许调用环境库函数和 `TensorIR` 函数。 diff --git a/chapter_gpu_acceleration/part1.md b/chapter_gpu_acceleration/part1.md index d6ba5cf..1b2b712 100644 --- a/chapter_gpu_acceleration/part1.md +++ b/chapter_gpu_acceleration/part1.md @@ -22,9 +22,6 @@ from tvm.ir.module import IRModule from tvm.script import tir as T, relax as R from tvm import relax import numpy as np - -# This is needed for deferring annotation parsing in TVMScript -from __future__ import annotations ``` ### GPU 体系结构 @@ -43,9 +40,9 @@ from __future__ import annotations @tvm.script.ir_module class MyModuleVecAdd: @T.prim_func - def main(A: T.Buffer[(1024,), "float32"], - B: T.Buffer[(1024,), "float32"], - C: T.Buffer[(1024,), "float32"]) -> None: + def main(A: T.Buffer((1024,), "float32"), + B: T.Buffer((1024,), "float32"), + C: T.Buffer((1024,), "float32")) -> None: T.func_attr({"global_symbol": "main", "tir.noalias": True}) for i in T.grid(1024): with T.block("C"): @@ -176,9 +173,9 @@ print(rt_mod.imported_modules[0].get_source()) @tvm.script.ir_module class MyModuleMatmul: @T.prim_func - def main(A: T.Buffer[(1024, 1024), "float32"], - B: T.Buffer[(1024, 1024), "float32"], - C: T.Buffer[(1024, 1024), "float32"]) -> None: + def main(A: T.Buffer((1024, 1024), "float32"), + B: T.Buffer((1024, 1024), "float32"), + C: T.Buffer((1024, 1024), "float32")) -> None: T.func_attr({"global_symbol": "main", "tir.noalias": True}) for i, j, k in T.grid(1024, 1024, 1024): with T.block("C"): @@ -315,21 +312,20 @@ print("GEMM-Blocking: %f GFLOPS" % (num_flop / evaluator(A_nd, B_nd, C_nd).mean ```python from tvm import meta_schedule as ms -sch_tuned = ms.tune_tir( +database = ms.tune_tir( mod=MyModuleMatmul, target="nvidia/tesla-p100", - config=ms.TuneConfig( - max_trials_global=64, - num_trials_per_iter=64, - ), + max_trials_global=64, + num_trials_per_iter=64, work_dir="./tune_tmp", task_name="main" ) -sch_tuned.mod.show() +sch = ms.tir_integration.compile_tir(database, MyModuleMatmul, "nvidia/tesla-p100") +sch.mod.show() ``` ```python -rt_mod = tvm.build(sch_tuned.mod, target="nvidia/tesla-p100") +rt_mod = tvm.build(sch.mod, target="nvidia/tesla-p100") dev = tvm.cuda(0) evaluator = rt_mod.time_evaluator("main", dev, number=10) diff --git a/chapter_gpu_acceleration/part2.md b/chapter_gpu_acceleration/part2.md index 640d28e..ad0d2e8 100644 --- a/chapter_gpu_acceleration/part2.md +++ b/chapter_gpu_acceleration/part2.md @@ -12,9 +12,6 @@ from tvm.ir.module import IRModule from tvm.script import tir as T, relax as R from tvm import relax import numpy as np - -# This is needed for deferring annotation parsing in TVMScript -from __future__ import annotations ``` ### 硬件专业化趋势 @@ -89,9 +86,9 @@ np.testing.assert_allclose(c_np, c_tmm, rtol=1e-5) class MatmulBlockModule: @T.prim_func def main( - A: T.Buffer[(1024, 1024), "float32"], - B: T.Buffer[(1024, 1024), "float32"], - C: T.Buffer[(1024, 1024), "float32"], + A: T.Buffer((1024, 1024), "float32"), + B: T.Buffer((1024, 1024), "float32"), + C: T.Buffer((1024, 1024), "float32"), ) -> None: T.func_attr({"global_symbol": "main", "tir.noalias": True}) for i0, j0, k0 in T.grid(64, 64, 64): @@ -163,9 +160,9 @@ sch.mod.show() class MatmulModule: @T.prim_func def main( - A: T.Buffer[(1024, 1024), "float32"], - B: T.Buffer[(1024, 1024), "float32"], - C: T.Buffer[(1024, 1024), "float32"], + A: T.Buffer((1024, 1024), "float32"), + B: T.Buffer((1024, 1024), "float32"), + C: T.Buffer((1024, 1024), "float32"), ) -> None: T.func_attr({"global_symbol": "main", "tir.noalias": True}) for i, j, k in T.grid(1024, 1024, 1024): @@ -238,9 +235,9 @@ def tmm16_desc(a: T.handle, b: T.handle, c: T.handle) -> None: @T.prim_func def tmm16_impl(a: T.handle, b: T.handle, c: T.handle) -> None: - sa = T.var("int32") - sb = T.var("int32") - sc = T.var("int32") + sa = T.int32() + sb = T.int32() + sc = T.int32() A = T.match_buffer(a, (16, 16), "float32", offset_factor=16, strides=[sa, 1], scope="global.A_reg") B = T.match_buffer(b, (16, 16), "float32", offset_factor=16, strides=[sb, 1], scope="global.B_reg") C = T.match_buffer(c, (16, 16), "float32", offset_factor=16, strides=[sc, 1], scope="global.accumulator") @@ -312,7 +309,9 @@ sch.annotate(i, "pragma_import_llvm", tmm_kernel()) 然后我们可以去执行下面的代码块,它将张量化的计算重定向到自定义的 `tmm_kernel`。 -```{.python .input} +``` + + a_nd = tvm.nd.array(a_np) b_nd = tvm.nd.array(b_np) diff --git a/chapter_graph_optimization/index.md b/chapter_graph_optimization/index.md index 0332ada..673ebdf 100644 --- a/chapter_graph_optimization/index.md +++ b/chapter_graph_optimization/index.md @@ -11,9 +11,6 @@ 首先,让我们导入必要的依赖项。 ```{.python .input} -# This is needed for deferring annotation parsing in TVMScript -from __future__ import annotations - import tvm from tvm.ir.module import IRModule from tvm.script import tir as T, relax as R @@ -29,15 +26,15 @@ import numpy as np @tvm.script.ir_module class MyModule: @R.function - def main(x: Tensor((3, 4), "float32"), y: Tensor((3, 4), "float32")): - with relax.dataflow(): - lv0 = relax.multiply(x, y) - gv0 = relax.add(lv0, y) - relax.output(gv0) + def main(x: R.Tensor((3, 4), "float32"), y: R.Tensor((3, 4), "float32")): + with R.dataflow(): + lv0 = relax.op.multiply(x, y) + gv0 = relax.op.add(lv0, y) + R.output(gv0) return gv0 ``` -`MyModule` 包含一个带有两个图层 op 的 relax 函数,其中包含 `relax.multiply` 和`relax.add`。我们的目标是找到这两个运算符并将它们替换为一个 `relax.ewise_fma` 运算符的调用。 +`MyModule` 包含一个带有两个图层 op 的 relax 函数,其中包含 `relax.op.multiply` 和`relax.op.add`。我们的目标是找到这两个运算符并将它们替换为一个 `relax.op.ewise_fma` 运算符的调用。 在我们研究如何准确地做到这一点之前,让我们首先检查构成 `MyModule` 的数据结构。 每个 `IRModule` 都包含一组函数,函数体由一组称为抽象语法树(AST)的数据结构组成。 @@ -45,7 +42,7 @@ class MyModule: relax_func = MyModule["main"] ``` -每个函数都由一个 `relax.Function` 节点表示。 +每个函数都由一个 `relax.expr.Function` 节点表示。 ```{.python .input} type(relax_func) @@ -77,8 +74,8 @@ dataflow_block = func_body.blocks[0] 在我们的特定情况下,我们有一个数据流块,其中包含两个 Binding 。绑定对应于以下代码: ```python -lv0 = relax.multiply(x, y) -gv0 = relax.add(lv0, y) +lv0 = relax.op.multiply(x, y) +gv0 = relax.op.add(lv0, y) ``` ```{.python .input} @@ -153,23 +150,22 @@ import pickle as pkl mlp_params = pkl.load(open("fasionmnist_mlp_params.pkl", "rb")) ``` -以下代码重新构建了我们在过去章节中使用的 FashionMNIST MLP 模型。 为了简化过程,我们直接使用高级运算符构建模型,例如 `relax.op.add` 和 `relax.op.dense`。 +以下代码重新构建了我们在过去章节中使用的 FashionMNIST MLP 模型。 为了简化过程,我们直接使用高级运算符构建模型,例如 `relax.op.add` 和 `relax.op.matmul`。 ```{.python .input} def create_model(): bb = relax.BlockBuilder() - x = relax.Var("x", (1, 784), relax.DynTensorType(2, "float32")) + x = relax.Var("x", relax.TensorStructInfo((1, 784), "float32")) w0 = relax.const(mlp_params["w0"], "float32") b0 = relax.const(mlp_params["b0"], "float32") w1 = relax.const(mlp_params["w1"], "float32") b1 = relax.const(mlp_params["b1"], "float32") - with bb.function("main", [x]): with bb.dataflow(): - lv0 = bb.emit(relax.op.dense(x, w0)) + lv0 = bb.emit(relax.op.matmul(x, relax.op.permute_dims(w0))) lv1 = bb.emit(relax.op.add(lv0, b0)) - lv2 = bb.emit(relax.op.relu(lv1)) - lv3 = bb.emit(relax.op.dense(lv2, w1)) + lv2 = bb.emit(relax.op.nn.relu(lv1)) + lv3 = bb.emit(relax.op.matmul(lv2, relax.op.permute_dims(w1))) lv4 = bb.emit(relax.op.add(lv3, b1)) gv = bb.emit_output(lv4) bb.emit_func_output(gv) @@ -180,21 +176,21 @@ MLPModel = create_model() MLPModel.show() ``` -我们的目标是“融合” `dense` 和 `add` 算子到一起。 以下代码通过以下步骤实现: +我们的目标是“融合” `matmul` 和 `add` 算子到一起。 以下代码通过以下步骤实现: -- 识别 `dense` 和 `add` 算子。 -- 生成另一个调用 `dense` 和 `add` 算子的子函数。 -- 将 `dense` 和 `add` 替换为融合后的子函数。 +- 识别 `matmul` 和 `add` 算子。 +- 生成另一个调用 `matmul` 和 `add` 算子的子函数。 +- 将 `matmul` 和 `add` 替换为融合后的子函数。 ```{.python .input} @relax.expr_functor.mutator -class DenseAddFusor(relax.PyExprMutator): +class MatmulAddFusor(relax.PyExprMutator): def __init__(self, mod: IRModule) -> None: super().__init__() self.mod_ = mod # cache pre-defined ops self.add_op = tvm.ir.Op.get("relax.add") - self.dense_op = tvm.ir.Op.get("relax.nn.dense") + self.matmul_op = tvm.ir.Op.get("relax.matmul") self.counter = 0 def transform(self) -> IRModule: @@ -202,7 +198,7 @@ class DenseAddFusor(relax.PyExprMutator): if not isinstance(func, relax.Function): continue # avoid already fused primitive functions - if "Primitive" in func.attrs.keys() and func.attrs["Primitive"] != 0: + if func.attrs is not None and "Primitive" in func.attrs.keys() and func.attrs["Primitive"] != 0: continue updated_func = self.visit_expr(func) updated_func = relax.analysis.remove_all_unused(updated_func) @@ -218,7 +214,7 @@ class DenseAddFusor(relax.PyExprMutator): return False return node.op == op - # pattern match dense => add + # pattern match matmul => add if not match_call(call, self.add_op): return call @@ -226,7 +222,7 @@ class DenseAddFusor(relax.PyExprMutator): if value is None: return call - if not match_call(value, self.dense_op): + if not match_call(value, self.matmul_op): return call x = value.args[0] @@ -234,17 +230,17 @@ class DenseAddFusor(relax.PyExprMutator): b = call.args[1] # construct a new fused primitive function - param_x = relax.Var("x", x.shape_, x._checked_type_) - param_w = relax.Var("w", w.shape_, w._checked_type_) - param_b = relax.Var("b", b.shape_, b._checked_type_) + param_x = relax.Var("x" ,relax.TensorStructInfo(x.struct_info.shape, x.struct_info.dtype)) + param_w = relax.Var("w" ,relax.TensorStructInfo(w.struct_info.shape, w.struct_info.dtype)) + param_b = relax.Var("b" ,relax.TensorStructInfo(b.struct_info.shape, b.struct_info.dtype)) bb = relax.BlockBuilder() - fn_name = "fused_dense_add%d" % (self.counter) + fn_name = "fused_matmul_add%d" % (self.counter) self.counter += 1 with bb.function(fn_name, [param_x, param_w, param_b]): with bb.dataflow(): - lv0 = bb.emit(relax.op.nn.dense(param_x, param_w)) + lv0 = bb.emit(relax.op.matmul(param_x, param_w)) gv = bb.emit_output(relax.op.add(lv0, param_b)) bb.emit_func_output(gv) @@ -255,11 +251,11 @@ class DenseAddFusor(relax.PyExprMutator): # construct call into the fused function return relax.Call(global_var, [x, w, b], None, None) -@tvm.ir.transform.module_pass(opt_level=2, name="DeseAddFuse") +@tvm.ir.transform.module_pass(opt_level=2, name="MatmulAddFuse") class FuseDenseAddPass: """The wrapper for the LowerTensorIR pass.""" def transform_module(self, mod, ctx): - return DenseAddFusor(mod).transform() + return MatmulAddFusor(mod).transform() MLPFused = FuseDenseAddPass()(MLPModel) @@ -268,7 +264,7 @@ MLPFused.show() ### 为什么要创建子函数 -在上面的例子中,我们创建了两个前缀为 `fuse_dense_add` 的子函数。 这些子函数包含有融合后算子的计算信息。 这种重写的替代方法是简单地为融合运算符创建一个单独的原始操作(如`ewise_fma`)。 但是,当我们尝试融合更多运算符时,可能存在指数级数量的组合。 将融合操作分组在一起的子函数为后续的 pass 保留了原始信息,进而便于分析,无需为每个融合 pattern 引入专用的高级运算符。 +在上面的例子中,我们创建了两个前缀为 `fuse_matmul_add` 的子函数。 这些子函数包含有融合后算子的计算信息。 这种重写的替代方法是简单地为融合运算符创建一个单独的原始操作(如`ewise_fma`)。 但是,当我们尝试融合更多运算符时,可能存在指数级数量的组合。 将融合操作分组在一起的子函数为后续的 pass 保留了原始信息,进而便于分析,无需为每个融合 pattern 引入专用的高级运算符。 ## 映射到 TensorIR Calls @@ -304,9 +300,9 @@ class LowerToTensorIR(relax.PyExprMutator): return self.builder_.get() -def map_dense(bb, call): +def map_matmul(bb, call): x, w = call.args - return bb.call_te(topi.nn.dense, x, w) + return bb.call_te(topi.nn.matmul, x, w) def map_add(bb, call): a, b = call.args @@ -315,11 +311,14 @@ def map_add(bb, call): def map_relu(bb, call): return bb.call_te(topi.nn.relu, call.args[0]) +def map_transpose(bb, call): + return bb.call_te(topi.transpose, call.args[0], ) op_map = { - "relax.nn.dense": map_dense, + "relax.matmul": map_matmul, "relax.add": map_add, - "relax.nn.relu": map_relu + "relax.nn.relu": map_relu, + "relax.permute_dims": map_transpose } @tvm.ir.transform.module_pass(opt_level=0, name="LowerToTensorIR") @@ -333,7 +332,7 @@ MLPModelTIR = LowerToTensorIRPass()(MLPFused) MLPModelTIR.show() ``` -请注意,在上面的代码中。 `fused_dense_add0` 和 `fused_dense_add1` 仍然是上层 relax 函数,它们调用相应的 TensorIR `dense` 和 `add` 函数。 我们可以将它们变成一个单一的 TensorIR 函数,然后可以用于后续优化和代码生成阶段。 +请注意,在上面的代码中。 `fused_matmul_add0` 和 `fused_matmul_add1` 仍然是上层 relax 函数,它们调用相应的 TensorIR `matmul` 和 `add` 函数。 我们可以将它们变成一个单一的 TensorIR 函数,然后可以用于后续优化和代码生成阶段。 ```{.python .input} MLPModelFinal = relax.transform.FuseTIR()(MLPModelTIR) @@ -376,7 +375,7 @@ print("Class:", class_names[label[0]]) ``` ```{.python .output} -ex = relax.vm.build(MLPModelFinal, target="llvm") +ex = relax.build(MLPModelFinal, target="llvm") vm = relax.VirtualMachine(ex, tvm.cpu()) data_nd = tvm.nd.array(img.reshape(1, 784)) diff --git a/chapter_integration/index.md b/chapter_integration/index.md index 41435e7..abfb898 100644 --- a/chapter_integration/index.md +++ b/chapter_integration/index.md @@ -16,9 +16,6 @@ from tvm.ir.module import IRModule from tvm.script import tir as T, relax as R from tvm import relax import numpy as np - -# This is needed for deferring annotation parsing in TVMScript -from __future__ import annotations ``` ```{.python .input} @@ -133,8 +130,8 @@ te.create_prim_func([A, B, C, D]).show() 让我们首先创建一个 block builder,它可以帮助我们逐步构建一个 `relax.Function`。 ```{.python .input} -A = relax.Var("A", (128, 128), relax.DynTensorType(2, "float32")) -B = relax.Var("B", (128, 128), relax.DynTensorType(2, "float32")) +A = relax.Var("A", relax.TensorStructInfo((128, 128), "float32")) +B = relax.Var("B", relax.TensorStructInfo((128, 128), "float32")) ``` 我们通过创建 block builder 和一系列元张量函数来构造 Relax 函数。 @@ -180,7 +177,7 @@ isinstance(C, relax.Var) Relax 函数中的每一行都是由 `emit_te` 调用生成的。 例如, ```python -lv = R.call_tir(te_matmul, (A, B), (128, 128), dtype="float32") +lv = R.call_dps_packed(te_matmul, (A, B), (128, 128), dtype="float32") ``` 是由如下代码所生成。 @@ -194,7 +191,7 @@ C = bb.emit_te(te_matmul, A, B). - 为 A 和 B 创建一个输入 `te.placeholder`。 - 通过 `te_matmul` 函数运行它们。 - 调用 `te.create_prim_func` 来创建一个 TensorIR 函数。 -- 通过 `call_tir` 生成对函数的调用。 +- 通过 `call_dps_packed` 生成对函数的调用。 我们可以发现,上面 BlockBuilder 构造后的结果是一个有两个中间值的计算图,一个节点对应 `te_matmul` 操作,另一个节点对应 `te_relu`。 @@ -274,9 +271,8 @@ fx_module.graph.print_tabular() ```{.python .input} def map_param(param: nn.Parameter): - ndim = len(param.data.shape) return relax.const( - param.data.cpu().numpy(), relax.DynTensorType(ndim, "float32") + param.data.cpu().numpy(), relax.TensorStructInfo(param.data.shape, "float32") ) def fetch_attr(fx_mod, target: str): @@ -306,7 +302,7 @@ def from_fx(fx_mod, input_shapes, call_function_map, call_module_map): shape = input_shapes[input_index] input_index += 1 input_var = relax.Var( - node.target, shape, relax.DynTensorType(len(shape), "float32") + node.target, relax.TensorStructInfo(shape, "float32") ) fn_inputs.append(input_var) node_map[node] = input_var @@ -460,7 +456,7 @@ MLPModule.show() ``` ```{.python .input} -ex = relax.vm.build(MLPModule, target="llvm") +ex = relax.build(MLPModule, target="llvm") vm = relax.VirtualMachine(ex, tvm.cpu()) data_nd = tvm.nd.array(img.reshape(1, 784)) @@ -477,15 +473,13 @@ print("MLPModule Prediction:", class_names[pred_kind[0]]) ```{.python .input} def map_nn_relu_op(bb, node_map, node, nn_mod): A = node_map[node.args[0]] - return bb.emit(relax.op.relu(A)) + return bb.emit(relax.op.nn.relu(A)) def map_nn_linear_op(bb, node_map, node, nn_mod): x = node_map[node.args[0]] w = map_param(nn_mod.weight) - if nn_mod.bias is not None: - b = map_param(nn_mod.bias) - y = bb.emit(relax.op.dense(x, w)) - return bb.emit(relax.op.add(y, b)) + b = map_param(nn_mod.bias) + return bb.emit(relax.op.linear(x, w, b)) MLPModuleHighLevel = from_fx( fx.symbolic_trace(mlp_model), @@ -503,7 +497,7 @@ MLPModuleHighLevel.show() 上面展示了我们使用哪些内置的算子将模型导入为 IRModule 后的结果。这些内置算子是 **比 TensorIR 函数更高级别的抽象**。我们可以有不同的机会将这些原始算子进一步转换为库函数或 TensorIR 函数。 -在大多数情况下,在有高级算子支持的情况下,转换为高级内置函数会很有帮助。但是,有很多情况下我们找不到对应的高级内置算子或者想直接指定 TensorIR 函数。 在这些情况下,我们可以自定义翻译逻辑或变换从而生成 `call_tir` 或调用库函数。 通常,我们可以结合高级操作、TensorIR 和库抽象来获得最佳结果。 我们将在后续章节中讨论权衡取舍。 +在大多数情况下,在有高级算子支持的情况下,转换为高级内置函数会很有帮助。但是,有很多情况下我们找不到对应的高级内置算子或者想直接指定 TensorIR 函数。 在这些情况下,我们可以自定义翻译逻辑或变换从而生成 `call_dps_packed` 或调用库函数。 通常,我们可以结合高级操作、TensorIR 和库抽象来获得最佳结果。 我们将在后续章节中讨论权衡取舍。 ## 讨论 diff --git a/chapter_tensor_program/case_study.md b/chapter_tensor_program/case_study.md index 0c5edab..9c73668 100644 --- a/chapter_tensor_program/case_study.md +++ b/chapter_tensor_program/case_study.md @@ -98,9 +98,9 @@ np.testing.assert_allclose(c_mm_relu, c_np, rtol=1e-5) @tvm.script.ir_module class MyModule: @T.prim_func - def mm_relu(A: T.Buffer[(128, 128), "float32"], - B: T.Buffer[(128, 128), "float32"], - C: T.Buffer[(128, 128), "float32"]): + def mm_relu(A: T.Buffer((128, 128), "float32"), + B: T.Buffer((128, 128), "float32"), + C: T.Buffer((128, 128), "float32")): T.func_attr({"global_symbol": "mm_relu", "tir.noalias": True}) Y = T.alloc_buffer((128, 128), dtype="float32") for i, j, k in T.grid(128, 128, 128): @@ -130,9 +130,9 @@ class MyModule: ```python # TensorIR -def mm_relu(A: T.Buffer[(128, 128), "float32"], - B: T.Buffer[(128, 128), "float32"], - C: T.Buffer[(128, 128), "float32"]): +def mm_relu(A: T.Buffer((128, 128), "float32"), + B: T.Buffer((128, 128), "float32"), + C: T.Buffer((128, 128), "float32")): ... # numpy def lnumpy_mm_relu(A: np.ndarray, B: np.ndarray, C: np.ndarray): @@ -258,9 +258,9 @@ vk = T.axis.reduce(range_of_k, k) @tvm.script.ir_module class MyModuleWithAxisRemapSugar: @T.prim_func - def mm_relu(A: T.Buffer[(128, 128), "float32"], - B: T.Buffer[(128, 128), "float32"], - C: T.Buffer[(128, 128), "float32"]): + def mm_relu(A: T.Buffer((128, 128), "float32"), + B: T.Buffer((128, 128), "float32"), + C: T.Buffer((128, 128), "float32")): T.func_attr({"global_symbol": "mm_relu", "tir.noalias": True}) Y = T.alloc_buffer((128, 128), dtype="float32") for i, j, k in T.grid(128, 128, 128): @@ -305,9 +305,9 @@ type(MyModule["mm_relu"]) @tvm.script.ir_module class MyModuleWithTwoFunctions: @T.prim_func - def mm(A: T.Buffer[(128, 128), "float32"], - B: T.Buffer[(128, 128), "float32"], - Y: T.Buffer[(128, 128), "float32"]): + def mm(A: T.Buffer((128, 128), "float32"), + B: T.Buffer((128, 128), "float32"), + Y: T.Buffer((128, 128), "float32")): T.func_attr({"global_symbol": "mm", "tir.noalias": True}) for i, j, k in T.grid(128, 128, 128): with T.block("Y"): @@ -317,8 +317,8 @@ class MyModuleWithTwoFunctions: Y[vi, vj] = Y[vi, vj] + A[vi, vk] * B[vk, vj] @T.prim_func - def relu(A: T.Buffer[(128, 128), "float32"], - B: T.Buffer[(128, 128), "float32"]): + def relu(A: T.Buffer((128, 128), "float32"), + B: T.Buffer((128, 128), "float32")): T.func_attr({"global_symbol": "relu", "tir.noalias": True}) for i, j in T.grid(128, 128): with T.block("B"): diff --git a/chapter_tensor_program/tensorir_exercises.md b/chapter_tensor_program/tensorir_exercises.md index 991a9b5..c22983e 100644 --- a/chapter_tensor_program/tensorir_exercises.md +++ b/chapter_tensor_program/tensorir_exercises.md @@ -50,9 +50,9 @@ c_lnumpy @tvm.script.ir_module class MyAdd: @T.prim_func - def add(A: T.Buffer[(4, 4), "int64"], - B: T.Buffer[(4, 4), "int64"], - C: T.Buffer[(4, 4), "int64"]): + def add(A: T.Buffer((4, 4), "int64"), + B: T.Buffer((4, 4), "int64"), + C: T.Buffer((4, 4), "int64")): T.func_attr({"global_symbol": "add"}) for i, j in T.grid(4, 4): with T.block("C"): @@ -164,9 +164,9 @@ np.testing.assert_allclose(conv_tvm.numpy(), conv_torch, rtol=1e-5) @tvm.script.ir_module class MyAdd: @T.prim_func - def add(A: T.Buffer[(4, 4), "int64"], - B: T.Buffer[(4, 4), "int64"], - C: T.Buffer[(4, 4), "int64"]): + def add(A: T.Buffer((4, 4), "int64"), + B: T.Buffer((4, 4), "int64"), + C: T.Buffer((4, 4), "int64")): T.func_attr({"global_symbol": "add"}) for i, j in T.grid(4, 4): with T.block("C"): @@ -228,7 +228,7 @@ IPython.display.Code(sch.mod.script(), language="python") @tvm.script.ir_module class TargetModule: @T.prim_func - def bmm_relu(A: T.Buffer[(16, 128, 128), "float32"], B: T.Buffer[(16, 128, 128), "float32"], C: T.Buffer[(16, 128, 128), "float32"]) -> None: + def bmm_relu(A: T.Buffer((16, 128, 128), "float32"), B: T.Buffer((16, 128, 128), "float32"), C: T.Buffer((16, 128, 128), "float32")) -> None: T.func_attr({"global_symbol": "bmm_relu", "tir.noalias": True}) Y = T.alloc_buffer([16, 128, 128], dtype="float32") for i0 in T.parallel(16):