Interesting while-loop use-case : jacrev Ok, JIT not => scan #14732
Unanswered
jecampagne
asked this question in
Q&A
Replies: 1 comment 2 replies
-
Thanks for sharing! This is similar in spirit to the "early exit scan" or "bounded while loop" proposed here: #13062 |
Beta Was this translation helpful? Give feedback.
2 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
-
Hi,
here I think an interesting simple exo to work to a while-loop transformation (idea adapted from this discussion.
One can for instance code like that using a while-loop:
and
gives
Interestingly the jacobian FORWARD & BACKWARD can be computed as
while
is notjax.lax.while_loop
BUT, the JIT crashes due to
So, I have modified the code as followed, which shows that one should think different using JAX if one wants the full power
Then one finds the same results of the non-Jitted version
Takeaway:
Now, may be my jfunc_bis is not the best way, so if someone has a better solution (not for the given exemple but for a possible schema of generalisation in other use-case) I will be glad too.
PS: In fact at the start of the exo I was trying to set a jax.lax.while_loop code to make the jacrev crash...
then the
jacrev
will then crash as expected.Beta Was this translation helpful? Give feedback.
All reactions