From 785c2c8d5aab926064d0c004e8eccd4e6f6fb54a Mon Sep 17 00:00:00 2001 From: Dougal Maclaurin Date: Tue, 17 Sep 2024 09:54:15 -0700 Subject: [PATCH] Clean up and fix primal type to tangent type mapping This is part of the ["stackless"](#23299) change. I'm splitting it out into a separate PR because we need it for some work on sharding types. Changes: 1. Rename `at_least_vspace` to `to_tangent_type` since that's what we always meant by it. `at_least_vspace` was always a bad name (sorry!) but it makes even less sense when you can have a special tangent type for a primal types that's already a vector space itself. 2. Replace `Zero.from_value` with `Zero.from_primal_value`, which does the required primal-type-to-tangent-type conversion. 3. Add `to_tangent_type` calls in various other places they're missing. 4. Remove non-support for float0 in custom deriviatives? 5. [Optional, WIP] Reinstate some checks that had been skipped over, presumably because of these bugs. (We'll see how far I get with it. Might end up being a separate PR.) PiperOrigin-RevId: 675606346 --- jax/_src/ad_checkpoint.py | 2 +- jax/_src/ad_util.py | 9 ++++- jax/_src/api.py | 4 +- jax/_src/checkify.py | 2 +- jax/_src/core.py | 20 ++++++--- jax/_src/custom_derivatives.py | 49 +++++++++++------------ jax/_src/dtypes.py | 2 +- jax/_src/export/_export.py | 2 +- jax/_src/interpreters/ad.py | 47 ++++++++-------------- jax/_src/interpreters/partial_eval.py | 3 +- jax/_src/lax/ann.py | 4 +- jax/_src/lax/control_flow/conditionals.py | 2 +- jax/_src/lax/control_flow/for_loop.py | 2 +- jax/_src/lax/control_flow/loops.py | 4 +- jax/_src/lax/control_flow/solves.py | 4 +- jax/_src/lax/lax.py | 12 +++--- jax/_src/lax/linalg.py | 4 +- jax/_src/lax/slicing.py | 8 ++-- jax/_src/lax/windowed_reductions.py | 4 +- jax/_src/pallas/core.py | 4 +- jax/_src/state/discharge.py | 2 +- jax/_src/state/types.py | 4 +- jax/core.py | 1 + jax/experimental/attrs.py | 4 +- jax/experimental/shard_map.py | 2 +- jax/experimental/sparse/bcoo.py | 16 ++++---- jax/experimental/sparse/bcsr.py | 4 +- jax/experimental/sparse/coo.py | 4 +- jax/experimental/sparse/csr.py | 4 +- jax/interpreters/ad.py | 2 - tests/api_test.py | 23 ++++++----- tests/export_test.py | 3 +- 32 files changed, 130 insertions(+), 127 deletions(-) diff --git a/jax/_src/ad_checkpoint.py b/jax/_src/ad_checkpoint.py index fd30119882e7..8c7fe2f489d5 100644 --- a/jax/_src/ad_checkpoint.py +++ b/jax/_src/ad_checkpoint.py @@ -514,7 +514,7 @@ def remat_jvp(primals, tangents, jaxpr, prevent_cse, differentiated, policy): prevent_cse=prevent_cse, differentiated=differentiated, policy=policy) out_primals, out_tangents_ = split_list(outs, [len(jaxpr.outvars)]) out_tangents_ = iter(out_tangents_) - out_tangents = [next(out_tangents_) if nz else ad_util.Zero.from_value(p) + out_tangents = [next(out_tangents_) if nz else ad_util.Zero.from_primal_value(p) for p, nz in zip(out_primals, out_nz)] return out_primals, out_tangents ad.primitive_jvps[remat_p] = remat_jvp diff --git a/jax/_src/ad_util.py b/jax/_src/ad_util.py index 57e881c34f82..c69ff3754dc6 100644 --- a/jax/_src/ad_util.py +++ b/jax/_src/ad_util.py @@ -65,8 +65,8 @@ def __init__(self, aval: core.AbstractValue): def __repr__(self) -> str: return f'Zero({self.aval})' @staticmethod - def from_value(val: Any) -> Zero: - return Zero(raise_to_shaped(get_aval(val))) + def from_primal_value(val: Any) -> Zero: + return Zero(raise_to_shaped(get_aval(val)).to_tangent_aval()) register_pytree_node(Zero, lambda z: ((), z.aval), lambda aval, _: Zero(aval)) @@ -82,6 +82,7 @@ def _stop_gradient_impl(x: T) -> T: stop_gradient_p.def_abstract_eval(lambda x: x) +# User-facing version of `Zero` class SymbolicZero: def __init__(self, aval: core.AbstractValue) -> None: self.aval = aval @@ -108,6 +109,10 @@ def __getattr__(self, name): else: return attr + @staticmethod + def from_primal_value(val: Any) -> SymbolicZero: + return SymbolicZero(get_aval(val).to_tangent_aval()) + JaxTypeOrTracer = Any def replace_internal_symbolic_zeros( diff --git a/jax/_src/api.py b/jax/_src/api.py index b548cc43fb3b..aae99a28bbea 100644 --- a/jax/_src/api.py +++ b/jax/_src/api.py @@ -1826,7 +1826,7 @@ def _lift_linearized(jaxpr, primal_avals, io_tree, out_pvals, consts, *py_args): def fun(*tangents): tangent_avals = list(map(core.get_aval, tangents)) for primal_aval, tangent_aval in zip(primal_avals, tangent_avals): - if not core.typecompat(primal_aval.at_least_vspace(), tangent_aval): + if not core.typecompat(primal_aval.to_tangent_aval(), tangent_aval): raise ValueError("linearized function called on tangent values inconsistent with " "the original primal values: " f"got {tangent_aval} for primal aval {primal_aval}") @@ -1869,7 +1869,7 @@ def _vjp_pullback_wrapper(name, out_primal_avals, io_tree, fun, *py_args_): f"got {in_tree}, but expected to match {in_tree_expected}") for arg, aval in zip(args, out_primal_avals): ct_aval = shaped_abstractify(arg) - ct_aval_expected = aval.at_least_vspace() + ct_aval_expected = aval.to_tangent_aval() if (not core.typecompat(ct_aval, ct_aval_expected) and not _temporary_dtype_exception(ct_aval, ct_aval_expected)): raise ValueError( diff --git a/jax/_src/checkify.py b/jax/_src/checkify.py index 1167914e51c9..e67f624fc32e 100644 --- a/jax/_src/checkify.py +++ b/jax/_src/checkify.py @@ -980,7 +980,7 @@ def jvp(*xs): out = core.eval_jaxpr(jvp_jaxpr, jvp_consts, *primals, *nonzero_tangents) out_primals, nz_out_tangents = split_list(out, [len(out_zeros)]) nz_out_tangents_ = iter(nz_out_tangents) - out_tangents = [SymbolicZero(core.get_aval(p).at_least_vspace()) + out_tangents = [SymbolicZero(core.get_aval(p).to_tangent_aval()) if z else next(nz_out_tangents_) for p, z in zip(out_primals, out_zeros)] assert next(nz_out_tangents_, None) is None diff --git a/jax/_src/core.py b/jax/_src/core.py index 51933a9f8bbf..057a79925e2e 100644 --- a/jax/_src/core.py +++ b/jax/_src/core.py @@ -1414,9 +1414,13 @@ def definitely_equal(x, y): class AbstractValue: __slots__: list[str] = [] - def at_least_vspace(self): + def to_tangent_aval(self): raise NotImplementedError("must override") + # TODO(dougalm): deprecate this alias + def at_least_vspace(self): + return self.to_tangent_aval() + def __repr__(self): try: kv_pairs = (f'{k}={v}' for k, v in self.__dict__.items()) @@ -1524,6 +1528,12 @@ def get_aval(x): else: return concrete_aval(x) +def get_type(x): + aval = get_aval(x) + if isinstance(aval, ConcreteArray): + return raise_to_shaped(aval) + else: + return aval def concretization_function_error(fun, suggest_astype=False): fname = getattr(fun, "__name__", fun) @@ -1647,7 +1657,7 @@ def __repr__(self): _oct = concretization_function_error(oct) _index = concretization_function_error(operator.index) - def at_least_vspace(self) -> AbstractValue: + def to_tangent_aval(self) -> AbstractValue: return UnshapedArray(primal_dtype_to_tangent_dtype(self.dtype), self.weak_type) @@ -1786,7 +1796,7 @@ def __hash__(self): return hash((self.shape, self.dtype, self.weak_type, getattr(self, 'sharding', None))) - def at_least_vspace(self): + def to_tangent_aval(self): return ShapedArray(self.shape, primal_dtype_to_tangent_dtype(self.dtype), self.weak_type) @@ -1945,7 +1955,7 @@ def join(self, other): else: raise TypeError(self, other) - def at_least_vspace(self): + def to_tangent_aval(self): return DShapedArray(self.shape, primal_dtype_to_tangent_dtype(self.dtype), self.weak_type) @@ -2076,7 +2086,7 @@ def join(self, other): else: assert False, f"Cannot join {self} with {other}" def str_short(self, short_dtypes=False): return 'Tok' - def at_least_vspace(self): return self + def to_tangent_aval(self): return self abstract_token: AbstractToken = AbstractToken() # Singleton shaped array used by all abstract tokens when shape/dtype is needed. diff --git a/jax/_src/custom_derivatives.py b/jax/_src/custom_derivatives.py index 019948c36683..05ede08d219c 100644 --- a/jax/_src/custom_derivatives.py +++ b/jax/_src/custom_derivatives.py @@ -67,7 +67,7 @@ def _sum_tangents(_, x, *xs): return reduce(ad.add_tangents, xs, x) def _zeros_like_pytree(x): - return tree_map(Zero.from_value, x) + return tree_map(Zero.from_primal_value, x) _stop_gradient = partial( tree_map, @@ -327,24 +327,27 @@ def _flatten_jvp(primal_name, jvp_name, in_tree, maybe_out_type, *args): "shapes/dtypes of:\n" f""" {str(ty_tree_).replace("'", "")}""") raise TypeError(m) - # TODO(mattjj): compare primals' tangent types to tangent objects' types - primal_avals_out = [raise_to_shaped(core.get_aval(x), weak_type=False) - for x in primals_out] + primal_avals_out = [raise_to_shaped(core.get_aval(x), weak_type=False) for x in primals_out] + expected_tangent_avals_out = [ + raise_to_shaped(core.get_aval(x), weak_type=False).to_tangent_aval() + for x in primals_out] tangent_avals_out = [raise_to_shaped(core.get_aval(t), weak_type=False) if type(t) is not SymbolicZero else t.aval.strip_weak_type() for t in tangents_out] - if primal_avals_out != tangent_avals_out: - if len(primal_avals_out) == 1: - (av1,), (av2,) = primal_avals_out, tangent_avals_out + if expected_tangent_avals_out != tangent_avals_out: + if len(expected_tangent_avals_out) == 1: + (av_p,), (av_et,), (av_t,) = primal_avals_out, expected_tangent_avals_out, tangent_avals_out msg = ("Custom JVP rule must produce primal and tangent outputs with " - "equal shapes and dtypes, but got {} and {} respectively.") - raise TypeError(msg.format(av1.str_short(), av2.str_short())) + "corresponding shapes and dtypes. Expected {} (tangent type of {}) but got {}.") + raise TypeError(msg.format(av_et.str_short(), av_p.str_short(), av_t.str_short())) else: msg = ("Custom JVP rule must produce primal and tangent outputs with " - "equal shapes and dtypes, but got:\n{}") + "corresponding shapes and dtypes, but got:\n{}") disagreements = ( - f" primal {av1.str_short()} for tangent {av2.str_short()}" - for av1, av2 in zip(primal_avals_out, tangent_avals_out) if av1 != av2) + f" primal {av_p.str_short()} with tangent {av_t.str_short()}, expecting tangent {av_et}" + for av_p, av_et, av_t in zip(primal_avals_out, expected_tangent_avals_out, tangent_avals_out) + if av_et != av_t) + raise TypeError(msg.format('\n'.join(disagreements))) yield primals_out + tangents_out, (out_tree, primal_avals) @@ -392,7 +395,7 @@ def jvp(*xs): out = core.eval_jaxpr(jvp_jaxpr, jvp_consts, *primals, *nonzero_tangents) out_primals, nz_out_tangents = split_list(out, [len(out_zeros)]) nz_out_tangents_ = iter(nz_out_tangents) - out_tangents = [SymbolicZero(core.get_aval(p).at_least_vspace()) + out_tangents = [SymbolicZero(core.get_aval(p).to_tangent_aval()) if z else next(nz_out_tangents_) for p, z in zip(out_primals, out_zeros)] assert next(nz_out_tangents_, None) is None @@ -780,10 +783,10 @@ def append(x, d): raise TypeError(msg.format(in_tree2, in_tree)) from None results = [] for kp, a, ct in zip(keypaths, in_avals, cts_in_flat): - if ct is zero or a != a.at_least_vspace(): - results.append(Zero(a.at_least_vspace())) + if ct is zero or a != a.to_tangent_aval(): + results.append(Zero(a.to_tangent_aval())) elif type(ct) is SymbolicZero: - if not core.typecompat(a.at_least_vspace(), a_ := ct.aval): + if not core.typecompat(a.to_tangent_aval(), a_ := ct.aval): msg = ("Custom VJP bwd rule produced a SymbolicZero with a shape/dtype " "that does not match the corresponding input tangent shape/dtype: " f"at output{keystr(kp)} the SymbolicZero had shape/dtype " @@ -794,7 +797,7 @@ def append(x, d): raise ValueError(msg) results.append(Zero(ct.aval)) else: - if (not core.typecompat(a.at_least_vspace(), a_ := core.get_aval(ct)) + if (not core.typecompat(a.to_tangent_aval(), a_ := core.get_aval(ct)) and not (_temporary_dtype_exception(a, a_) or _temporary_shape_exception(a, a_))): msg = ("Custom VJP bwd rule must produce an output with the same " @@ -908,16 +911,12 @@ def _custom_vjp_call_jaxpr_jvp( _, res_tree = out_trees() res_and_primals_out = core.eval_jaxpr(fwd_jaxpr, fwd_consts, *args) res, primals_out = split_list(res_and_primals_out, [res_tree.num_leaves]) - avals_out = [raise_to_shaped(core.get_aval(x)) for x in primals_out] + avals_out = [raise_to_shaped(core.get_aval(x)).to_tangent_aval() for x in primals_out] args_dot = map(ad.instantiate_zeros, args_dot) - # Cast float0 to zeros with the primal dtype because custom vjp rules don't - # currently handle float0s - args_dot = map(ad.replace_float0s, args, args_dot) tangents_out = ad.custom_lin_p.bind( *res, *args_dot, num_res=res_tree.num_leaves, bwd=bwd, out_avals=avals_out, symbolic_zeros=symbolic_zeros) tangents_out = map(lax.tie_p.bind, primals_out, tangents_out) - tangents_out = map(ad.recast_to_float0, primals_out, tangents_out) return primals_out, tangents_out ad.primitive_jvps[custom_vjp_call_jaxpr_p] = _custom_vjp_call_jaxpr_jvp @@ -1039,7 +1038,7 @@ def fwd(*args, **kwargs): ans, rule = fun(*args, **kwargs) ans_flat, out_tree = tree_flatten((ans,)) rule, in_tree = flatten_fun_nokwargs(lu.wrap_init(rule), out_tree) - ans_avals = [core.get_aval(x).at_least_vspace() for x in ans_flat] + ans_avals = [core.get_aval(x).to_tangent_aval() for x in ans_flat] jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(rule, ans_avals) return ans, Residuals(jaxpr, in_tree(), out_tree, consts) @@ -1153,7 +1152,7 @@ def _maybe_perturbed(x: Any) -> bool: elif isinstance(x, pe.DynamicJaxprTracer): # If x is a DynamicJaxprTracer then we're staging out; differentiation could # happen later, but some types always have trivial tangents. - vspace = x.aval.at_least_vspace() + vspace = x.aval.to_tangent_aval() return not (vspace is core.abstract_token or getattr(vspace, 'dtype', None) == dtypes.float0) elif not isinstance(x, ad.JVPTracer): @@ -1425,7 +1424,7 @@ def custom_vjp_by_custom_transpose(fun, fwd, bwd): @fun.defjvp def jvp(primals, tangents): outs, residuals = fwd(*primals) - tan_out_types = tree_map(lambda o: core.get_aval(o).at_least_vspace(), outs) + tan_out_types = tree_map(lambda o: core.get_aval(o).to_tangent_aval(), outs) tan_fn = custom_transpose(partial(disallow_jvp, out_avals=tan_out_types)) tan_fn.def_transpose(bwd) return outs, tan_fn(tan_out_types, residuals, tangents) diff --git a/jax/_src/dtypes.py b/jax/_src/dtypes.py index 81f4180a1c12..d76b80ad3a89 100644 --- a/jax/_src/dtypes.py +++ b/jax/_src/dtypes.py @@ -784,7 +784,7 @@ def check_user_dtype_supported(dtype, fun_name=None): uint2, uint4, ] - if np_dtype.kind not in "biufc" and not is_custom_dtype: + if np_dtype.kind not in "biufc" and not is_custom_dtype and not dtype == float0: msg = f"JAX only supports number and bool dtypes, got dtype {dtype}" msg += f" in {fun_name}" if fun_name else "" raise TypeError(msg) diff --git a/jax/_src/export/_export.py b/jax/_src/export/_export.py index d0159f7a4334..7f7773acbd39 100644 --- a/jax/_src/export/_export.py +++ b/jax/_src/export/_export.py @@ -1127,7 +1127,7 @@ def flattened_primal_fun_jax(*args_flat): vjp_in_avals = list( itertools.chain(in_avals, - map(lambda a: a.at_least_vspace(), out_avals))) + map(lambda a: a.to_tangent_aval(), out_avals))) if apply_jit: assert device_assignment is not None diff --git a/jax/_src/interpreters/ad.py b/jax/_src/interpreters/ad.py index f1b25cf96a95..f1f46a5c18f7 100644 --- a/jax/_src/interpreters/ad.py +++ b/jax/_src/interpreters/ad.py @@ -57,7 +57,7 @@ def _update_annotation( # Implicit arguments never have tangents, so generate the tangent part of the # type annotation from explicit arguments only. explicit_avals = [aval for aval, explicit in orig_type if explicit] - tan_types = [(aval.at_least_vspace(), True) + tan_types = [(aval.to_tangent_aval(), True) for nz, aval in zip(explicit_nonzeros, explicit_avals) if nz] return lu.annotate(f, (*orig_type, *tan_types)) @@ -72,7 +72,7 @@ def jvp(fun: lu.WrappedFun, has_aux=False, instantiate=True, @lu.transformation def jvpfun(instantiate, transform_stack, primals, tangents): - tangents = [Zero.from_value(t) if not isinstance(t, Zero) + tangents = [Zero.from_primal_value(t) if not isinstance(t, Zero) and dtype(t) == float0 else t for t in tangents] ctx = (source_info_util.transform_name_stack('jvp') if transform_stack else contextlib.nullcontext()) @@ -124,7 +124,7 @@ def linearize(traceable, *primals, **kwargs): jvpfun, aux = jvp(traceable, has_aux=True) in_pvals = (tuple(pe.PartialVal.known(p) for p in primals) - + tuple(pe.PartialVal.unknown(get_aval(p).at_least_vspace()) + + tuple(pe.PartialVal.unknown(get_aval(p).to_tangent_aval()) for p in primals)) _, in_tree = tree_flatten(((primals, primals), {})) jvpfun_flat, out_tree = flatten_fun(jvpfun, in_tree) @@ -166,18 +166,6 @@ def unpair_pval(pval): aval_1, aval_2 = aval return (aval_1, const_1), (aval_2, const_2) -def replace_float0s(primal, tangent): - if dtype(tangent) == float0: - return zeros_like_jaxval(primal) - else: - return tangent - -def recast_to_float0(primal, tangent): - if core.primal_dtype_to_tangent_dtype(dtype(primal)) == float0: - return Zero(get_aval(primal).at_least_vspace()) - else: - return tangent - # NOTE: The FIXMEs below are caused by primal/tangent mixups (type # errors if you will) @@ -203,7 +191,7 @@ def write_cotangent(prim, v, ct): # assert v.aval.strip_weak_type() == joined_aval, (prim, v.aval, ct_aval) def read_cotangent(v): - return ct_env.pop(v, Zero(v.aval.at_least_vspace())) + return ct_env.pop(v, Zero(v.aval.to_tangent_aval())) def read_primal(v): if type(v) is Literal: @@ -295,11 +283,11 @@ def nonzero_tangent_outputs(*args, **kwargs): class JVPTrace(Trace): def pure(self, val): - tangent_zero = Zero(get_aval(val).at_least_vspace()) + tangent_zero = Zero.from_primal_value(val) return JVPTracer(self, val, tangent_zero) def lift(self, val): - tangent_zero = Zero(get_aval(val).at_least_vspace()) + tangent_zero = Zero.from_primal_value(val) return JVPTracer(self, val, tangent_zero) def sublift(self, val): @@ -343,7 +331,7 @@ def new_out_axes_thunk(): result = call_primitive.bind(_update_annotation(f_jvp, f.in_type, which_nz), *args, **new_params) primal_out, tangent_out = tree_unflatten(out_tree(), result) - tangent_out = [Zero(get_aval(p).at_least_vspace()) if t is None else t + tangent_out = [Zero.from_primal_value(p) if t is None else t for p, t in zip(primal_out, tangent_out)] return [JVPTracer(self, p, t) for p, t in zip(primal_out, tangent_out)] @@ -374,13 +362,11 @@ def process_custom_jvp_call(self, _, __, f_jvp, tracers, *, symbolic_zeros): primals_in = map(core.full_lower, primals_in) if not symbolic_zeros: tangents_in = map(instantiate_zeros, tangents_in) - tangents_in = map(replace_float0s, primals_in, tangents_in) else: tangents_in = map(replace_internal_symbolic_zeros, tangents_in) outs = f_jvp.call_wrapped(*it.chain(primals_in, tangents_in)) primals_out, tangents_out = split_list(outs, [len(outs) // 2]) tangents_out = map(replace_rule_output_symbolic_zeros, tangents_out) - tangents_out = map(recast_to_float0, primals_out, tangents_out) return map(partial(JVPTracer, self), primals_out, tangents_out) def post_process_custom_jvp_call(self, out_tracers, _): @@ -398,14 +384,13 @@ def process_custom_vjp_call(self, _, __, fwd, bwd, tracers, out_trees, res_and_primals_out = fwd.call_wrapped(*fwd_in) _, res_tree = out_trees() res, primals_out = split_list(res_and_primals_out, [res_tree.num_leaves]) - avals_out = [raise_to_shaped(core.get_aval(x)) for x in primals_out] + avals_out = [raise_to_shaped(core.get_aval(x)).to_tangent_aval() for x in primals_out] # TODO(frostig,mattjj): avoid instantiating zeros when we don't have to! tangents_in = map(instantiate_zeros, tangents_in) tangents_out = custom_lin_p.bind( *res, *tangents_in, num_res=res_tree.num_leaves, bwd=bwd, out_avals=avals_out, symbolic_zeros=symbolic_zeros) tangents_out = map(lax.tie_p.bind, primals_out, tangents_out) - tangents_out = map(recast_to_float0, primals_out, tangents_out) return map(partial(JVPTracer, self), primals_out, tangents_out) def post_process_custom_vjp_call(self, out_tracers, _): @@ -505,8 +490,8 @@ def linear_jvp(primitive, primals, tangents, **params): val_out = primitive.bind(*primals, **params) if all(type(tangent) is Zero for tangent in tangents): if primitive.multiple_results: - return val_out, map(Zero.from_value, val_out) - return val_out, Zero.from_value(val_out) + return val_out, map(Zero.from_primal_value, val_out) + return val_out, Zero.from_primal_value(val_out) else: tangents = map(instantiate_zeros, tangents) return val_out, primitive.bind(*tangents, **params) @@ -533,7 +518,7 @@ def standard_jvp(jvprules, primitive, primals, tangents, **params): val_out = primitive.bind(*primals, **params) tangents_out = [rule(t, *primals, **params) for rule, t in zip(jvprules, tangents) if rule is not None and type(t) is not Zero] - return val_out, functools.reduce(add_tangents, tangents_out, Zero.from_value(val_out)) + return val_out, functools.reduce(add_tangents, tangents_out, Zero.from_primal_value(val_out)) def defjvp2(primitive, *jvprules): assert isinstance(primitive, Primitive) @@ -545,7 +530,7 @@ def standard_jvp2(jvprules, primitive, primals, tangents, **params): tangents_out = (rule(t, val_out, *primals, **params) for rule, t in zip(jvprules, tangents) if rule is not None and type(t) is not Zero) tangents_out = list(tangents_out) - return val_out, functools.reduce(add_tangents, tangents_out, Zero.from_value(val_out)) + return val_out, functools.reduce(add_tangents, tangents_out, Zero.from_primal_value(val_out)) def add_tangents(x, y): if type(x) is Zero: @@ -580,7 +565,7 @@ def defjvp_zero(primitive): def zero_jvp(primitive, primals, tangents, **params): r = primitive.bind(*primals, **params) - return r, Zero.from_value(r) + return r, Zero.from_primal_value(r) deflinear2(add_jaxvals_p, lambda t, *args: (t, t)) @@ -591,7 +576,7 @@ def instantiate_zeros(tangent): @lu.transformation_with_aux def traceable(in_tree, *primals_and_tangents): primals, tangents = tree_unflatten(in_tree, primals_and_tangents) - tangents = [Zero(get_aval(p).at_least_vspace()) if t is None else t + tangents = [Zero.from_primal_value(p) if t is None else t for p, t in zip(primals, tangents)] primals_out, tangents_out = yield (primals, tangents), {} tangents_out = [None if type(t) is Zero else t for t in tangents_out] @@ -695,7 +680,7 @@ def _jvp_jaxpr(jaxpr, nonzeros, instantiate): f = lu.wrap_init(core.jaxpr_as_fun(jaxpr)) f_jvp, out_nonzeros = f_jvp_traceable(jvp(f, instantiate=instantiate, transform_stack=False), nonzeros) - tangent_avals = [aval for aval, nz in zip(jaxpr.in_avals, nonzeros) if nz] + tangent_avals = [aval.to_tangent_aval() for aval, nz in zip(jaxpr.in_avals, nonzeros) if nz] avals_in = list(it.chain(jaxpr.in_avals, tangent_avals)) jaxpr_out, avals_out, literals_out, () = pe.trace_to_jaxpr_dynamic(f_jvp, avals_in) return core.ClosedJaxpr(jaxpr_out, literals_out), out_nonzeros() @@ -705,7 +690,7 @@ def f_jvp_traceable(nonzeros, *primals_and_nztangents): num_primals = len(nonzeros) primals = list(primals_and_nztangents[:num_primals]) nonzero_tangents = iter(primals_and_nztangents[num_primals:]) - tangents = [next(nonzero_tangents) if nz else Zero.from_value(p) + tangents = [next(nonzero_tangents) if nz else Zero.from_primal_value(p) for p, nz in zip(primals, nonzeros)] primals_out, tangents_out = yield (primals, tangents), {} out_nonzeros = [type(t) is not Zero for t in tangents_out] diff --git a/jax/_src/interpreters/partial_eval.py b/jax/_src/interpreters/partial_eval.py index 374816e001ec..fc2214aaf29f 100644 --- a/jax/_src/interpreters/partial_eval.py +++ b/jax/_src/interpreters/partial_eval.py @@ -2158,6 +2158,7 @@ def post_process_map(self, map_primitive, out_tracers, params): def process_custom_jvp_call(self, prim, fun, jvp, tracers, *, symbolic_zeros): in_avals = [t.aval for t in tracers] + in_tangent_avals = [t.to_tangent_aval() for t in in_avals] with core.new_sublevel(): fun_jaxpr, out_avals, consts, () = trace_to_subjaxpr_dynamic(fun, self.main, in_avals) closed_fun_jaxpr = core.ClosedJaxpr(convert_constvars_jaxpr(fun_jaxpr), ()) @@ -2166,7 +2167,7 @@ def process_custom_jvp_call(self, prim, fun, jvp, tracers, *, symbolic_zeros): @_memoize def jvp_jaxpr_thunk(*in_zeros): for store in jvp.stores: store and store.reset() - nz_tangent_avals, zero_avals = partition_list(in_zeros, in_avals) + nz_tangent_avals, zero_avals = partition_list(in_zeros, in_tangent_avals) jvp_, out_zeros = _jvp_jaxpr_zeros(jvp, in_zeros, tuple(zero_avals)) in_avals_ = (*in_avals, *nz_tangent_avals) jaxpr, _, out_consts, () = trace_to_subjaxpr_dynamic(jvp_, main_(), in_avals_) diff --git a/jax/_src/lax/ann.py b/jax/_src/lax/ann.py index f2dbd8d4fa0e..0e037ec774b5 100644 --- a/jax/_src/lax/ann.py +++ b/jax/_src/lax/ann.py @@ -373,7 +373,7 @@ def _approx_top_k_jvp(primals, tangents, *, k, reduction_dimension, reduction_input_size_override, aggregate_to_topk) if type(tangent) is ad_util.Zero: - tangent_out = ad_util.Zero.from_value(val_out) + tangent_out = ad_util.Zero.from_primal_value(val_out) else: arg_shape = arg_out.shape rank = len(arg_shape) @@ -385,7 +385,7 @@ def _approx_top_k_jvp(primals, tangents, *, k, reduction_dimension, idx = tuple( arg_out if i == reduction_dimension else iotas[i] for i in range(rank)) tangent_out = tangent[idx] - return (val_out, arg_out), (tangent_out, ad_util.Zero.from_value(arg_out)) + return (val_out, arg_out), (tangent_out, ad_util.Zero.from_primal_value(arg_out)) approx_top_k_p = core.Primitive('approx_top_k') diff --git a/jax/_src/lax/control_flow/conditionals.py b/jax/_src/lax/control_flow/conditionals.py index b96f9e8c6e40..4cb38d28c36f 100644 --- a/jax/_src/lax/control_flow/conditionals.py +++ b/jax/_src/lax/control_flow/conditionals.py @@ -434,7 +434,7 @@ def _cond_jvp(primals, tangents, branches): out = cond_p.bind(index, *ops, *ops_dot, branches=branches_jvp) out_primals, out_tangents = split_list(out, [len(out_nz)]) out_tangents_iter = iter(out_tangents) - out_tangents = [next(out_tangents_iter) if nz else ad_util.Zero.from_value(p) + out_tangents = [next(out_tangents_iter) if nz else ad_util.Zero.from_primal_value(p) for p, nz in zip(out_primals, out_nz)] return out_primals, out_tangents diff --git a/jax/_src/lax/control_flow/for_loop.py b/jax/_src/lax/control_flow/for_loop.py index 61b9a24644ce..21b522b3d8bb 100644 --- a/jax/_src/lax/control_flow/for_loop.py +++ b/jax/_src/lax/control_flow/for_loop.py @@ -340,7 +340,7 @@ def _for_jvp(primals, tangents, *, jaxpr, nsteps, reverse, which_linear, # into outputs as well. We don't care about these in AD so we throw them out. out_primals, out_tangents = split_list(out_flat, [len(primals)]) out_tangents_iter = iter(out_tangents) - out_tangents = [next(out_tangents_iter) if nz else ad_util.Zero.from_value(p) + out_tangents = [next(out_tangents_iter) if nz else ad_util.Zero.from_primal_value(p) for p, nz in zip(out_primals, nonzero_tangents)] return out_primals, out_tangents ad.primitive_jvps[for_p] = _for_jvp diff --git a/jax/_src/lax/control_flow/loops.py b/jax/_src/lax/control_flow/loops.py index 828728ebdbd2..41d809f8d688 100644 --- a/jax/_src/lax/control_flow/loops.py +++ b/jax/_src/lax/control_flow/loops.py @@ -547,7 +547,7 @@ def _scan_jvp(primals, tangents, reverse, length, jaxpr, num_consts, num_carry, carry, carry_dot, ys, ys_dot = split_list(out_flat, [num_carry, len(init_dot), num_ys]) primals_out = carry + ys tangents_out_iter = iter(carry_dot + ys_dot) - tangents_out = [next(tangents_out_iter) if nz else ad_util.Zero.from_value(p) + tangents_out = [next(tangents_out_iter) if nz else ad_util.Zero.from_primal_value(p) for p, nz in zip(primals_out, nonzeros_out)] return primals_out, tangents_out @@ -1518,7 +1518,7 @@ def _while_loop_jvp(primals, tangents, cond_nconsts, cond_jaxpr, body_nconsts, out_carry, out_carry_dot = split_list(out, [num_carry]) out_tangents_iter = iter(out_carry_dot) - out_tangents = [next(out_tangents_iter) if nz else ad_util.Zero.from_value(p) + out_tangents = [next(out_tangents_iter) if nz else ad_util.Zero.from_primal_value(p) for p, nz in zip(out_carry, nonzeros_out)] return out_carry, out_tangents diff --git a/jax/_src/lax/control_flow/solves.py b/jax/_src/lax/control_flow/solves.py index 21105e20aaf8..4e0f5086b121 100644 --- a/jax/_src/lax/control_flow/solves.py +++ b/jax/_src/lax/control_flow/solves.py @@ -316,7 +316,7 @@ def _tangent_linear_map(func, params, params_dot, *x): this function computes ``∂A @ x``. """ assert any(type(p) is not ad_util.Zero for p in params_dot) - zeros = _map(ad_util.Zero.from_value, x) + zeros = _map(ad_util.Zero.from_primal_value, x) _, out_tangent = ad.jvp(lu.wrap_init(func)).call_wrapped( params + list(x), params_dot + zeros) return out_tangent @@ -352,7 +352,7 @@ def _custom_linear_solve_jvp(primals, tangents, const_lengths, jaxprs): # split into x tangents and aux tangents (these become zero) dx_leaves, daux_leaves = split_list(x_dot, [num_x_leaves]) - daux_leaves = _map(ad_util.Zero.from_value, daux_leaves) + daux_leaves = _map(ad_util.Zero.from_primal_value, daux_leaves) x_dot = dx_leaves + daux_leaves diff --git a/jax/_src/lax/lax.py b/jax/_src/lax/lax.py index 8d2c24d6e64c..83a2e5ef2ec8 100644 --- a/jax/_src/lax/lax.py +++ b/jax/_src/lax/lax.py @@ -2300,7 +2300,7 @@ def _add_jvp(primals, tangents): xdot, ydot = tangents primal_out = add(x, y) if type(xdot) is type(ydot) is ad_util.Zero: - return primal_out, ad_util.Zero.from_value(primal_out) + return primal_out, ad_util.Zero.from_primal_value(primal_out) if type(xdot) is ad_util.Zero: return primal_out, _maybe_broadcast(primal_out.shape, ydot) elif type(ydot) is ad_util.Zero: @@ -2331,7 +2331,7 @@ def _sub_jvp(primals, tangents): xdot, ydot = tangents primal_out = sub(x, y) if type(xdot) is type(ydot) is ad_util.Zero: - return primal_out, ad_util.Zero.from_value(primal_out) + return primal_out, ad_util.Zero.from_primal_value(primal_out) if type(xdot) is ad_util.Zero: return primal_out, _maybe_broadcast(primal_out.shape, neg(ydot)) elif type(ydot) is ad_util.Zero: @@ -3355,7 +3355,7 @@ def _broadcast_in_dim_jvp_rule(primals, tangents, *, shape, broadcast_dimensions y = broadcast_in_dim_p.bind(operand, *dyn_shape, shape=shape, broadcast_dimensions=broadcast_dimensions) if type(operand_dot) is ad_util.Zero: - y_dot = ad_util.Zero.from_value(y) + y_dot = ad_util.Zero.from_primal_value(y) else: y_dot = broadcast_in_dim_p.bind(operand_dot, *dyn_shape, shape=shape, broadcast_dimensions=broadcast_dimensions) @@ -4525,7 +4525,7 @@ def _top_k_jvp(primals, tangents, *, k): tangent, = tangents primals_out = top_k(operand, k) if type(tangent) is ad_util.Zero: - tangent_out = ad_util.Zero.from_value(primals_out[0]) + tangent_out = ad_util.Zero.from_primal_value(primals_out[0]) else: _, k_idxs = primals_out idx_shape = k_idxs.shape @@ -4544,7 +4544,7 @@ def _top_k_jvp(primals, tangents, *, k): collapsed_slice_dims=tuple(range(rank)), start_index_map=tuple(range(rank))) tangent_out = slicing.gather(tangent, gather_indices, dnums, slice_sizes) - return primals_out, (tangent_out, ad_util.Zero.from_value(primals_out[1])) + return primals_out, (tangent_out, ad_util.Zero.from_primal_value(primals_out[1])) def _top_k_batch_rule(batched_args, batch_dims, *, k): operand, = batched_args @@ -4580,7 +4580,7 @@ def _top_k_lower(ctx, operand, k): def _stop_gradient_jvp_rule(primals, tangents): # if we don't call stop_gradient here, we'd only peel off one autodiff tracer x, = primals - return stop_gradient(x), ad_util.Zero.from_value(x) + return stop_gradient(x), ad_util.Zero.from_primal_value(x) def _stop_gradient_batch_rule(batched_args, batch_dims): x, = batched_args diff --git a/jax/_src/lax/linalg.py b/jax/_src/lax/linalg.py index 8752e0b6d1de..ec0a075dae1b 100644 --- a/jax/_src/lax/linalg.py +++ b/jax/_src/lax/linalg.py @@ -1487,8 +1487,8 @@ def _lu_jvp_rule(primals, tangents): l_dot = l @ _tril(lau, -1) u_dot = _triu(lau) @ u lu_dot = l_dot + u_dot - return (lu, pivots, permutation), (lu_dot, ad_util.Zero.from_value(pivots), - ad_util.Zero.from_value(permutation)) + return (lu, pivots, permutation), (lu_dot, ad_util.Zero.from_primal_value(pivots), + ad_util.Zero.from_primal_value(permutation)) def _lu_batching_rule(batched_args, batch_dims): diff --git a/jax/_src/lax/slicing.py b/jax/_src/lax/slicing.py index 39d4b31588c1..5ed1945ecb96 100644 --- a/jax/_src/lax/slicing.py +++ b/jax/_src/lax/slicing.py @@ -1362,7 +1362,7 @@ def _dynamic_update_slice_jvp(primals, tangents): g_operand, g_update = tangents[:2] val_out = dynamic_update_slice_p.bind(operand, update, *start_indices) if type(g_operand) is ad_util.Zero and type(g_update) is ad_util.Zero: - tangent_out = ad_util.Zero.from_value(val_out) + tangent_out = ad_util.Zero.from_primal_value(val_out) else: g_operand = ad.instantiate_zeros(g_operand) g_update = ad.instantiate_zeros(g_update) @@ -2000,7 +2000,7 @@ def _scatter_add_jvp(primals, tangents, *, update_jaxpr, update_consts, indices_are_sorted=indices_are_sorted, unique_indices=unique_indices, mode=mode) if type(g_operand) is ad_util.Zero and type(g_updates) is ad_util.Zero: - tangent_out = ad_util.Zero.from_value(val_out) + tangent_out = ad_util.Zero.from_primal_value(val_out) else: g_operand = ad.instantiate_zeros(g_operand) g_updates = ad.instantiate_zeros(g_updates) @@ -2180,7 +2180,7 @@ def _scatter_extremal_jvp(scatter_op, primals, tangents, update_jaxpr, unique_indices=unique_indices, mode=mode) if type(g_operand) is ad_util.Zero and type(g_updates) is ad_util.Zero: - tangent_out = ad_util.Zero.from_value(val_out) + tangent_out = ad_util.Zero.from_primal_value(val_out) else: g_operand = ad.instantiate_zeros(g_operand) g_updates = ad.instantiate_zeros(g_updates) @@ -2294,7 +2294,7 @@ def _scatter_jvp(primals, tangents, *, update_jaxpr, update_consts, update_consts=update_consts, dimension_numbers=dnums, indices_are_sorted=indices_are_sorted, unique_indices=unique_indices, mode=mode) - return val_out, ad_util.Zero.from_value(val_out) + return val_out, ad_util.Zero.from_primal_value(val_out) g_operand = ad.instantiate_zeros(g_operand) g_updates = ad.instantiate_zeros(g_updates) diff --git a/jax/_src/lax/windowed_reductions.py b/jax/_src/lax/windowed_reductions.py index dd8e664a095a..089a77de2949 100644 --- a/jax/_src/lax/windowed_reductions.py +++ b/jax/_src/lax/windowed_reductions.py @@ -707,7 +707,7 @@ def _select_and_scatter_add_jvp( padding) del g_operand if type(g_source) is ad_util.Zero: - tangent_out = ad_util.Zero.from_value(val_out) + tangent_out = ad_util.Zero.from_primal_value(val_out) else: tangent_out = _select_and_scatter_add( g_source, operand, select_prim, window_dimensions, @@ -952,7 +952,7 @@ def _select_and_gather_add_jvp( padding, base_dilation, window_dilation) del g_operand if type(g_source) is ad_util.Zero: - tangent_out = ad_util.Zero.from_value(val_out) + tangent_out = ad_util.Zero.from_primal_value(val_out) else: tangent_out = _select_and_gather_add( g_source, operand, select_prim, window_dimensions, diff --git a/jax/_src/pallas/core.py b/jax/_src/pallas/core.py index f8ec3b63339a..7e5768c04092 100644 --- a/jax/_src/pallas/core.py +++ b/jax/_src/pallas/core.py @@ -145,9 +145,9 @@ def update(self, inner_aval=None, memory_space=None): memory_space = self.memory_space if memory_space is None else memory_space return AbstractMemoryRef(inner_aval, memory_space) - def at_least_vspace(self): + def to_tangent_aval(self): return AbstractMemoryRef( - self.inner_aval.at_least_vspace(), self.memory_space) + self.inner_aval.to_tangent_aval(), self.memory_space) def __eq__(self, other): return (type(self) is type(other) and self.inner_aval == other.inner_aval diff --git a/jax/_src/state/discharge.py b/jax/_src/state/discharge.py index 4231822965b1..7970440d29a6 100644 --- a/jax/_src/state/discharge.py +++ b/jax/_src/state/discharge.py @@ -490,7 +490,7 @@ def _run_state_jvp(primals: Sequence[Any], tangents: Sequence[Any], *, len(primals)]) del out_consts out_tangents_iter = iter(out_tangents) - out_tangents = [next(out_tangents_iter) if nz else ad_util.Zero.from_value(p) + out_tangents = [next(out_tangents_iter) if nz else ad_util.Zero.from_primal_value(p) for p, nz in zip(out_primals, nonzero_tangents)] return out_primals, out_tangents ad.primitive_jvps[run_state_p] = _run_state_jvp diff --git a/jax/_src/state/types.py b/jax/_src/state/types.py index 8289f858498b..e64d6258a808 100644 --- a/jax/_src/state/types.py +++ b/jax/_src/state/types.py @@ -243,8 +243,8 @@ def _setitem(self, tracer, idx, value) -> None: def __repr__(self) -> str: return f'Ref{{{self.inner_aval.str_short()}}}' - def at_least_vspace(self): - return AbstractRef(self.inner_aval.at_least_vspace()) + def to_tangent_aval(self): + return AbstractRef(self.inner_aval.to_tangent_aval()) def __eq__(self, other): return (type(self) is type(other) and self.inner_aval == other.inner_aval) diff --git a/jax/core.py b/jax/core.py index 1f433d6f5c29..9857fcf88c02 100644 --- a/jax/core.py +++ b/jax/core.py @@ -85,6 +85,7 @@ full_lower as full_lower, gensym as gensym, get_aval as get_aval, + get_type as get_type, get_referent as get_referent, is_constant_dim as is_constant_dim, is_constant_shape as is_constant_shape, diff --git a/jax/experimental/attrs.py b/jax/experimental/attrs.py index 8176465c1470..62da0f231d50 100644 --- a/jax/experimental/attrs.py +++ b/jax/experimental/attrs.py @@ -169,7 +169,7 @@ def linearize(f, *primals, attrs: list[tuple[Any, str]] = []): def _linearize(traceable: lu.WrappedFun, *primals): jvpfun, attrs = _split_attrs(_jvp(traceable)) in_pvals = (tuple(pe.PartialVal.known(p) for p in primals) - + tuple(pe.PartialVal.unknown(core.get_aval(p).at_least_vspace()) + + tuple(pe.PartialVal.unknown(core.get_aval(p).to_tangent_aval()) for p in primals)) _, in_tree = tree_flatten((primals, primals)) jvpfun_flat, out_tree = flatten_fun_nokwargs(jvpfun, in_tree) @@ -211,7 +211,7 @@ def vjp(f, *primals, attrs: list[tuple[Any, str]] = []): f_, out_tree = flatten_fun_nokwargs(_set_attrs(lu.wrap_init(f), attrs), tree) primal_out, out_pvals, jaxpr, consts, attrs_out = _linearize( f_, *attr_primals, *primals_flat) - attr_avals = [core.raise_to_shaped(core.get_aval(jax_getattr(o, a))).at_least_vspace() + attr_avals = [core.raise_to_shaped(core.get_aval(jax_getattr(o, a))).to_tangent_aval() for o, a in attrs_out] f_vjp = _vjp_wrap(jaxpr, consts, out_pvals, attr_avals, (in_tree, out_tree()), attrs, attrs_out) diff --git a/jax/experimental/shard_map.py b/jax/experimental/shard_map.py index 8319e3fba70f..fabd45ca069a 100644 --- a/jax/experimental/shard_map.py +++ b/jax/experimental/shard_map.py @@ -1405,7 +1405,7 @@ def new_out_names_thunk(): f_jvp, out_tree = ad.traceable(f_jvp, in_tree) result = shard_map_p.bind(f_jvp, *args, **params) primal_out, tangent_out = tree_unflatten(out_tree(), result) - tangent_out = [ad.Zero(core.get_aval(p).at_least_vspace()) if t is None else t + tangent_out = [ad.Zero(core.get_aval(p).to_tangent_aval()) if t is None else t for p, t in zip(primal_out, tangent_out)] return [ad.JVPTracer(trace, p, t) for p, t in zip(primal_out, tangent_out)] ad.JVPTrace.process_shard_map = _shard_map_jvp diff --git a/jax/experimental/sparse/bcoo.py b/jax/experimental/sparse/bcoo.py index 9eafa0db0fc2..d200577c2416 100644 --- a/jax/experimental/sparse/bcoo.py +++ b/jax/experimental/sparse/bcoo.py @@ -332,11 +332,11 @@ def _bcoo_fromdense_jvp(primals, tangents, *, nse, n_batch, n_dense, index_dtype data, indices = primals_out if type(Mdot) is ad.Zero: - data_dot = ad.Zero.from_value(data) + data_dot = ad.Zero.from_primal_value(data) else: data_dot = _bcoo_extract(indices, Mdot) - tangents_out = (data_dot, ad.Zero.from_value(indices)) + tangents_out = (data_dot, ad.Zero.from_primal_value(indices)) return primals_out, tangents_out @@ -571,7 +571,7 @@ def _bcoo_transpose_jvp(primals, tangents, *, permutation: Sequence[int], spinfo data_dot, _ = tangents primals_out = _bcoo_transpose(data, indices, permutation=permutation, spinfo=spinfo) data_dot_out, _ = _bcoo_transpose(data_dot, indices, permutation=permutation, spinfo=spinfo) - return primals_out, (data_dot_out, ad.Zero.from_value(indices)) + return primals_out, (data_dot_out, ad.Zero.from_primal_value(indices)) def _bcoo_transpose_transpose(ct, data, indices, *, permutation: Sequence[int], spinfo: SparseInfo): data_ct, indices_ct = ct @@ -1277,7 +1277,7 @@ def _bcoo_spdot_general_jvp(primals, tangents, **kwds): data_dot_out += _bcoo_spdot_general(lhs_data_dot, lhs_indices, rhs_data, rhs_indices, **kwds)[0] if type(rhs_data_dot) is not ad.Zero: data_dot_out += _bcoo_spdot_general(lhs_data, lhs_indices, rhs_data_dot, rhs_indices, **kwds)[0] - return primals_out, [data_dot_out, ad.Zero.from_value(primals_out[1])] + return primals_out, [data_dot_out, ad.Zero.from_primal_value(primals_out[1])] # TODO(JVP): transpose rule batching.primitive_batchers[bcoo_spdot_general_p] = _bcoo_spdot_general_batch_rule @@ -1358,8 +1358,8 @@ def _bcoo_sort_indices_jvp(primals, tangents, *, spinfo): permute = nfold_vmap(lambda d, p: d[p], props.n_batch) data_out = permute(data, perm) - indices_dot_out = ad.Zero.from_value(indices) - data_dot_out = ad.Zero.from_value(data_out) if type(data_dot) is ad.Zero else permute(data_dot, perm) + indices_dot_out = ad.Zero.from_primal_value(indices) + data_dot_out = ad.Zero.from_primal_value(data_out) if type(data_dot) is ad.Zero else permute(data_dot, perm) return (data_out, indices_out), (data_dot_out, indices_dot_out) _bcoo_sort_indices_hlo = mlir.lower_fun( @@ -1544,8 +1544,8 @@ def _bcoo_sum_duplicates_jvp(primals, tangents, *, spinfo, nse): permute = lambda x, i, y: x permute = nfold_vmap(permute, props.n_batch) data_out = permute(data_out, mapping, data) - indices_dot_out = ad.Zero.from_value(indices_out) - data_dot_out = ad.Zero.from_value(data_out) if type(data_dot) is ad.Zero else permute(data_dot_out, mapping, data_dot) + indices_dot_out = ad.Zero.from_primal_value(indices_out) + data_dot_out = ad.Zero.from_primal_value(data_out) if type(data_dot) is ad.Zero else permute(data_dot_out, mapping, data_dot) return (data_out, indices_out), (data_dot_out, indices_dot_out) _bcoo_sum_duplicates_hlo = mlir.lower_fun( diff --git a/jax/experimental/sparse/bcsr.py b/jax/experimental/sparse/bcsr.py index 7f3ebb43c0ec..7275d6bb20aa 100644 --- a/jax/experimental/sparse/bcsr.py +++ b/jax/experimental/sparse/bcsr.py @@ -272,11 +272,11 @@ def _bcsr_fromdense_jvp(primals, tangents, *, nse, n_batch, n_dense, index_dtype data, indices, indptr = primals_out if type(Mdot) is ad.Zero: - data_dot = ad.Zero.from_value(data) + data_dot = ad.Zero.from_primal_value(data) else: data_dot = bcsr_extract(indices, indptr, Mdot) - tangents_out = (data_dot, ad.Zero.from_value(indices), ad.Zero.from_value(indptr)) + tangents_out = (data_dot, ad.Zero.from_primal_value(indices), ad.Zero.from_primal_value(indptr)) return primals_out, tangents_out diff --git a/jax/experimental/sparse/coo.py b/jax/experimental/sparse/coo.py index 8863478df4d3..c65bc87235d6 100644 --- a/jax/experimental/sparse/coo.py +++ b/jax/experimental/sparse/coo.py @@ -348,11 +348,11 @@ def _coo_fromdense_jvp(primals, tangents, *, nse, index_dtype): data, row, col = primals_out if type(Mdot) is ad.Zero: - data_dot = ad.Zero.from_value(data) + data_dot = ad.Zero.from_primal_value(data) else: data_dot = _coo_extract(row, col, Mdot) - tangents_out = (data_dot, ad.Zero.from_value(row), ad.Zero.from_value(col)) + tangents_out = (data_dot, ad.Zero.from_primal_value(row), ad.Zero.from_primal_value(col)) return primals_out, tangents_out diff --git a/jax/experimental/sparse/csr.py b/jax/experimental/sparse/csr.py index c1178943c02a..89d08f109d68 100644 --- a/jax/experimental/sparse/csr.py +++ b/jax/experimental/sparse/csr.py @@ -380,11 +380,11 @@ def _csr_fromdense_jvp(primals, tangents, *, nse, index_dtype): data, indices, indptr = primals_out if type(Mdot) is ad.Zero: - data_dot = ad.Zero.from_value(data) + data_dot = ad.Zero.from_primal_value(data) else: data_dot = _csr_extract(indices, indptr, Mdot) - tangents_out = (data_dot, ad.Zero.from_value(indices), ad.Zero.from_value(indptr)) + tangents_out = (data_dot, ad.Zero.from_primal_value(indices), ad.Zero.from_primal_value(indptr)) return primals_out, tangents_out diff --git a/jax/interpreters/ad.py b/jax/interpreters/ad.py index 6663df3ac473..6bfc3473ff50 100644 --- a/jax/interpreters/ad.py +++ b/jax/interpreters/ad.py @@ -59,9 +59,7 @@ primitive_jvps as primitive_jvps, primitive_transposes as primitive_transposes, rearrange_binders as rearrange_binders, - recast_to_float0 as recast_to_float0, reducing_transposes as reducing_transposes, - replace_float0s as replace_float0s, standard_jvp as standard_jvp, standard_jvp2 as standard_jvp2, traceable as traceable, diff --git a/tests/api_test.py b/tests/api_test.py index 8b75cb624f1b..b0915a1df44b 100644 --- a/tests/api_test.py +++ b/tests/api_test.py @@ -7203,10 +7203,11 @@ def foo_jvp(primals, tangents): TypeError, re.escape( "Custom JVP rule must produce primal and tangent outputs " - "with equal shapes and dtypes, but got float32[] and float32[1] " - "respectively."), + "with corresponding shapes and dtypes. " + "Expected float32[] (tangent type of float32[]) but got float32[1]."), lambda: api.jvp(f, (jnp.float32(2.),), (jnp.float32(1.),))) + def test_jvp_rule_doesnt_return_pair_error_message(self): # https://github.com/google/jax/issues/2516 @@ -7536,12 +7537,12 @@ def g_jvp(primals, tangents): self.assertAllClose(tangents, 2 * jnp.arange(3., dtype='float32')) def test_float0(self): + scalar_float0 = jnp.zeros((), dtype=float0) @jax.custom_jvp def f(x, y): return x, y def f_jvp(primals, _): - # we need a defined (non-float0) tangent to trigger the rule - return primals, (2., 1) + return primals, (2., scalar_float0) f.defjvp(f_jvp) primals = (2., 3) @@ -7551,12 +7552,13 @@ def f_jvp(primals, _): (primals, expected_tangents)) def test_float0_initial_style(self): + scalar_float0 = jnp.zeros((), dtype=float0) @jax.custom_jvp def f(x, y): return x, y def f_jvp(primals, _): x, y = primals - return (x, y), (2., 1) + return (x, y), (2., scalar_float0) f.defjvp(f_jvp) def foo(x, y): @@ -7564,8 +7566,9 @@ def foo(x, y): return out primals = (2., 3) - tangents = (np.ones(()), np.zeros((), float0),) - expected_tangents = (2., np.zeros((), float0)) + tangents = (np.ones(()), scalar_float0) + expected_tangents = (2., scalar_float0) + self.assertAllClose(api.jvp(foo, primals, tangents), (primals, expected_tangents)) @@ -8730,7 +8733,7 @@ def f(x): def f_fwd(x): return x, (2., x) def f_rev(*_): - return ((2., 1),) + return ((2., jnp.zeros(shape=(), dtype=float0)),) f.defvjp(f_fwd, f_rev) def foo(x, y): @@ -9670,12 +9673,12 @@ def __call__(self, *args): # an option of inferring output types. def custom_transpose(example_out): if isinstance(example_out, Callable): - out_type = core.get_aval(0.).at_least_vspace() + out_type = core.get_aval(0.).to_tangent_aval() return _custom_transpose(out_type, example_out) return partial( _custom_transpose, jax.tree.map( - lambda x: core.get_aval(x).at_least_vspace(), example_out)) + lambda x: core.get_aval(x).to_tangent_aval(), example_out)) class CustomTransposeTest(jtu.JaxTestCase): diff --git a/tests/export_test.py b/tests/export_test.py index b269aef28d79..d5884b7e6b16 100644 --- a/tests/export_test.py +++ b/tests/export_test.py @@ -473,7 +473,8 @@ def f(xi, xf): # Native JAX 1st order vjp (f_outi, f_outf), f_vjp = jax.vjp(f, xi, xf) - f_outi_ct = np.ones(f_outi.shape, dtype=f_outi.dtype) + f_outi_ct = np.ones(f_outi.shape, + dtype=core.primal_dtype_to_tangent_dtype(f_outi.dtype)) f_outf_ct = np.ones(f_outf.shape, dtype=f_outf.dtype) xi_ct, xf_ct = f_vjp((f_outi_ct, f_outf_ct))