Skip to content

Commit

Permalink
improve scan docs
Browse files Browse the repository at this point in the history
  • Loading branch information
cgarciae committed Jul 26, 2023
1 parent 719217b commit 93d48c1
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions flax/linen/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -834,8 +834,8 @@ def scan(
is broadcasted they are typically initialized inside the loop body but
independent of the loop variables.
The loop body should have the signature
``(scope, body, carry, *xs) -> (carry, ys)``, where ``xs`` and ``ys``
The ``target`` should have the signature
``(module, carry, *xs) -> (carry, ys)``, where ``xs`` and ``ys``
are the scan values that go in and out of the loop.
Example::
Expand Down Expand Up @@ -951,8 +951,8 @@ def scan(
methods: If `target` is a `Module`, the methods of `Module` to scan over.
Returns:
The scan function with the signature ``(scope, carry, *xxs) -> (carry,
yys)``, where ``xxs`` and ``yys`` are the scan values that go in and out of
The scan function with the signature ``(module, carry, *xs) -> (carry,
ys)``, where ``xs`` and ``ys`` are the scan values that go in and out of
the loop.
"""
return lift_transform(
Expand Down

0 comments on commit 93d48c1

Please sign in to comment.