diff --git a/python/tvm/relax/analysis/analysis.py b/python/tvm/relax/analysis/analysis.py index f5a4a1d4ce..6ed319ed85 100644 --- a/python/tvm/relax/analysis/analysis.py +++ b/python/tvm/relax/analysis/analysis.py @@ -331,8 +331,8 @@ def contains_impure_call(expr: Expr, own_name: Optional[Union[Var, GlobalVar]] = """ Check if the given expression (likely a function body) contains any impure calls. - Parameter - --------- + Parameters + ---------- expr : Expr The expression to be examined. If expr is a function, we check the body. diff --git a/python/tvm/relax/frontend/nn/op.py b/python/tvm/relax/frontend/nn/op.py index 0f4c5f6483..8eafa80802 100644 --- a/python/tvm/relax/frontend/nn/op.py +++ b/python/tvm/relax/frontend/nn/op.py @@ -661,6 +661,7 @@ def repeat(x: Tensor, repeats: int, axis: Optional[int] = None, name="repeat") - Examples -------- .. code-block:: python + np_x = numpy.array([[1, 2], [3, 4]]) x = Tensor.from_const(np_x) lv1 = repeat(x, repeats=2) # lv1 == [1, 1, 2, 2, 3, 3, 4, 4] diff --git a/python/tvm/relax/op/base.py b/python/tvm/relax/op/base.py index 627722ff88..603148d0cf 100644 --- a/python/tvm/relax/op/base.py +++ b/python/tvm/relax/op/base.py @@ -664,7 +664,7 @@ def invoke_pure_closure( Invoke a closure and indicate to the compiler that it is pure. Note: This should be used for cases when the user knows that calling the closure - with these arguments will _in reality_ not cause any side effects. + with these arguments will **in reality** not cause any side effects. If it is used for a call that _does_ result in side effects, then the compiler may end up removing, reordering, or repeating that call, with no guarantees made about any side effects from the callee. diff --git a/python/tvm/relax/op/manipulate.py b/python/tvm/relax/op/manipulate.py index a6f1c580a4..3dfa371e42 100644 --- a/python/tvm/relax/op/manipulate.py +++ b/python/tvm/relax/op/manipulate.py @@ -350,6 +350,7 @@ def repeat(data: Expr, repeats: int, axis: Optional[int] = None) -> Expr: Examples -------- .. code-block:: python + x = R.const([[1, 2], [3, 4]]) lv1 = R.repeat(x, repeats=2) # lv1 == [1, 1, 2, 2, 3, 3, 4, 4] lv2 = R.repeat(x, repeats=2, axis=1) # lv2 == [[1., 1., 2., 2.], @@ -435,13 +436,20 @@ def scatter_elements( specified by `updates` at specific index positions specified by `indices`. For example, in 2D tensor, the update corresponding to the [i][j] entry is performed as below: + + .. code-block:: + output[indices[i][j]][j] = updates[i][j] if axis = 0 output[i][indices[i][j]] = updates[i][j] if axis = 1 When the `reduction` is set to some reduction function `f`, the update corresponding to [i][j] entry is performed as below: + + .. code-block:: + output[indices[i][j]][j] += f(output[indices[i][j]][j], updates[i][j]) if axis = 0 output[i][indices[i][j]] += f(output[i][indices[i][j]], updates[i][j]) if axis = 1 + Where `f` is update, add, mul, mean, max, min. Parameters @@ -470,6 +478,7 @@ def scatter_elements( Examples -------- .. code-block:: python + # inputs data = [ [0.0, 0.0, 0.0], diff --git a/python/tvm/relax/op/nn/nn.py b/python/tvm/relax/op/nn/nn.py index 89334591bd..3eddd5f591 100644 --- a/python/tvm/relax/op/nn/nn.py +++ b/python/tvm/relax/op/nn/nn.py @@ -871,9 +871,11 @@ def batch_norm( .. note:: This operator has two modes: + - Training mode. - Use the mean and var computed from THIS batch to normalize. - Update and then return the running mean and running var. + - Inference mode. - Use the running_mean and running_var parameters to normalize. - Do not update the running mean and running var. Just return the original value. @@ -1005,7 +1007,7 @@ def group_norm( into groups along the channel axis. Then apply layer normalization to each group. Parameters - ---------- + ---------- data : relax.Expr Input to which group_norm will be applied. @@ -1231,22 +1233,37 @@ def attention( while for 'BottomRight', the mask matrix is as `np.tril(*, k=abs(seq_len - seq_len_kv))` For example, with seq_len = 4, seq_len_kv = 2, mask for 'TopLeft': - [[1, 0], - [1, 1], - [1, 1], - [1, 1]] + + .. code:: python + + [[1, 0], + [1, 1], + [1, 1], + [1, 1]] + mask for 'BottomRight': - [[1, 1], - [1, 1], - [1, 1], - [1, 1]] + + .. code:: python + + [[1, 1], + [1, 1], + [1, 1], + [1, 1]] + with seq_len = 2, seq_len_kv = 4, mask for 'TopLeft': - [[1, 0, 0, 0], - [1, 1, 0, 0]] + + .. code:: python + + [[1, 0, 0, 0], + [1, 1, 0, 0]] + mask for 'BottomRight': - [[1, 1, 1, 0], - [1, 1, 1, 1]] + + .. code:: python + + [[1, 1, 1, 0], + [1, 1, 1, 1]] Returns diff --git a/python/tvm/relax/transform/transform.py b/python/tvm/relax/transform/transform.py index 2b7a788e32..1676ba18c1 100644 --- a/python/tvm/relax/transform/transform.py +++ b/python/tvm/relax/transform/transform.py @@ -65,6 +65,7 @@ def Gradient( The new function will be like: .. code-block:: python + @R.function def main_adjoint(original_parameters): with R.dataflow(): @@ -161,6 +162,7 @@ def main_adjoint( The second example is returning multiple values and specifying the target with `target_index`: .. code-block:: python + @I.ir_module class Module: @R.function @@ -178,6 +180,7 @@ def main( The module after the Gradient pass will be: .. code-block:: python + @I.ir_module class Module: @R.function @@ -271,7 +274,8 @@ def CallTIRRewrite() -> tvm.ir.transform.Pass: def Normalize() -> tvm.ir.transform.Pass: """Transforming Relax IR to normal form, i.e., the expressions are normalized(no nesting - and hence the AST is in ANF), and all checked_type_ and shape_ of expressions are available. + and hence the AST is in ANF), and all ``checked_type_`` and ``shape_`` of expressions are + available. Returns ------- @@ -356,7 +360,7 @@ def StaticPlanBlockMemory() -> tvm.ir.transform.Pass: function signature for clarity. For example, we can annotate a Relax function with - `R.func_attr({"tir_var_upper_bound": {"n": 1024}})`. + :code:`R.func_attr({"tir_var_upper_bound": {"n": 1024}})`. It means the maximum value of variable that names "n" in the function signature will have upper bound 1024. And we will use 1024 as its value during memory planning. @@ -441,15 +445,10 @@ def BindParams( Parameters ---------- - func_name: str The function name to be bound - params : Dict[ - Union[str,relax.Var], - Union[tvm.runtime.NDArray, np.ndarray], - ] - + params : Dict[Union[str,relax.Var],Union[tvm.runtime.NDArray, np.ndarray]] The map from parameter or parameter name name to constant tensors. @@ -474,15 +473,14 @@ def BindSymbolicVars( func_name: Optional[str] = None, ) -> tvm.ir.transform.Pass: """Bind params of function of the module to constant tensors. + Parameters ---------- binding_map : Mapping[Union[str, tvm.tir.Var], tvm.tir.PrimExpr] - The map from symbolic varname to integer. - func_name: Optional[str] - - The function name to be bound. If None (default), all + func_name : Optional[str] + The function name to be bound. If None (default), all functions within the module will be updated. Returns @@ -686,7 +684,7 @@ def FuseOpsByPattern( in which they are matched. Higher-priority patterns should come earlier in the list. In addition to FusionPattern, a tuple can be passed as item of this list. The pattern - will be constructed through FusionPattern(*item) + will be constructed through :code:`FusionPattern(*item)` bind_constants : bool Whether or not to keep bound constants in the grouped function. @@ -941,6 +939,7 @@ def MetaScheduleTuneIRMod( op_names: Optional[List[str]] = None, ) -> tvm.ir.transform.Pass: """Tune Relax IRModule with MetaSchedule. + Parameters ---------- params: Dict[str, NDArray] @@ -1065,13 +1064,15 @@ def AlterOpImpl( def ConvertLayout(desired_layouts: Dict[str, List[str]]) -> tvm.ir.transform.Pass: """Automatic layout conversion pass. + Parameters ---------- desired_layouts : Dict[str, List[str]] The desired layout of conv2d ops is a map from the name of the op to the desired layout of the desired feature map, weight and output. For example, if we want to convert the layout of conv2d from NCHW to NHWC, we can set the desired layout of conv2d to be - {"relax.nn.conv2d": ["NHWC", "OHWI"]}. + ``{"relax.nn.conv2d": ["NHWC", "OHWI"]}``. + Returns ------- ret : tvm.transform.Pass @@ -1082,21 +1083,23 @@ def ConvertLayout(desired_layouts: Dict[str, List[str]]) -> tvm.ir.transform.Pas def DeadCodeElimination(entry_functions: Optional[List[str]] = None) -> tvm.ir.transform.Pass: """Remove dead code in the IRModule. - Currently it removes: + Currently it removes: + 1. Unused local VarBindings in a DataflowBlock. 2. Unused DataflowBlocks in a function. 3. Unused Relax functions in the module. We detect the call chain from the entry function, and remove all unused functions. - Parameters - ---------- - entry_functions: Optional[List[str]] - The set of entry functions to start from. Notes ----- For function-wise DCE, use py:func:`tvm.relax.analysis.remove_all_unused`. + Parameters + ---------- + entry_functions: Optional[List[str]] + The set of entry functions to start from. + Returns ------- ret : tvm.transform.Pass @@ -1112,6 +1115,7 @@ def ToMixedPrecision( ) -> tvm.ir.transform.Pass: """Automatic mixed precision pass. Currently the pass assumes the input module to be fp32 only, and will automatically cast fp32 to fp16 for certain ops. + Parameters ---------- out_dtype : str @@ -1128,17 +1132,20 @@ def ToMixedPrecision( return _ffi_api.ToMixedPrecision(out_dtype, fp16_input_names) # type: ignore -def SplitCallTIRByPattern(patterns, fcodegen) -> tvm.ir.transform.Pass: +def SplitCallTIRByPattern(patterns: List[PrimFunc], fcodegen: Callable) -> tvm.ir.transform.Pass: """Split a PrimFunc into 2 parts: the first part is a TIR PrimFunc which is matched with some pattern, and the second part is the rest of the original PrimFunc. It will call fcodegen to generate the code for the matched pattern to replace it with a ExternFunc call. + Parameters ---------- patterns : List[PrimFunc] The list of patterns to match. + fcodegen: Callable[[List[MatchResult]], List[Object]] The function to generate the code for the matched patterns. + Returns ------- ret : tvm.transform.Pass diff --git a/python/tvm/te/operation.py b/python/tvm/te/operation.py index 4edbe247f5..5547ef82d7 100644 --- a/python/tvm/te/operation.py +++ b/python/tvm/te/operation.py @@ -570,39 +570,49 @@ def create_prim_func( ops: List[Union[_tensor.Tensor, tvm.tir.Var]], index_dtype_override: Optional[str] = None ) -> tvm.tir.PrimFunc: """Create a TensorIR PrimFunc from tensor expression + Parameters ---------- ops : List[Union[_tensor.Tensor, tvm.tir.Var]] The source expression. + Example ------- We define a matmul kernel using following code: + .. code-block:: python + import tvm from tvm import te from tvm.te import create_prim_func import tvm.script + A = te.placeholder((128, 128), name="A") B = te.placeholder((128, 128), name="B") k = te.reduce_axis((0, 128), "k") C = te.compute((128, 128), lambda x, y: te.sum(A[x, k] * B[y, k], axis=k), name="C") func = create_prim_func([A, B, C]) print(func.script()) + If we want to use TensorIR schedule to do transformations on such kernel, we need to use `create_prim_func([A, B, C])` to create a schedulable PrimFunc. The generated function looks like: + .. code-block:: python + @T.prim_func def tir_matmul(a: T.handle, b: T.handle, c: T.handle) -> None: A = T.match_buffer(a, (128, 128)) B = T.match_buffer(b, (128, 128)) C = T.match_buffer(c, (128, 128)) - for i, j, k in T.grip(128, 128, 128): + + for i, j, k in T.grid(128, 128, 128): with T.block(): vi, vj, vk = T.axis.remap("SSR", [i, j, k]) with T.init(): C[vi, vj] = 0.0 C[vi, vj] += A[vi, vk] * B[vj, vk] + Returns ------- func : tir.PrimFunc