From 1708b48bb40c023ce3a3c69edeb90edcbd4f7e9a Mon Sep 17 00:00:00 2001 From: Nimrod Gileadi Date: Wed, 18 Sep 2024 07:44:45 -0700 Subject: [PATCH] Forward all arguments when using nnx.transforms.deprecated.scan as a decorator. PiperOrigin-RevId: 675988367 --- flax/nnx/transforms/deprecated.py | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/flax/nnx/transforms/deprecated.py b/flax/nnx/transforms/deprecated.py index 99dc5a806d..f0191fc020 100644 --- a/flax/nnx/transforms/deprecated.py +++ b/flax/nnx/transforms/deprecated.py @@ -1238,7 +1238,19 @@ def scan( ) -> F | tp.Callable[[F], F]: if isinstance(f, Missing): return functools.partial( - scan, length=length, reverse=reverse, unroll=unroll + scan, + length=length, + reverse=reverse, + unroll=unroll, + _split_transpose=_split_transpose, + in_axes=in_axes, + in_axes_kwargs=in_axes_kwargs, + out_axes=out_axes, + carry_argnum=carry_argnum, + state_axes=state_axes, + split_rngs=split_rngs, + transform_metadata=transform_metadata, + scan_output=scan_output, ) # type: ignore[return-value] @functools.wraps(f)