Skip to content

jax symbolic autodiff on sin cos ... #15850

Answered by jakevdp
jakubMitura14 asked this question in Q&A
Discussion options

You must be logged in to vote

JAX does symbolic automatic differentiation at the primitive level. You can see this by printing the jaxpr for the transformed function:

import jax
import jax.numpy as jnp

def a(x):
    return jnp.sin(x)

x = jnp.float32(1.0)

print(jax.make_jaxpr(a)(x))
# { lambda ; a:f32[]. let b:f32[] = sin a in (b,) }

print(jax.make_jaxpr(jax.grad(a))(x))
# { lambda ; a:f32[]. let
#     _:f32[] = sin a
#     b:f32[] = cos a
#     c:f32[] = mul 1.0 b
#   in (c,) }

Replies: 1 comment

Comment options

You must be logged in to vote
0 replies
Answer selected by jakubMitura14
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants