jax symbolic autodiff on sin cos ... #15850
Answered
by
jakevdp
jakubMitura14
asked this question in
Q&A
-
Hello I was reading through the autodiff cookbook but one thing is not clear to me, do simple functions like trig functions are differentiated symbolically or using numerical approximations?
the jax during reverse autodiff will get a derivative as just cos? |
Beta Was this translation helpful? Give feedback.
Answered by
jakevdp
May 3, 2023
Replies: 1 comment
-
JAX does symbolic automatic differentiation at the primitive level. You can see this by printing the jaxpr for the transformed function: import jax
import jax.numpy as jnp
def a(x):
return jnp.sin(x)
x = jnp.float32(1.0)
print(jax.make_jaxpr(a)(x))
# { lambda ; a:f32[]. let b:f32[] = sin a in (b,) }
print(jax.make_jaxpr(jax.grad(a))(x))
# { lambda ; a:f32[]. let
# _:f32[] = sin a
# b:f32[] = cos a
# c:f32[] = mul 1.0 b
# in (c,) } |
Beta Was this translation helpful? Give feedback.
0 replies
Answer selected by
jakubMitura14
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
JAX does symbolic automatic differentiation at the primitive level. You can see this by printing the jaxpr for the transformed function: