Skip to content

Commit

Permalink
Add deprecation warning for FROM_GDA usage since that argument is not…
Browse files Browse the repository at this point in the history
… required anymore.

PiperOrigin-RevId: 519781715
  • Loading branch information
yashk2810 authored and jax authors committed Mar 27, 2023
1 parent 3c3fa04 commit e21aee1
Show file tree
Hide file tree
Showing 4 changed files with 31 additions and 19 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,9 @@ Remember to align the itemized text with the first line of an item within a list
For example, instead of `x.at[i].get(True)`, use `x.at[i].get(indices_are_sorted=True)`
* `jax.interpreters.xla.device_put` is deprecated. Please use `jax.device_put`.
* `jax.interpreters.pxla.device_put` is deprecated. Please use `jax.device_put`.
* `jax.experimental.pjit.FROM_GDA` is deprecated. Please pass in sharded
jax.Arrays as input and remove the `in_shardings` argument to pjit since
it is optional.

## jaxlib 0.4.7

Expand Down
2 changes: 1 addition & 1 deletion jax/experimental/multihost_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from jax._src.interpreters import pxla
from jax.interpreters import xla
from jax._src import pjit as pjit_lib
from jax.experimental.pjit import pjit, FROM_GDA
from jax.experimental.pjit import pjit
from jax.sharding import PartitionSpec as P
from jax._src import distributed
from jax._src import config as config_internal
Expand Down
37 changes: 25 additions & 12 deletions jax/experimental/pjit.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@

from jax._src.pjit import (
AUTO as AUTO,
FROM_GDA as FROM_GDA,
ParsedPartitionSpec as ParsedPartitionSpec,
get_array_mapping as get_array_mapping,
hashable_pytree as hashable_pytree,
Expand All @@ -38,28 +37,42 @@
from jax._src.pjit import (
NamedSharding as _deprecated_NamedSharding,
PartitionSpec as _deprecated_PartitionSpec,
FROM_GDA as _deprecated_FROM_GDA,
)

import typing
if typing.TYPE_CHECKING:
from jax._src.pjit import (
NamedSharding as NamedSharding,
PartitionSpec as PartitionSpec,
FROM_GDA as FROM_GDA,
)
del typing

_deprecations = {
# Added Feb 13, 2023:
"NamedSharding": (
("jax.experimental.pjit.NamedSharding is deprecated. Use "
"jax.sharding.NamedSharding."),
_deprecated_NamedSharding,
),
"PartitionSpec": (
("jax.experimental.pjit.PartitionSpec is deprecated. Use "
"jax.sharding.PartitionSpec."),
_deprecated_PartitionSpec,
),
# Added Feb 13, 2023:
"NamedSharding": (
(
"jax.experimental.pjit.NamedSharding is deprecated. Use "
"jax.sharding.NamedSharding."
),
_deprecated_NamedSharding,
),
"PartitionSpec": (
(
"jax.experimental.pjit.PartitionSpec is deprecated. Use "
"jax.sharding.PartitionSpec."
),
_deprecated_PartitionSpec,
),
"FROM_GDA": (
(
"jax.experimental.pjit.FROM_GDA is deprecated. Please pass in"
" sharded jax.Arrays as input and remove the in_shardings argument"
" to pjit since it is optional."
),
_deprecated_FROM_GDA,
),
}

from jax._src.deprecations import deprecation_getattr as _deprecation_getattr
Expand Down
8 changes: 2 additions & 6 deletions tests/multiprocess_gpu_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -420,9 +420,7 @@ def cb(index):
}

with jax.sharding.Mesh(global_mesh.devices, global_mesh.axis_names):
f = pjit.pjit(
lambda x: x, in_shardings=pjit.FROM_GDA, out_shardings=mesh_axes
)
f = pjit.pjit(lambda x: x, out_shardings=mesh_axes)
out = f(gda1)
for s in out.addressable_shards:
device_id = s.device.id
Expand Down Expand Up @@ -471,9 +469,7 @@ def cb(index):
}

with global_mesh:
f = pjit.pjit(
lambda x: x, in_shardings=pjit.FROM_GDA, out_shardings=mesh_axes
)
f = pjit.pjit(lambda x: x, out_shardings=mesh_axes)
out = f(gda1)

for s in out.addressable_shards:
Expand Down

0 comments on commit e21aee1

Please sign in to comment.