Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve documentation for jax.jacobian #23940

Merged
merged 1 commit into from
Sep 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/jax.rst
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ Automatic differentiation

grad
value_and_grad
jacobian
jacfwd
jacrev
hessian
Expand Down
8 changes: 7 additions & 1 deletion jax/_src/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -680,7 +680,13 @@ def jacfun(*args, **kwargs):
return jac_tree, aux

return jacfun
jacobian = jacrev


def jacobian(fun: Callable, argnums: int | Sequence[int] = 0,
has_aux: bool = False, holomorphic: bool = False, allow_int: bool = False) -> Callable:
"""Alias of :func:`jax.jacrev`."""
return jacrev(fun, argnums=argnums, has_aux=has_aux, holomorphic=holomorphic, allow_int=allow_int)


_check_input_dtype_jacrev = partial(_check_input_dtype_revderiv, "jacrev")
_check_output_dtype_jacrev = partial(_check_output_dtype_revderiv, "jacrev")
Expand Down