lifted nn.scan
unroll seems not work comparing jax.lax.scan
#2198
-
Hi, I am doing AI4Science research on solving differential equations. I mainly use I implement I also check the source code, and it seems that in axes_scan.py |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 2 replies
-
The scan wrapper in linen/transforms.py passes the default (1) instead of the unroll value passed PR #2213 will fix it |
Beta Was this translation helpful? Give feedback.
The scan wrapper in linen/transforms.py passes the default (1) instead of the unroll value passed PR #2213 will fix it