You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Provide a very small example that shows how to do a common task with jax.api's jax.grad like grading w.r.t. 1st and 2nd parameters (using those words) for better clarity and UX especially when a bunch of external folks are probably learning JAX and it's magic autodiff.
More details
For example, you can write:
grad(my_function, argnums=(0, 1)
The docstring already explains it like this:
defgrad(fun: Callable, argnums: Union[int, Sequence[int]] =0,
has_aux: bool=False, holomorphic: bool=False,
allow_int: bool=False) ->Callable:
""" ... Args: ... argnums: Optional, integer or sequence of integers. Specifies which positional argument(s) to differentiate with respect to (default 0). ... For example: >>> import jax >>> >>> grad_tanh = jax.grad(jax.numpy.tanh) >>> print(grad_tanh(0.2)) 0.961043 """
Searching for positional argument to differentiate in the search bar wasn't the first thing that came to my mind, although I'm not a researcher/programmer, so I may be wrong.
Proposal
Maybe we can add a second small example to the docsting explaining how to take gradient with respect to... parameter (<- key words for SEO) with some toy example like grad(my_function, argnums=(0, 1) really quick without taking up much space.
fromjaximportlaxfromjaximportapi
...
defsquare_add_lax(a, b):
"""A square-add function using the newly defined multiply-add."""returnmultiply_add_lax(a, a, b)
...
# Differentiate w.r.t. the first argumentprint("grad(square_add_lax) = ", api.grad(square_add_lax, argnums=0)(2.0, 10.))
reacted with thumbs up emoji reacted with thumbs down emoji reacted with laugh emoji reacted with hooray emoji reacted with confused emoji reacted with heart emoji reacted with rocket emoji reacted with eyes emoji
-
Hi @mattjj and team.
Idea
Provide a very small example that shows how to do a common task with
jax.api
'sjax.grad
likegrad
ing w.r.t. 1st and 2nd parameters (using those words) for better clarity and UX especially when a bunch of external folks are probably learning JAX and it's magic autodiff.More details
For example, you can write:
The docstring already explains it like this:
Searching for
positional argument to differentiate
in the search bar wasn't the first thing that came to my mind, although I'm not a researcher/programmer, so I may be wrong.Proposal
Maybe we can add a second small example to the docsting explaining how to
take gradient with respect to... parameter
(<- key words for SEO) with some toy example likegrad(my_function, argnums=(0, 1)
really quick without taking up much space.WDYT? Cheers ☕️
cc @dynamicwebpaige for UX
There's an example in How JAX primitives work but you'd have to search for it:
Source: https://jax.readthedocs.io/en/latest/_modules/jax/api.html#grad
Beta Was this translation helpful? Give feedback.
All reactions