Skip to content

Commit

Permalink
Clean up and fix primal type to tangent type mapping
Browse files Browse the repository at this point in the history
This is part of the ["stackless"](jax-ml#23299) change. I'm splitting it out into a separate PR because we need it for some work on sharding types.

Changes:
  1. Rename `at_least_vspace` to `to_tangent_type` since that's what we always meant by it. `at_least_vspace` was always a bad name (sorry!) but it makes even less sense when you can have a special tangent type for a primal types that's already a vector space itself.
  2. Replace `Zero.from_value` with `Zero.from_primal_value`, which does the required primal-type-to-tangent-type conversion.
  3. Add `to_tangent_type` calls in various other places they're missing.
  4. Remove non-support for float0 in custom deriviatives?
  5. [Optional, WIP] Reinstate some checks that had been skipped over, presumably because of these bugs. (We'll see how far I get with it. Might end up being a separate PR.)
PiperOrigin-RevId: 676115753
  • Loading branch information
dougalm authored and rajasekharporeddy committed Sep 20, 2024
1 parent 7e43f0c commit 0e08f9f
Show file tree
Hide file tree
Showing 32 changed files with 130 additions and 127 deletions.
2 changes: 1 addition & 1 deletion jax/_src/ad_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -514,7 +514,7 @@ def remat_jvp(primals, tangents, jaxpr, prevent_cse, differentiated, policy):
prevent_cse=prevent_cse, differentiated=differentiated, policy=policy)
out_primals, out_tangents_ = split_list(outs, [len(jaxpr.outvars)])
out_tangents_ = iter(out_tangents_)
out_tangents = [next(out_tangents_) if nz else ad_util.Zero.from_value(p)
out_tangents = [next(out_tangents_) if nz else ad_util.Zero.from_primal_value(p)
for p, nz in zip(out_primals, out_nz)]
return out_primals, out_tangents
ad.primitive_jvps[remat_p] = remat_jvp
Expand Down
9 changes: 7 additions & 2 deletions jax/_src/ad_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,8 +65,8 @@ def __init__(self, aval: core.AbstractValue):
def __repr__(self) -> str:
return f'Zero({self.aval})'
@staticmethod
def from_value(val: Any) -> Zero:
return Zero(raise_to_shaped(get_aval(val)))
def from_primal_value(val: Any) -> Zero:
return Zero(raise_to_shaped(get_aval(val)).to_tangent_aval())

register_pytree_node(Zero, lambda z: ((), z.aval), lambda aval, _: Zero(aval))

Expand All @@ -82,6 +82,7 @@ def _stop_gradient_impl(x: T) -> T:
stop_gradient_p.def_abstract_eval(lambda x: x)


# User-facing version of `Zero`
class SymbolicZero:
def __init__(self, aval: core.AbstractValue) -> None:
self.aval = aval
Expand All @@ -108,6 +109,10 @@ def __getattr__(self, name):
else:
return attr

@staticmethod
def from_primal_value(val: Any) -> SymbolicZero:
return SymbolicZero(get_aval(val).to_tangent_aval())

JaxTypeOrTracer = Any

def replace_internal_symbolic_zeros(
Expand Down
4 changes: 2 additions & 2 deletions jax/_src/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -1826,7 +1826,7 @@ def _lift_linearized(jaxpr, primal_avals, io_tree, out_pvals, consts, *py_args):
def fun(*tangents):
tangent_avals = list(map(core.get_aval, tangents))
for primal_aval, tangent_aval in zip(primal_avals, tangent_avals):
if not core.typecompat(primal_aval.at_least_vspace(), tangent_aval):
if not core.typecompat(primal_aval.to_tangent_aval(), tangent_aval):
raise ValueError("linearized function called on tangent values inconsistent with "
"the original primal values: "
f"got {tangent_aval} for primal aval {primal_aval}")
Expand Down Expand Up @@ -1869,7 +1869,7 @@ def _vjp_pullback_wrapper(name, out_primal_avals, io_tree, fun, *py_args_):
f"got {in_tree}, but expected to match {in_tree_expected}")
for arg, aval in zip(args, out_primal_avals):
ct_aval = shaped_abstractify(arg)
ct_aval_expected = aval.at_least_vspace()
ct_aval_expected = aval.to_tangent_aval()
if (not core.typecompat(ct_aval, ct_aval_expected) and
not _temporary_dtype_exception(ct_aval, ct_aval_expected)):
raise ValueError(
Expand Down
2 changes: 1 addition & 1 deletion jax/_src/checkify.py
Original file line number Diff line number Diff line change
Expand Up @@ -980,7 +980,7 @@ def jvp(*xs):
out = core.eval_jaxpr(jvp_jaxpr, jvp_consts, *primals, *nonzero_tangents)
out_primals, nz_out_tangents = split_list(out, [len(out_zeros)])
nz_out_tangents_ = iter(nz_out_tangents)
out_tangents = [SymbolicZero(core.get_aval(p).at_least_vspace())
out_tangents = [SymbolicZero(core.get_aval(p).to_tangent_aval())
if z else next(nz_out_tangents_)
for p, z in zip(out_primals, out_zeros)]
assert next(nz_out_tangents_, None) is None
Expand Down
20 changes: 15 additions & 5 deletions jax/_src/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1414,9 +1414,13 @@ def definitely_equal(x, y):
class AbstractValue:
__slots__: list[str] = []

def at_least_vspace(self):
def to_tangent_aval(self):
raise NotImplementedError("must override")

# TODO(dougalm): deprecate this alias
def at_least_vspace(self):
return self.to_tangent_aval()

def __repr__(self):
try:
kv_pairs = (f'{k}={v}' for k, v in self.__dict__.items())
Expand Down Expand Up @@ -1524,6 +1528,12 @@ def get_aval(x):
else:
return concrete_aval(x)

def get_type(x):
aval = get_aval(x)
if isinstance(aval, ConcreteArray):
return raise_to_shaped(aval)
else:
return aval

def concretization_function_error(fun, suggest_astype=False):
fname = getattr(fun, "__name__", fun)
Expand Down Expand Up @@ -1647,7 +1657,7 @@ def __repr__(self):
_oct = concretization_function_error(oct)
_index = concretization_function_error(operator.index)

def at_least_vspace(self) -> AbstractValue:
def to_tangent_aval(self) -> AbstractValue:
return UnshapedArray(primal_dtype_to_tangent_dtype(self.dtype),
self.weak_type)

Expand Down Expand Up @@ -1786,7 +1796,7 @@ def __hash__(self):
return hash((self.shape, self.dtype, self.weak_type,
getattr(self, 'sharding', None)))

def at_least_vspace(self):
def to_tangent_aval(self):
return ShapedArray(self.shape, primal_dtype_to_tangent_dtype(self.dtype),
self.weak_type)

Expand Down Expand Up @@ -1945,7 +1955,7 @@ def join(self, other):
else:
raise TypeError(self, other)

def at_least_vspace(self):
def to_tangent_aval(self):
return DShapedArray(self.shape, primal_dtype_to_tangent_dtype(self.dtype),
self.weak_type)

Expand Down Expand Up @@ -2076,7 +2086,7 @@ def join(self, other):
else:
assert False, f"Cannot join {self} with {other}"
def str_short(self, short_dtypes=False): return 'Tok'
def at_least_vspace(self): return self
def to_tangent_aval(self): return self
abstract_token: AbstractToken = AbstractToken()

# Singleton shaped array used by all abstract tokens when shape/dtype is needed.
Expand Down
49 changes: 24 additions & 25 deletions jax/_src/custom_derivatives.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def _sum_tangents(_, x, *xs):
return reduce(ad.add_tangents, xs, x)

def _zeros_like_pytree(x):
return tree_map(Zero.from_value, x)
return tree_map(Zero.from_primal_value, x)

_stop_gradient = partial(
tree_map,
Expand Down Expand Up @@ -327,24 +327,27 @@ def _flatten_jvp(primal_name, jvp_name, in_tree, maybe_out_type, *args):
"shapes/dtypes of:\n"
f""" {str(ty_tree_).replace("'", "")}""")
raise TypeError(m)
# TODO(mattjj): compare primals' tangent types to tangent objects' types
primal_avals_out = [raise_to_shaped(core.get_aval(x), weak_type=False)
for x in primals_out]
primal_avals_out = [raise_to_shaped(core.get_aval(x), weak_type=False) for x in primals_out]
expected_tangent_avals_out = [
raise_to_shaped(core.get_aval(x), weak_type=False).to_tangent_aval()
for x in primals_out]
tangent_avals_out = [raise_to_shaped(core.get_aval(t), weak_type=False)
if type(t) is not SymbolicZero else t.aval.strip_weak_type()
for t in tangents_out]
if primal_avals_out != tangent_avals_out:
if len(primal_avals_out) == 1:
(av1,), (av2,) = primal_avals_out, tangent_avals_out
if expected_tangent_avals_out != tangent_avals_out:
if len(expected_tangent_avals_out) == 1:
(av_p,), (av_et,), (av_t,) = primal_avals_out, expected_tangent_avals_out, tangent_avals_out
msg = ("Custom JVP rule must produce primal and tangent outputs with "
"equal shapes and dtypes, but got {} and {} respectively.")
raise TypeError(msg.format(av1.str_short(), av2.str_short()))
"corresponding shapes and dtypes. Expected {} (tangent type of {}) but got {}.")
raise TypeError(msg.format(av_et.str_short(), av_p.str_short(), av_t.str_short()))
else:
msg = ("Custom JVP rule must produce primal and tangent outputs with "
"equal shapes and dtypes, but got:\n{}")
"corresponding shapes and dtypes, but got:\n{}")
disagreements = (
f" primal {av1.str_short()} for tangent {av2.str_short()}"
for av1, av2 in zip(primal_avals_out, tangent_avals_out) if av1 != av2)
f" primal {av_p.str_short()} with tangent {av_t.str_short()}, expecting tangent {av_et}"
for av_p, av_et, av_t in zip(primal_avals_out, expected_tangent_avals_out, tangent_avals_out)
if av_et != av_t)

raise TypeError(msg.format('\n'.join(disagreements)))
yield primals_out + tangents_out, (out_tree, primal_avals)

Expand Down Expand Up @@ -392,7 +395,7 @@ def jvp(*xs):
out = core.eval_jaxpr(jvp_jaxpr, jvp_consts, *primals, *nonzero_tangents)
out_primals, nz_out_tangents = split_list(out, [len(out_zeros)])
nz_out_tangents_ = iter(nz_out_tangents)
out_tangents = [SymbolicZero(core.get_aval(p).at_least_vspace())
out_tangents = [SymbolicZero(core.get_aval(p).to_tangent_aval())
if z else next(nz_out_tangents_)
for p, z in zip(out_primals, out_zeros)]
assert next(nz_out_tangents_, None) is None
Expand Down Expand Up @@ -780,10 +783,10 @@ def append(x, d):
raise TypeError(msg.format(in_tree2, in_tree)) from None
results = []
for kp, a, ct in zip(keypaths, in_avals, cts_in_flat):
if ct is zero or a != a.at_least_vspace():
results.append(Zero(a.at_least_vspace()))
if ct is zero or a != a.to_tangent_aval():
results.append(Zero(a.to_tangent_aval()))
elif type(ct) is SymbolicZero:
if not core.typecompat(a.at_least_vspace(), a_ := ct.aval):
if not core.typecompat(a.to_tangent_aval(), a_ := ct.aval):
msg = ("Custom VJP bwd rule produced a SymbolicZero with a shape/dtype "
"that does not match the corresponding input tangent shape/dtype: "
f"at output{keystr(kp)} the SymbolicZero had shape/dtype "
Expand All @@ -794,7 +797,7 @@ def append(x, d):
raise ValueError(msg)
results.append(Zero(ct.aval))
else:
if (not core.typecompat(a.at_least_vspace(), a_ := core.get_aval(ct))
if (not core.typecompat(a.to_tangent_aval(), a_ := core.get_aval(ct))
and not (_temporary_dtype_exception(a, a_) or
_temporary_shape_exception(a, a_))):
msg = ("Custom VJP bwd rule must produce an output with the same "
Expand Down Expand Up @@ -908,16 +911,12 @@ def _custom_vjp_call_jaxpr_jvp(
_, res_tree = out_trees()
res_and_primals_out = core.eval_jaxpr(fwd_jaxpr, fwd_consts, *args)
res, primals_out = split_list(res_and_primals_out, [res_tree.num_leaves])
avals_out = [raise_to_shaped(core.get_aval(x)) for x in primals_out]
avals_out = [raise_to_shaped(core.get_aval(x)).to_tangent_aval() for x in primals_out]
args_dot = map(ad.instantiate_zeros, args_dot)
# Cast float0 to zeros with the primal dtype because custom vjp rules don't
# currently handle float0s
args_dot = map(ad.replace_float0s, args, args_dot)
tangents_out = ad.custom_lin_p.bind(
*res, *args_dot, num_res=res_tree.num_leaves, bwd=bwd,
out_avals=avals_out, symbolic_zeros=symbolic_zeros)
tangents_out = map(lax.tie_p.bind, primals_out, tangents_out)
tangents_out = map(ad.recast_to_float0, primals_out, tangents_out)
return primals_out, tangents_out
ad.primitive_jvps[custom_vjp_call_jaxpr_p] = _custom_vjp_call_jaxpr_jvp

Expand Down Expand Up @@ -1039,7 +1038,7 @@ def fwd(*args, **kwargs):
ans, rule = fun(*args, **kwargs)
ans_flat, out_tree = tree_flatten((ans,))
rule, in_tree = flatten_fun_nokwargs(lu.wrap_init(rule), out_tree)
ans_avals = [core.get_aval(x).at_least_vspace() for x in ans_flat]
ans_avals = [core.get_aval(x).to_tangent_aval() for x in ans_flat]
jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(rule, ans_avals)
return ans, Residuals(jaxpr, in_tree(), out_tree, consts)

Expand Down Expand Up @@ -1153,7 +1152,7 @@ def _maybe_perturbed(x: Any) -> bool:
elif isinstance(x, pe.DynamicJaxprTracer):
# If x is a DynamicJaxprTracer then we're staging out; differentiation could
# happen later, but some types always have trivial tangents.
vspace = x.aval.at_least_vspace()
vspace = x.aval.to_tangent_aval()
return not (vspace is core.abstract_token or
getattr(vspace, 'dtype', None) == dtypes.float0)
elif not isinstance(x, ad.JVPTracer):
Expand Down Expand Up @@ -1425,7 +1424,7 @@ def custom_vjp_by_custom_transpose(fun, fwd, bwd):
@fun.defjvp
def jvp(primals, tangents):
outs, residuals = fwd(*primals)
tan_out_types = tree_map(lambda o: core.get_aval(o).at_least_vspace(), outs)
tan_out_types = tree_map(lambda o: core.get_aval(o).to_tangent_aval(), outs)
tan_fn = custom_transpose(partial(disallow_jvp, out_avals=tan_out_types))
tan_fn.def_transpose(bwd)
return outs, tan_fn(tan_out_types, residuals, tangents)
Expand Down
2 changes: 1 addition & 1 deletion jax/_src/dtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -784,7 +784,7 @@ def check_user_dtype_supported(dtype, fun_name=None):
uint2,
uint4,
]
if np_dtype.kind not in "biufc" and not is_custom_dtype:
if np_dtype.kind not in "biufc" and not is_custom_dtype and not dtype == float0:
msg = f"JAX only supports number and bool dtypes, got dtype {dtype}"
msg += f" in {fun_name}" if fun_name else ""
raise TypeError(msg)
Expand Down
2 changes: 1 addition & 1 deletion jax/_src/export/_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -1127,7 +1127,7 @@ def flattened_primal_fun_jax(*args_flat):

vjp_in_avals = list(
itertools.chain(in_avals,
map(lambda a: a.at_least_vspace(), out_avals)))
map(lambda a: a.to_tangent_aval(), out_avals)))

if apply_jit:
assert device_assignment is not None
Expand Down
Loading

0 comments on commit 0e08f9f

Please sign in to comment.