Skip to content

Commit

Permalink
[Unity][Doc] Fix Relax part fix_docstring (#15860)
Browse files Browse the repository at this point in the history
Following #15848, this PR continues to fix the docstring of Relax part.
  • Loading branch information
Hzfengsy authored Oct 3, 2023
1 parent 4735a5b commit 063cd7f
Show file tree
Hide file tree
Showing 7 changed files with 81 additions and 37 deletions.
4 changes: 2 additions & 2 deletions python/tvm/relax/analysis/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -331,8 +331,8 @@ def contains_impure_call(expr: Expr, own_name: Optional[Union[Var, GlobalVar]] =
"""
Check if the given expression (likely a function body) contains any impure calls.
Parameter
---------
Parameters
----------
expr : Expr
The expression to be examined. If expr is a function, we check the body.
Expand Down
1 change: 1 addition & 0 deletions python/tvm/relax/frontend/nn/op.py
Original file line number Diff line number Diff line change
Expand Up @@ -661,6 +661,7 @@ def repeat(x: Tensor, repeats: int, axis: Optional[int] = None, name="repeat") -
Examples
--------
.. code-block:: python
np_x = numpy.array([[1, 2], [3, 4]])
x = Tensor.from_const(np_x)
lv1 = repeat(x, repeats=2) # lv1 == [1, 1, 2, 2, 3, 3, 4, 4]
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/relax/op/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -664,7 +664,7 @@ def invoke_pure_closure(
Invoke a closure and indicate to the compiler that it is pure.
Note: This should be used for cases when the user knows that calling the closure
with these arguments will _in reality_ not cause any side effects.
with these arguments will **in reality** not cause any side effects.
If it is used for a call that _does_ result in side effects, then the compiler
may end up removing, reordering, or repeating that call, with no guarantees
made about any side effects from the callee.
Expand Down
9 changes: 9 additions & 0 deletions python/tvm/relax/op/manipulate.py
Original file line number Diff line number Diff line change
Expand Up @@ -350,6 +350,7 @@ def repeat(data: Expr, repeats: int, axis: Optional[int] = None) -> Expr:
Examples
--------
.. code-block:: python
x = R.const([[1, 2], [3, 4]])
lv1 = R.repeat(x, repeats=2) # lv1 == [1, 1, 2, 2, 3, 3, 4, 4]
lv2 = R.repeat(x, repeats=2, axis=1) # lv2 == [[1., 1., 2., 2.],
Expand Down Expand Up @@ -435,13 +436,20 @@ def scatter_elements(
specified by `updates` at specific index positions specified by `indices`.
For example, in 2D tensor, the update corresponding to the [i][j] entry is performed
as below:
.. code-block::
output[indices[i][j]][j] = updates[i][j] if axis = 0
output[i][indices[i][j]] = updates[i][j] if axis = 1
When the `reduction` is set to some reduction function `f`, the update corresponding to
[i][j] entry is performed as below:
.. code-block::
output[indices[i][j]][j] += f(output[indices[i][j]][j], updates[i][j]) if axis = 0
output[i][indices[i][j]] += f(output[i][indices[i][j]], updates[i][j]) if axis = 1
Where `f` is update, add, mul, mean, max, min.
Parameters
Expand Down Expand Up @@ -470,6 +478,7 @@ def scatter_elements(
Examples
--------
.. code-block:: python
# inputs
data = [
[0.0, 0.0, 0.0],
Expand Down
43 changes: 30 additions & 13 deletions python/tvm/relax/op/nn/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -871,9 +871,11 @@ def batch_norm(
.. note::
This operator has two modes:
- Training mode.
- Use the mean and var computed from THIS batch to normalize.
- Update and then return the running mean and running var.
- Inference mode.
- Use the running_mean and running_var parameters to normalize.
- Do not update the running mean and running var. Just return the original value.
Expand Down Expand Up @@ -1005,7 +1007,7 @@ def group_norm(
into groups along the channel axis. Then apply layer normalization to each group.
Parameters
----------
----------
data : relax.Expr
Input to which group_norm will be applied.
Expand Down Expand Up @@ -1231,22 +1233,37 @@ def attention(
while for 'BottomRight', the mask matrix is as `np.tril(*, k=abs(seq_len - seq_len_kv))`
For example, with seq_len = 4, seq_len_kv = 2,
mask for 'TopLeft':
[[1, 0],
[1, 1],
[1, 1],
[1, 1]]
.. code:: python
[[1, 0],
[1, 1],
[1, 1],
[1, 1]]
mask for 'BottomRight':
[[1, 1],
[1, 1],
[1, 1],
[1, 1]]
.. code:: python
[[1, 1],
[1, 1],
[1, 1],
[1, 1]]
with seq_len = 2, seq_len_kv = 4,
mask for 'TopLeft':
[[1, 0, 0, 0],
[1, 1, 0, 0]]
.. code:: python
[[1, 0, 0, 0],
[1, 1, 0, 0]]
mask for 'BottomRight':
[[1, 1, 1, 0],
[1, 1, 1, 1]]
.. code:: python
[[1, 1, 1, 0],
[1, 1, 1, 1]]
Returns
Expand Down
47 changes: 27 additions & 20 deletions python/tvm/relax/transform/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ def Gradient(
The new function will be like:
.. code-block:: python
@R.function
def main_adjoint(original_parameters):
with R.dataflow():
Expand Down Expand Up @@ -161,6 +162,7 @@ def main_adjoint(
The second example is returning multiple values and specifying the target with `target_index`:
.. code-block:: python
@I.ir_module
class Module:
@R.function
Expand All @@ -178,6 +180,7 @@ def main(
The module after the Gradient pass will be:
.. code-block:: python
@I.ir_module
class Module:
@R.function
Expand Down Expand Up @@ -271,7 +274,8 @@ def CallTIRRewrite() -> tvm.ir.transform.Pass:

def Normalize() -> tvm.ir.transform.Pass:
"""Transforming Relax IR to normal form, i.e., the expressions are normalized(no nesting
and hence the AST is in ANF), and all checked_type_ and shape_ of expressions are available.
and hence the AST is in ANF), and all ``checked_type_`` and ``shape_`` of expressions are
available.
Returns
-------
Expand Down Expand Up @@ -356,7 +360,7 @@ def StaticPlanBlockMemory() -> tvm.ir.transform.Pass:
function signature for clarity.
For example, we can annotate a Relax function with
`R.func_attr({"tir_var_upper_bound": {"n": 1024}})`.
:code:`R.func_attr({"tir_var_upper_bound": {"n": 1024}})`.
It means the maximum value of variable that names "n" in the function
signature will have upper bound 1024. And we will use 1024 as its value
during memory planning.
Expand Down Expand Up @@ -441,15 +445,10 @@ def BindParams(
Parameters
----------
func_name: str
The function name to be bound
params : Dict[
Union[str,relax.Var],
Union[tvm.runtime.NDArray, np.ndarray],
]
params : Dict[Union[str,relax.Var],Union[tvm.runtime.NDArray, np.ndarray]]
The map from parameter or parameter name name to constant
tensors.
Expand All @@ -474,15 +473,14 @@ def BindSymbolicVars(
func_name: Optional[str] = None,
) -> tvm.ir.transform.Pass:
"""Bind params of function of the module to constant tensors.
Parameters
----------
binding_map : Mapping[Union[str, tvm.tir.Var], tvm.tir.PrimExpr]
The map from symbolic varname to integer.
func_name: Optional[str]
The function name to be bound. If None (default), all
func_name : Optional[str]
The function name to be bound. If None (default), all
functions within the module will be updated.
Returns
Expand Down Expand Up @@ -686,7 +684,7 @@ def FuseOpsByPattern(
in which they are matched. Higher-priority patterns should come earlier in the list.
In addition to FusionPattern, a tuple can be passed as item of this list. The pattern
will be constructed through FusionPattern(*item)
will be constructed through :code:`FusionPattern(*item)`
bind_constants : bool
Whether or not to keep bound constants in the grouped function.
Expand Down Expand Up @@ -941,6 +939,7 @@ def MetaScheduleTuneIRMod(
op_names: Optional[List[str]] = None,
) -> tvm.ir.transform.Pass:
"""Tune Relax IRModule with MetaSchedule.
Parameters
----------
params: Dict[str, NDArray]
Expand Down Expand Up @@ -1065,13 +1064,15 @@ def AlterOpImpl(

def ConvertLayout(desired_layouts: Dict[str, List[str]]) -> tvm.ir.transform.Pass:
"""Automatic layout conversion pass.
Parameters
----------
desired_layouts : Dict[str, List[str]]
The desired layout of conv2d ops is a map from the name of the op to the desired layout
of the desired feature map, weight and output. For example, if we want to convert the
layout of conv2d from NCHW to NHWC, we can set the desired layout of conv2d to be
{"relax.nn.conv2d": ["NHWC", "OHWI"]}.
``{"relax.nn.conv2d": ["NHWC", "OHWI"]}``.
Returns
-------
ret : tvm.transform.Pass
Expand All @@ -1082,21 +1083,23 @@ def ConvertLayout(desired_layouts: Dict[str, List[str]]) -> tvm.ir.transform.Pas

def DeadCodeElimination(entry_functions: Optional[List[str]] = None) -> tvm.ir.transform.Pass:
"""Remove dead code in the IRModule.
Currently it removes:
Currently it removes:
1. Unused local VarBindings in a DataflowBlock.
2. Unused DataflowBlocks in a function.
3. Unused Relax functions in the module.
We detect the call chain from the entry function, and remove all unused functions.
Parameters
----------
entry_functions: Optional[List[str]]
The set of entry functions to start from.
Notes
-----
For function-wise DCE, use py:func:`tvm.relax.analysis.remove_all_unused`.
Parameters
----------
entry_functions: Optional[List[str]]
The set of entry functions to start from.
Returns
-------
ret : tvm.transform.Pass
Expand All @@ -1112,6 +1115,7 @@ def ToMixedPrecision(
) -> tvm.ir.transform.Pass:
"""Automatic mixed precision pass. Currently the pass assumes the input module to be fp32
only, and will automatically cast fp32 to fp16 for certain ops.
Parameters
----------
out_dtype : str
Expand All @@ -1128,17 +1132,20 @@ def ToMixedPrecision(
return _ffi_api.ToMixedPrecision(out_dtype, fp16_input_names) # type: ignore


def SplitCallTIRByPattern(patterns, fcodegen) -> tvm.ir.transform.Pass:
def SplitCallTIRByPattern(patterns: List[PrimFunc], fcodegen: Callable) -> tvm.ir.transform.Pass:
"""Split a PrimFunc into 2 parts: the first part is a TIR PrimFunc which is
matched with some pattern, and the second part is the rest of the original
PrimFunc. It will call fcodegen to generate the code for the matched pattern
to replace it with a ExternFunc call.
Parameters
----------
patterns : List[PrimFunc]
The list of patterns to match.
fcodegen: Callable[[List[MatchResult]], List[Object]]
The function to generate the code for the matched patterns.
Returns
-------
ret : tvm.transform.Pass
Expand Down
12 changes: 11 additions & 1 deletion python/tvm/te/operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -570,39 +570,49 @@ def create_prim_func(
ops: List[Union[_tensor.Tensor, tvm.tir.Var]], index_dtype_override: Optional[str] = None
) -> tvm.tir.PrimFunc:
"""Create a TensorIR PrimFunc from tensor expression
Parameters
----------
ops : List[Union[_tensor.Tensor, tvm.tir.Var]]
The source expression.
Example
-------
We define a matmul kernel using following code:
.. code-block:: python
import tvm
from tvm import te
from tvm.te import create_prim_func
import tvm.script
A = te.placeholder((128, 128), name="A")
B = te.placeholder((128, 128), name="B")
k = te.reduce_axis((0, 128), "k")
C = te.compute((128, 128), lambda x, y: te.sum(A[x, k] * B[y, k], axis=k), name="C")
func = create_prim_func([A, B, C])
print(func.script())
If we want to use TensorIR schedule to do transformations on such kernel,
we need to use `create_prim_func([A, B, C])` to create a schedulable PrimFunc.
The generated function looks like:
.. code-block:: python
@T.prim_func
def tir_matmul(a: T.handle, b: T.handle, c: T.handle) -> None:
A = T.match_buffer(a, (128, 128))
B = T.match_buffer(b, (128, 128))
C = T.match_buffer(c, (128, 128))
for i, j, k in T.grip(128, 128, 128):
for i, j, k in T.grid(128, 128, 128):
with T.block():
vi, vj, vk = T.axis.remap("SSR", [i, j, k])
with T.init():
C[vi, vj] = 0.0
C[vi, vj] += A[vi, vk] * B[vj, vk]
Returns
-------
func : tir.PrimFunc
Expand Down

0 comments on commit 063cd7f

Please sign in to comment.