From 93d48c19fc8f42ab078d0389b0d150c89e6f43eb Mon Sep 17 00:00:00 2001 From: Cristian Garcia Date: Wed, 26 Jul 2023 22:02:08 +0000 Subject: [PATCH] improve scan docs --- flax/linen/transforms.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/flax/linen/transforms.py b/flax/linen/transforms.py index 51fbf177ee..465d44482f 100644 --- a/flax/linen/transforms.py +++ b/flax/linen/transforms.py @@ -834,8 +834,8 @@ def scan( is broadcasted they are typically initialized inside the loop body but independent of the loop variables. - The loop body should have the signature - ``(scope, body, carry, *xs) -> (carry, ys)``, where ``xs`` and ``ys`` + The ``target`` should have the signature + ``(module, carry, *xs) -> (carry, ys)``, where ``xs`` and ``ys`` are the scan values that go in and out of the loop. Example:: @@ -951,8 +951,8 @@ def scan( methods: If `target` is a `Module`, the methods of `Module` to scan over. Returns: - The scan function with the signature ``(scope, carry, *xxs) -> (carry, - yys)``, where ``xxs`` and ``yys`` are the scan values that go in and out of + The scan function with the signature ``(module, carry, *xs) -> (carry, + ys)``, where ``xs`` and ``ys`` are the scan values that go in and out of the loop. """ return lift_transform(