Skip to content

Commit

Permalink
Merge pull request #18795 from gnecula:test_export_grad
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 587730171
  • Loading branch information
jax authors committed Dec 4, 2023
2 parents 1d95e79 + 8a2d4a0 commit d91c13e
Show file tree
Hide file tree
Showing 3 changed files with 80 additions and 16 deletions.
4 changes: 2 additions & 2 deletions jax/_src/pjit.py
Original file line number Diff line number Diff line change
Expand Up @@ -1816,7 +1816,7 @@ def _resource_typing_pjit(avals, params, source_info, resource_env, named_axis_r
s._original_sharding, '_parsed_pspec'):
parsed_pspec = s._original_sharding._parsed_pspec
else:
if resource_env is not None:
if resource_env is not None and not resource_env.physical_mesh.empty:
parsed_pspec = parse_flatten_op_sharding(
s._hlo_sharding, resource_env.physical_mesh)[0]
else:
Expand All @@ -1838,7 +1838,7 @@ def _resource_typing_pjit(avals, params, source_info, resource_env, named_axis_r
s._original_sharding, '_parsed_pspec'):
parsed_pspec = s._original_sharding._parsed_pspec
else:
if resource_env is not None:
if resource_env is not None and not resource_env.physical_mesh.empty:
parsed_pspec = parse_flatten_op_sharding(
s._hlo_sharding, resource_env.physical_mesh)[0]
else:
Expand Down
25 changes: 20 additions & 5 deletions jax/experimental/jax2tf/tests/sharding_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import contextlib
from functools import partial
import logging
import math
import os
import re
from typing import Any
Expand All @@ -40,6 +41,7 @@
from jax.experimental import pjit
from jax.experimental.maps import xmap
from jax.experimental.shard_map import shard_map
from jax.sharding import NamedSharding
from jax.sharding import Mesh
from jax.sharding import PartitionSpec as P
import jax.numpy as jnp
Expand Down Expand Up @@ -382,25 +384,38 @@ def f_jax(x): # x: f32[10, 20], optionally some axes as polymorphic
for in_shardings in ("missing", None, "P")
for out_shardings in ("missing", None, "P")
])
@jtu.with_mesh([("x", 2)])
def test_grad_pjit(self, in_shardings="P", out_shardings=None):
if not config.jax2tf_default_native_serialization.value:
self.skipTest("TODO: failure in non-native serialization")
local_devices = list(jax.local_devices())
size = 2
if len(local_devices) < size:
raise unittest.SkipTest(f"Test requires {size} local devices")
mesh_devices = np.array(local_devices[:size]).reshape((2,))
mesh = jax.sharding.Mesh(mesh_devices, ("x",))
def f_jax(x): # x: f32[10,20] -> f32[20,10]
return jnp.sin(x.T)

pjit_kwargs = {}
if in_shardings != "missing":
pjit_kwargs["in_shardings"] = (P(None, "x") if in_shardings == "P" else None)
pjit_kwargs["in_shardings"] = (
NamedSharding(mesh, P(None, "x")) if in_shardings == "P" else None)
if out_shardings != "missing":
pjit_kwargs["out_shardings"] = (P("x", None) if out_shardings == "P" else None)
pjit_kwargs["out_shardings"] = (
NamedSharding(mesh, P("x", None)) if out_shardings == "P" else None)
f_jax = pjit.pjit(f_jax, **pjit_kwargs)
x_shape = (10, 20)
x = np.arange(np.prod(x_shape), dtype=np.float32).reshape(x_shape)

def f_grad_tf(x_v, res_ct):
with tf.GradientTape(persistent=True) as tape:
tape.watch(x_v)
res_tf = jax2tf.convert(f_jax)(x_v)
return tape.gradient(res_tf, x_v, output_gradients=res_ct)
with tf.GradientTape() as tape2:
tape2.watch(x_v)
res_tf = jax2tf.convert(f_jax)(x_v)
dy_dx = tape.gradient(res_tf, x_v, output_gradients=res_ct)
d2y_dx2 = tape.gradient(dy_dx, x_v)
return d2y_dx2

# Annotation count for the primal input and the grad output
count_in_P = self.GEQ(2) if in_shardings == "P" else 0
Expand Down
67 changes: 58 additions & 9 deletions tests/export_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from jax import tree_util
from jax.experimental.export import export
from jax.experimental import pjit
from jax.sharding import NamedSharding
from jax.sharding import Mesh
from jax.sharding import PartitionSpec as P

Expand Down Expand Up @@ -755,30 +756,50 @@ def f_jax(b): # b: f32[16 // DEVICES, 4]
)(a)

@jtu.parameterized_filterable(
one_containing="in_shardings_None_out_shardings_P_with_mesh_False",
kwargs=[
dict(testcase_name=f"_in_shardings={in_shardings}_out_shardings={out_shardings}",
in_shardings=in_shardings, out_shardings=out_shardings)
dict(in_shardings=in_shardings, out_shardings=out_shardings,
with_mesh=with_mesh)
for in_shardings in ("missing", None, "P")
for out_shardings in ("missing", None, "P")
for with_mesh in (True, False)
])
def test_grad_with_sharding(self, in_shardings="P", out_shardings=None):
def test_grad_with_sharding(self, in_shardings="P", out_shardings=None,
with_mesh=False):
if len(jax.devices()) < 2:
self.skipTest("Test requires at least 2 devices")
x_shape = (10, 20)
x = np.arange(np.prod(x_shape), dtype=np.float32).reshape(x_shape)
def f_jax(x): # x: f32[10,20] -> f32[20,10]
return jnp.sin(x.T)

mesh = Mesh(jax.devices()[:2], "d")
pjit_kwargs = {}
# Use NamedShardings if we don't have a mesh_context
if with_mesh:
sharding_None_d = P(None, "d")
sharding_d_None = P("d", None)
else:
sharding_None_d = NamedSharding(mesh, P(None, "d"))
sharding_d_None = NamedSharding(mesh, P("d", None))

if in_shardings != "missing":
pjit_kwargs["in_shardings"] = (P(None, "x") if in_shardings == "P" else None)
pjit_kwargs["in_shardings"] = (
sharding_None_d if in_shardings == "P" else None)
if out_shardings != "missing":
pjit_kwargs["out_shardings"] = (P("x", None) if out_shardings == "P" else None)
f_jax = pjit.pjit(f_jax, **pjit_kwargs)
pjit_kwargs["out_shardings"] = (
sharding_d_None if out_shardings == "P" else None)
f_jax_pjit = pjit.pjit(f_jax, **pjit_kwargs)

with contextlib.ExitStack() as stack:
if with_mesh:
stack.enter_context(mesh)
# Serialize higher-order gradiends
exp = export.export(f_jax_pjit)(x)

with Mesh(jax.devices()[:2], "x"):
exp = export.export(f_jax)(x)
exp_vjp = exp.vjp()
# Try 2nd order grad as well
exp_vjp2 = exp_vjp.vjp()

vjp_module_str = str(exp_vjp.mlir_module())

Expand Down Expand Up @@ -812,13 +833,41 @@ def f_jax(x): # x: f32[10,20] -> f32[20,10]

# Custom calls for the primal output shape all match primal_out_sharding
primal_out_calls = re.findall(
r"custom_call @Sharding.* {mhlo.sharding = \"(.+)\".*:.*tensor<20x10xf32>",
r"custom_call @Sharding.*mhlo.sharding = \"(.+)\".*:.*tensor<20x10xf32>",
vjp_module_str)
self.assertTrue(
all(s == primal_out_sharding for s in primal_out_calls),
primal_in_calls
)

# Call the exported gradient functions. In order to set the device context
# we replicate the inputs. If we don't use a mesh context and there are
# no shardings on inputs or outputs, then we have serialized for one
# device.
if in_shardings != "P" and out_shardings != "P" and not with_mesh:
self.assertEqual(exp_vjp.nr_devices, 1)
self.assertEqual(exp_vjp2.nr_devices, 1)
call_mesh = Mesh(jax.devices()[:1], "e")
else:
self.assertEqual(exp_vjp.nr_devices, 2)
self.assertEqual(exp_vjp2.nr_devices, 2)
call_mesh = Mesh(jax.devices()[:2], "e")

g1 = pjit.pjit(export.call_exported(exp_vjp),
in_shardings=(NamedSharding(call_mesh, None),
NamedSharding(call_mesh, None)))(x, x.T)
_, f_jax_vjp = jax.vjp(f_jax, x)
xbar = f_jax_vjp(x.T)
self.assertAllClose(xbar, g1)

g2 = pjit.pjit(export.call_exported(exp_vjp2),
in_shardings=(NamedSharding(call_mesh, None),
NamedSharding(call_mesh, None),
NamedSharding(call_mesh, None)))(x, x.T, x)
_, f_jax_vjp2 = jax.vjp(f_jax_vjp, x.T)
xbar2, = f_jax_vjp2((x,))
self.assertAllClose(xbar2, g2[1])

def test_multi_platform(self):
x = np.arange(8, dtype=np.float32)
exp = export.export(_testing_multi_platform_func,
Expand Down

0 comments on commit d91c13e

Please sign in to comment.