diff --git a/flax/nnx/transforms/deprecated.py b/flax/nnx/transforms/deprecated.py index 99dc5a806..f0191fc02 100644 --- a/flax/nnx/transforms/deprecated.py +++ b/flax/nnx/transforms/deprecated.py @@ -1238,7 +1238,19 @@ def scan( ) -> F | tp.Callable[[F], F]: if isinstance(f, Missing): return functools.partial( - scan, length=length, reverse=reverse, unroll=unroll + scan, + length=length, + reverse=reverse, + unroll=unroll, + _split_transpose=_split_transpose, + in_axes=in_axes, + in_axes_kwargs=in_axes_kwargs, + out_axes=out_axes, + carry_argnum=carry_argnum, + state_axes=state_axes, + split_rngs=split_rngs, + transform_metadata=transform_metadata, + scan_output=scan_output, ) # type: ignore[return-value] @functools.wraps(f)