-
I want to use import jax
async def f(x):
return x + 1
y = await f(1.0)
df = jax.grad(f)
dy = await df(1.0) I get the following
jax version '0.4.4' can someone point me in the right direction? Is this even possible with jax? Thank you! |
Beta Was this translation helpful? Give feedback.
Answered by
tylerflex
Feb 22, 2023
Replies: 1 comment
-
Answered my own question finally, the key was to pass f_async = lambda x: asyncio.run(f(x)) to jax.grad. |
Beta Was this translation helpful? Give feedback.
0 replies
Answer selected by
tylerflex
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Answered my own question finally, the key was to pass
to jax.grad.