diff --git a/flax/linen/transforms.py b/flax/linen/transforms.py index 51fbf177ee..7f4cf1c775 100644 --- a/flax/linen/transforms.py +++ b/flax/linen/transforms.py @@ -834,8 +834,8 @@ def scan( is broadcasted they are typically initialized inside the loop body but independent of the loop variables. - The loop body should have the signature - ``(scope, body, carry, *xs) -> (carry, ys)``, where ``xs`` and ``ys`` + The ``target`` should have the signature + ``(module, carry, *xs) -> (carry, ys)``, where ``xs`` and ``ys`` are the scan values that go in and out of the loop. Example:: @@ -921,8 +921,8 @@ def scan( which will in addition checkpoint each layer in the scan loop. Args: - target: a ``Module`` or a function taking a ``Module`` - as its first argument. + target: a ``Module`` or a function taking a ``Module`` as its first + argument. variable_axes: the variable collections that are scanned over. variable_broadcast: Specifies the broadcasted variable collections. A broadcasted variable should not depend on any computation that cannot be @@ -941,8 +941,8 @@ def scan( length: Specifies the number of loop iterations. This only needs to be specified if it cannot be derived from the scan arguments. reverse: If true, scan from end to start in reverse order. - unroll: how many scan iterations to unroll within a single - iteration of a loop (default: 1). + unroll: how many scan iterations to unroll within a single iteration of a + loop (default: 1). data_transform: optional function to transform raw functional-core variable and rng groups inside lifted scan body_fn, intended for inline SPMD annotations. @@ -951,8 +951,8 @@ def scan( methods: If `target` is a `Module`, the methods of `Module` to scan over. Returns: - The scan function with the signature ``(scope, carry, *xxs) -> (carry, - yys)``, where ``xxs`` and ``yys`` are the scan values that go in and out of + The scan function with the signature ``(module, carry, *xs) -> (carry, + ys)``, where ``xs`` and ``ys`` are the scan values that go in and out of the loop. """ return lift_transform(