Skip to content

lifted nn.scan unroll seems not work comparing jax.lax.scan #2198

Answered by jheek
luweizheng asked this question in General
Discussion options

You must be logged in to vote

The scan wrapper in linen/transforms.py passes the default (1) instead of the unroll value passed PR #2213 will fix it

Replies: 1 comment 2 replies

Comment options

You must be logged in to vote
2 replies
@luweizheng
Comment options

@luweizheng
Comment options

Answer selected by luweizheng
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
2 participants