diff --git a/docs/developer.md b/docs/developer.md index af2e451a22ef..53b6f0cf0f45 100644 --- a/docs/developer.md +++ b/docs/developer.md @@ -697,7 +697,7 @@ using [jupytext](https://jupytext.readthedocs.io/) by running `jupytext --sync` notebooks; for example: ``` -pip install jupytext==1.16.0 +pip install jupytext==1.16.4 jupytext --sync docs/notebooks/thinking_in_jax.ipynb ``` diff --git a/docs/sharded-computation.ipynb b/docs/sharded-computation.ipynb index 60bf4d41a7a6..a4b6f2e0ced2 100644 --- a/docs/sharded-computation.ipynb +++ b/docs/sharded-computation.ipynb @@ -360,7 +360,7 @@ "\n", "## 2. Semi-automated sharding with constraints\n", "\n", - "If you'd like to have some control over the sharding used within a particular computation, JAX offers the {func}`~jax.lax.with_sharding_constraint` function. You can use {func}`jax.lax.with_sharding_constraint` (in place of (func}`jax.device_put()`) together with {func}`jax.jit` for more control over how the compiler constraints how the intermediate values and outputs are distributed.\n", + "If you'd like to have some control over the sharding used within a particular computation, JAX offers the {func}`~jax.lax.with_sharding_constraint` function. You can use {func}`jax.lax.with_sharding_constraint` (in place of {func}`jax.device_put()`) together with {func}`jax.jit` for more control over how the compiler constraints how the intermediate values and outputs are distributed.\n", "\n", "For example, suppose that within `f_contract` above, you'd prefer the output not to be partially-replicated, but rather to be fully sharded across the eight devices:" ] diff --git a/docs/sharded-computation.md b/docs/sharded-computation.md index ef4dc2d3288d..c273e23c771e 100644 --- a/docs/sharded-computation.md +++ b/docs/sharded-computation.md @@ -133,7 +133,7 @@ The result is partially replicated: that is, the first two elements of the array ## 2. Semi-automated sharding with constraints -If you'd like to have some control over the sharding used within a particular computation, JAX offers the {func}`~jax.lax.with_sharding_constraint` function. You can use {func}`jax.lax.with_sharding_constraint` (in place of (func}`jax.device_put()`) together with {func}`jax.jit` for more control over how the compiler constraints how the intermediate values and outputs are distributed. +If you'd like to have some control over the sharding used within a particular computation, JAX offers the {func}`~jax.lax.with_sharding_constraint` function. You can use {func}`jax.lax.with_sharding_constraint` (in place of {func}`jax.device_put()`) together with {func}`jax.jit` for more control over how the compiler constraints how the intermediate values and outputs are distributed. For example, suppose that within `f_contract` above, you'd prefer the output not to be partially-replicated, but rather to be fully sharded across the eight devices: