diff --git a/CHANGELOG.md b/CHANGELOG.md index f5c3929ea356..fe76ebae5094 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -35,6 +35,9 @@ Remember to align the itemized text with the first line of an item within a list For example, instead of `x.at[i].get(True)`, use `x.at[i].get(indices_are_sorted=True)` * `jax.interpreters.xla.device_put` is deprecated. Please use `jax.device_put`. * `jax.interpreters.pxla.device_put` is deprecated. Please use `jax.device_put`. + * `jax.experimental.pjit.FROM_GDA` is deprecated. Please pass in sharded + jax.Arrays as input and remove the `in_shardings` argument to pjit since + it is optional. ## jaxlib 0.4.7 diff --git a/jax/experimental/multihost_utils.py b/jax/experimental/multihost_utils.py index 0b35b9e4a2a9..7c5b0e58d065 100644 --- a/jax/experimental/multihost_utils.py +++ b/jax/experimental/multihost_utils.py @@ -28,7 +28,7 @@ from jax._src.interpreters import pxla from jax.interpreters import xla from jax._src import pjit as pjit_lib -from jax.experimental.pjit import pjit, FROM_GDA +from jax.experimental.pjit import pjit from jax.sharding import PartitionSpec as P from jax._src import distributed from jax._src import config as config_internal diff --git a/jax/experimental/pjit.py b/jax/experimental/pjit.py index 30e0f41230da..9e8e1ad19f5b 100644 --- a/jax/experimental/pjit.py +++ b/jax/experimental/pjit.py @@ -16,7 +16,6 @@ from jax._src.pjit import ( AUTO as AUTO, - FROM_GDA as FROM_GDA, ParsedPartitionSpec as ParsedPartitionSpec, get_array_mapping as get_array_mapping, hashable_pytree as hashable_pytree, @@ -38,6 +37,7 @@ from jax._src.pjit import ( NamedSharding as _deprecated_NamedSharding, PartitionSpec as _deprecated_PartitionSpec, + FROM_GDA as _deprecated_FROM_GDA, ) import typing @@ -45,21 +45,34 @@ from jax._src.pjit import ( NamedSharding as NamedSharding, PartitionSpec as PartitionSpec, + FROM_GDA as FROM_GDA, ) del typing _deprecations = { - # Added Feb 13, 2023: - "NamedSharding": ( - ("jax.experimental.pjit.NamedSharding is deprecated. Use " - "jax.sharding.NamedSharding."), - _deprecated_NamedSharding, - ), - "PartitionSpec": ( - ("jax.experimental.pjit.PartitionSpec is deprecated. Use " - "jax.sharding.PartitionSpec."), - _deprecated_PartitionSpec, - ), + # Added Feb 13, 2023: + "NamedSharding": ( + ( + "jax.experimental.pjit.NamedSharding is deprecated. Use " + "jax.sharding.NamedSharding." + ), + _deprecated_NamedSharding, + ), + "PartitionSpec": ( + ( + "jax.experimental.pjit.PartitionSpec is deprecated. Use " + "jax.sharding.PartitionSpec." + ), + _deprecated_PartitionSpec, + ), + "FROM_GDA": ( + ( + "jax.experimental.pjit.FROM_GDA is deprecated. Please pass in" + " sharded jax.Arrays as input and remove the in_shardings argument" + " to pjit since it is optional." + ), + _deprecated_FROM_GDA, + ), } from jax._src.deprecations import deprecation_getattr as _deprecation_getattr diff --git a/tests/multiprocess_gpu_test.py b/tests/multiprocess_gpu_test.py index cf9d69a08d75..d816af052aac 100644 --- a/tests/multiprocess_gpu_test.py +++ b/tests/multiprocess_gpu_test.py @@ -420,9 +420,7 @@ def cb(index): } with jax.sharding.Mesh(global_mesh.devices, global_mesh.axis_names): - f = pjit.pjit( - lambda x: x, in_shardings=pjit.FROM_GDA, out_shardings=mesh_axes - ) + f = pjit.pjit(lambda x: x, out_shardings=mesh_axes) out = f(gda1) for s in out.addressable_shards: device_id = s.device.id @@ -471,9 +469,7 @@ def cb(index): } with global_mesh: - f = pjit.pjit( - lambda x: x, in_shardings=pjit.FROM_GDA, out_shardings=mesh_axes - ) + f = pjit.pjit(lambda x: x, out_shardings=mesh_axes) out = f(gda1) for s in out.addressable_shards: