Skip to content

Commit

Permalink
Forward all arguments when using nnx.transforms.deprecated.scan as a …
Browse files Browse the repository at this point in the history
…decorator.

PiperOrigin-RevId: 675988367
  • Loading branch information
nimrod-gileadi authored and Flax Authors committed Sep 18, 2024
1 parent e356acb commit 1708b48
Showing 1 changed file with 13 additions and 1 deletion.
14 changes: 13 additions & 1 deletion flax/nnx/transforms/deprecated.py
Original file line number Diff line number Diff line change
Expand Up @@ -1238,7 +1238,19 @@ def scan(
) -> F | tp.Callable[[F], F]:
if isinstance(f, Missing):
return functools.partial(
scan, length=length, reverse=reverse, unroll=unroll
scan,
length=length,
reverse=reverse,
unroll=unroll,
_split_transpose=_split_transpose,
in_axes=in_axes,
in_axes_kwargs=in_axes_kwargs,
out_axes=out_axes,
carry_argnum=carry_argnum,
state_axes=state_axes,
split_rngs=split_rngs,
transform_metadata=transform_metadata,
scan_output=scan_output,
) # type: ignore[return-value]

@functools.wraps(f)
Expand Down

0 comments on commit 1708b48

Please sign in to comment.