Skip to content

Commit

Permalink
Merge pull request #3231 from google:fix-scan-docs
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 551343861
  • Loading branch information
Flax Authors committed Jul 26, 2023
2 parents 719217b + 93d48c1 commit f103233
Showing 1 changed file with 8 additions and 8 deletions.
16 changes: 8 additions & 8 deletions flax/linen/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -834,8 +834,8 @@ def scan(
is broadcasted they are typically initialized inside the loop body but
independent of the loop variables.
The loop body should have the signature
``(scope, body, carry, *xs) -> (carry, ys)``, where ``xs`` and ``ys``
The ``target`` should have the signature
``(module, carry, *xs) -> (carry, ys)``, where ``xs`` and ``ys``
are the scan values that go in and out of the loop.
Example::
Expand Down Expand Up @@ -921,8 +921,8 @@ def scan(
which will in addition checkpoint each layer in the scan loop.
Args:
target: a ``Module`` or a function taking a ``Module``
as its first argument.
target: a ``Module`` or a function taking a ``Module`` as its first
argument.
variable_axes: the variable collections that are scanned over.
variable_broadcast: Specifies the broadcasted variable collections. A
broadcasted variable should not depend on any computation that cannot be
Expand All @@ -941,8 +941,8 @@ def scan(
length: Specifies the number of loop iterations. This only needs to be
specified if it cannot be derived from the scan arguments.
reverse: If true, scan from end to start in reverse order.
unroll: how many scan iterations to unroll within a single
iteration of a loop (default: 1).
unroll: how many scan iterations to unroll within a single iteration of a
loop (default: 1).
data_transform: optional function to transform raw functional-core variable
and rng groups inside lifted scan body_fn, intended for inline SPMD
annotations.
Expand All @@ -951,8 +951,8 @@ def scan(
methods: If `target` is a `Module`, the methods of `Module` to scan over.
Returns:
The scan function with the signature ``(scope, carry, *xxs) -> (carry,
yys)``, where ``xxs`` and ``yys`` are the scan values that go in and out of
The scan function with the signature ``(module, carry, *xs) -> (carry,
ys)``, where ``xs`` and ``ys`` are the scan values that go in and out of
the loop.
"""
return lift_transform(
Expand Down

0 comments on commit f103233

Please sign in to comment.