Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixed func ref in shared-computation #23643

Merged
merged 1 commit into from
Sep 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/developer.md
Original file line number Diff line number Diff line change
Expand Up @@ -697,7 +697,7 @@ using [jupytext](https://jupytext.readthedocs.io/) by running `jupytext --sync`
notebooks; for example:

```
pip install jupytext==1.16.0
pip install jupytext==1.16.4
jupytext --sync docs/notebooks/thinking_in_jax.ipynb
```

Expand Down
2 changes: 1 addition & 1 deletion docs/sharded-computation.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -360,7 +360,7 @@
"\n",
"## 2. Semi-automated sharding with constraints\n",
"\n",
"If you'd like to have some control over the sharding used within a particular computation, JAX offers the {func}`~jax.lax.with_sharding_constraint` function. You can use {func}`jax.lax.with_sharding_constraint` (in place of (func}`jax.device_put()`) together with {func}`jax.jit` for more control over how the compiler constraints how the intermediate values and outputs are distributed.\n",
"If you'd like to have some control over the sharding used within a particular computation, JAX offers the {func}`~jax.lax.with_sharding_constraint` function. You can use {func}`jax.lax.with_sharding_constraint` (in place of {func}`jax.device_put()`) together with {func}`jax.jit` for more control over how the compiler constraints how the intermediate values and outputs are distributed.\n",
"\n",
"For example, suppose that within `f_contract` above, you'd prefer the output not to be partially-replicated, but rather to be fully sharded across the eight devices:"
]
Expand Down
2 changes: 1 addition & 1 deletion docs/sharded-computation.md
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ The result is partially replicated: that is, the first two elements of the array

## 2. Semi-automated sharding with constraints

If you'd like to have some control over the sharding used within a particular computation, JAX offers the {func}`~jax.lax.with_sharding_constraint` function. You can use {func}`jax.lax.with_sharding_constraint` (in place of (func}`jax.device_put()`) together with {func}`jax.jit` for more control over how the compiler constraints how the intermediate values and outputs are distributed.
If you'd like to have some control over the sharding used within a particular computation, JAX offers the {func}`~jax.lax.with_sharding_constraint` function. You can use {func}`jax.lax.with_sharding_constraint` (in place of {func}`jax.device_put()`) together with {func}`jax.jit` for more control over how the compiler constraints how the intermediate values and outputs are distributed.

For example, suppose that within `f_contract` above, you'd prefer the output not to be partially-replicated, but rather to be fully sharded across the eight devices:

Expand Down