Skip to content

Commit

Permalink
wip
Browse files Browse the repository at this point in the history
  • Loading branch information
dougalm committed Aug 28, 2024
1 parent 9015af1 commit 6b9b191
Show file tree
Hide file tree
Showing 5 changed files with 24 additions and 41 deletions.
24 changes: 11 additions & 13 deletions jax/_src/custom_derivatives.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,14 +327,16 @@ def _flatten_jvp(primal_name, jvp_name, in_tree, maybe_out_type, *args):
f""" {str(ty_tree_).replace("'", "")}""")
raise TypeError(m)
# TODO(mattjj): compare primals' tangent types to tangent objects' types
primal_avals_out = [raise_to_shaped(core.get_aval(x), weak_type=False)
for x in primals_out]
tangent_avals_out = [raise_to_shaped(core.get_aval(t), weak_type=False)
if type(t) is not SymbolicZero else t.aval.strip_weak_type()
for t in tangents_out]
if primal_avals_out != tangent_avals_out:
if len(primal_avals_out) == 1:
(av1,), (av2,) = primal_avals_out, tangent_avals_out
expected_tangent_avals_out = [
core.primal_aval_to_tangent_aval(raise_to_shaped(core.get_aval(x), weak_type=False))
for x in primals_out]
tangent_avals_out = [
raise_to_shaped(core.get_aval(t), weak_type=False)
if type(t) is not SymbolicZero else t.aval.strip_weak_type()
for t in tangents_out]
if expected_tangent_avals_out != tangent_avals_out:
if len(expected_tangent_avals_out) == 1:
(av1,), (av2,) = expected_tangent_avals_out, tangent_avals_out
msg = ("Custom JVP rule must produce primal and tangent outputs with "
"equal shapes and dtypes, but got {} and {} respectively.")
raise TypeError(msg.format(av1.str_short(), av2.str_short()))
Expand All @@ -343,7 +345,7 @@ def _flatten_jvp(primal_name, jvp_name, in_tree, maybe_out_type, *args):
"equal shapes and dtypes, but got:\n{}")
disagreements = (
f" primal {av1.str_short()} for tangent {av2.str_short()}"
for av1, av2 in zip(primal_avals_out, tangent_avals_out) if av1 != av2)
for av1, av2 in zip(expected_tangent_avals_out, tangent_avals_out) if av1 != av2)
raise TypeError(msg.format('\n'.join(disagreements)))
yield primals_out + tangents_out, (out_tree, primal_avals)

Expand Down Expand Up @@ -814,14 +816,10 @@ def _custom_vjp_call_jaxpr_jvp(
res, primals_out = split_list(res_and_primals_out, [res_tree.num_leaves])
avals_out = [raise_to_shaped(core.get_aval(x)) for x in primals_out]
args_dot = map(ad.instantiate_zeros, args_dot)
# Cast float0 to zeros with the primal dtype because custom vjp rules don't
# currently handle float0s
args_dot = map(ad.replace_float0s, args, args_dot)
tangents_out = ad.custom_lin_p.bind(
*res, *args_dot, num_res=res_tree.num_leaves, bwd=bwd,
out_avals=avals_out, symbolic_zeros=symbolic_zeros)
tangents_out = map(lax.tie_p.bind, primals_out, tangents_out)
tangents_out = map(ad.recast_to_float0, primals_out, tangents_out)
return primals_out, tangents_out
ad.primitive_jvps[custom_vjp_call_jaxpr_p] = _custom_vjp_call_jaxpr_jvp

Expand Down
4 changes: 2 additions & 2 deletions jax/_src/dtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -782,9 +782,9 @@ def check_user_dtype_supported(dtype, fun_name=None):
int2,
int4,
uint2,
uint4,
uint4
]
if np_dtype.kind not in "biufc" and not is_custom_dtype:
if np_dtype.kind not in "biufc" and not is_custom_dtype and not dtype == float0:
msg = f"JAX only supports number and bool dtypes, got dtype {dtype}"
msg += f" in {fun_name}" if fun_name else ""
raise TypeError(msg)
Expand Down
19 changes: 2 additions & 17 deletions jax/_src/interpreters/ad.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,19 +164,6 @@ def unpair_pval(pval):
aval_1, aval_2 = aval
return (aval_1, const_1), (aval_2, const_2)

def replace_float0s(primal, tangent):
if dtype(tangent) == float0:
return zeros_like_jaxval(primal)
else:
return tangent

def recast_to_float0(primal, tangent):
if core.primal_dtype_to_tangent_dtype(dtype(primal)) == float0:
return Zero.from_primal_value(primal)
else:
return tangent


# NOTE: The FIXMEs below are caused by primal/tangent mixups (type
# errors if you will)
def backward_pass(jaxpr: core.Jaxpr, transform_stack,
Expand Down Expand Up @@ -369,7 +356,6 @@ def process_custom_jvp_call(self, prim, fun, f_jvp, tracers, *, symbolic_zeros):
with core.set_current_trace(self.parent_trace):
if not symbolic_zeros:
tangents_in = map(instantiate_zeros, tangents_in)
tangents_in = map(replace_float0s, primals_in, tangents_in)
else:
tangents_in = map(replace_internal_symbolic_zeros, tangents_in)
with core.set_current_trace(self):
Expand All @@ -396,7 +382,7 @@ def process_custom_vjp_call(self, prim, fun, fwd, bwd, tracers, out_trees,
res, primals_out = split_list(res_and_primals_out, [res_tree.num_leaves])
primals_out = map(self.primal_part, primals_out)
res = map(self.primal_part, res)
avals_out = [raise_to_shaped(core.get_aval(x)) for x in primals_out]
avals_out = [Zero.from_primal_value(x) for x in primals_out]
# TODO(frostig,mattjj): avoid instantiating zeros when we don't have to!
with core.set_current_trace(self.parent_trace):
tangents_in = map(instantiate_zeros, tangents_in)
Expand All @@ -405,7 +391,6 @@ def process_custom_vjp_call(self, prim, fun, fwd, bwd, tracers, out_trees,
*res, *tangents_in, num_res=res_tree.num_leaves, bwd=bwd,
out_avals=avals_out, symbolic_zeros=symbolic_zeros)
tangents_out = map(self.primal_part, tangents_out)
tangents_out = map(recast_to_float0, primals_out, tangents_out)
return map(partial(JVPTracer, self), primals_out, tangents_out)

def process_custom_transpose(self, prim, call, tracers, **params):
Expand Down Expand Up @@ -673,7 +658,7 @@ def _jvp_jaxpr(jaxpr, nonzeros, instantiate):
f = lu.wrap_init(core.jaxpr_as_fun(jaxpr))
f_jvp, out_nonzeros = f_jvp_traceable(jvp(f, instantiate=instantiate, transform_stack=False),
nonzeros)
tangent_avals = [aval for aval, nz in zip(jaxpr.in_avals, nonzeros) if nz]
tangent_avals = [core.primal_aval_to_tangent_aval(aval) for aval, nz in zip(jaxpr.in_avals, nonzeros) if nz]
avals_in = list(it.chain(jaxpr.in_avals, tangent_avals))
jaxpr_out, avals_out, literals_out, () = pe.trace_to_jaxpr_dynamic(f_jvp, avals_in)
return core.ClosedJaxpr(jaxpr_out, literals_out), out_nonzeros()
Expand Down
2 changes: 0 additions & 2 deletions jax/interpreters/ad.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,9 +57,7 @@
primitive_jvps as primitive_jvps,
primitive_transposes as primitive_transposes,
rearrange_binders as rearrange_binders,
recast_to_float0 as recast_to_float0,
reducing_transposes as reducing_transposes,
replace_float0s as replace_float0s,
standard_jvp as standard_jvp,
standard_jvp2 as standard_jvp2,
traceable as traceable,
Expand Down
16 changes: 9 additions & 7 deletions tests/api_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -7700,36 +7700,38 @@ def g_jvp(primals, tangents):
self.assertAllClose(tangents, 2 * jnp.arange(3., dtype='float32'))

def test_float0(self):
scalar_float0 = jnp.zeros((), dtype=float0)
@jax.custom_jvp
def f(x, y):
return x, y
def f_jvp(primals, _):
# we need a defined (non-float0) tangent to trigger the rule
return primals, (2., 1)
return primals, (2., scalar_float0)
f.defjvp(f_jvp)

primals = (2., 3)
tangents = (np.ones(()), np.zeros((), float0),)
expected_tangents = (2., np.zeros((), float0))
tangents = (np.ones(()), scalar_float0)
expected_tangents = (2., scalar_float0)
self.assertAllClose(api.jvp(f, primals, tangents),
(primals, expected_tangents))

def test_float0_initial_style(self):
scalar_float0 = jnp.zeros((), dtype=float0)
@jax.custom_jvp
def f(x, y):
return x, y
def f_jvp(primals, _):
x, y = primals
return (x, y), (2., 1)
return (x, y), (2., scalar_float0)
f.defjvp(f_jvp)

def foo(x, y):
out, _ = lax.scan(lambda c, _: (f(*c), None), (x, y), None, length=1)
return out

primals = (2., 3)
tangents = (np.ones(()), np.zeros((), float0),)
expected_tangents = (2., np.zeros((), float0))
tangents = (np.ones(()), scalar_float0)
expected_tangents = (2., scalar_float0)

self.assertAllClose(api.jvp(foo, primals, tangents),
(primals, expected_tangents))

Expand Down

0 comments on commit 6b9b191

Please sign in to comment.