diff --git a/jax/_src/lax/control_flow.py b/jax/_src/lax/control_flow.py index de8a90d1599f..e80d12a76f18 100644 --- a/jax/_src/lax/control_flow.py +++ b/jax/_src/lax/control_flow.py @@ -121,12 +121,13 @@ def _abstractify(x): return raise_to_shaped(core.get_aval(x)) def _typecheck_param(prim, param, name, msg_required, pred): - msg = (f'invalid {prim} param {name} of type {type(param).__name__}, ' - f'{msg_required} required:') - param_str = str(param) - sep = os.linesep if os.linesep in param_str else ' ' - msg = sep.join([msg, param_str]) - core.typecheck_assert(pred, msg) + if not pred: + msg = (f'invalid {prim} param {name} of type {type(param).__name__}, ' + f'{msg_required} required:') + param_str = str(param) + sep = os.linesep if os.linesep in param_str else ' ' + msg = sep.join([msg, param_str]) + raise core.JaxprTypeError(msg) ### fori_loop and while_loop @@ -1340,47 +1341,45 @@ def _cond_typecheck(*avals, branches, linear): tc(linear, 'linear', 'tuple of bool', type(linear) is tuple and all(type(x) is bool for x in linear)) - core.typecheck_assert( - len(branches) > 0, - 'cond requires at least one branch function') - core.typecheck_assert( - len(linear) + 1 == len(avals), - f'cond given {len(linear)} linear flags for ' - f'{len(avals) - 1} non-predicate operands') + if len(branches) == 0: + raise core.JaxprTypeError('cond requires at least one branch function') + if len(linear) + 1 != len(avals): + raise core.JaxprTypeError(f'cond given {len(linear)} linear flags for ' + f'{len(avals) - 1} non-predicate operands') jaxpr0 = branches[0] jaxpr0_in_avals_str = _avals_short(jaxpr0.in_avals) jaxpr0_out_avals_str = _avals_short(jaxpr0.out_avals) for i, jaxpr in enumerate(branches[1:]): - core.typecheck_assert( - len(jaxpr0.in_avals) == len(jaxpr.in_avals), + if len(jaxpr0.in_avals) != len(jaxpr.in_avals): + raise core.JaxprTypeError( f'cond branch 0 takes {len(jaxpr0.in_avals)} inputs, ' f'branch {i+1} takes {len(jaxpr.in_avals)}') - core.typecheck_assert( - len(jaxpr0.out_avals) == len(jaxpr.out_avals), + if len(jaxpr0.out_avals) != len(jaxpr.out_avals): + raise core.JaxprTypeError( f'cond branch 0 outputs {len(jaxpr0.out_avals)} values, ' f'branch {i+1} outputs {len(jaxpr.out_avals)}') - core.typecheck_assert( - all(_map(core.typematch, jaxpr0.in_avals, jaxpr.in_avals)), + if not all(_map(core.typematch, jaxpr0.in_avals, jaxpr.in_avals)): + raise core.JaxprTypeError( f'cond branches 0 and {i+1} have mismatching input types: ' f'{jaxpr0_in_avals_str} vs {_avals_short(jaxpr.in_avals)}') - core.typecheck_assert( - all(_map(core.typematch, jaxpr0.out_avals, jaxpr.out_avals)), + if not all(_map(core.typematch, jaxpr0.out_avals, jaxpr.out_avals)): + raise core.JaxprTypeError( f'cond branches 0 and {i+1} have mismatching output types: ' f'{jaxpr0_out_avals_str} vs {_avals_short(jaxpr.out_avals)}') - core.typecheck_assert( - len(avals) == 1 + len(jaxpr0.in_avals), + if len(avals) != 1 + len(jaxpr0.in_avals): + raise core.JaxprTypeError( f'cond called with {len(avals) - 1} non-predicate operands, ' f'but branches take {len(jaxpr0.in_avals)} inputs') index_aval, *op_avals = avals - core.typecheck_assert( - index_aval.dtype == np.int32, + if index_aval.dtype != np.int32: + raise core.JaxprTypeError( f'cond called with index of type {index_aval.dtype} instead of int32') - core.typecheck_assert( - all(_map(core.typecompat, jaxpr0.in_avals, op_avals)), + if not all(_map(core.typecompat, jaxpr0.in_avals, op_avals)): + raise core.JaxprTypeError( f'cond branches take input types {jaxpr0_in_avals_str}, ' f'called with operands of type {_avals_short(op_avals)}') @@ -2177,8 +2176,8 @@ def _scan_typecheck(bind_time, *avals, reverse, length, num_consts, num_carry, tc(length, 'length', 'non-negative int', type(length) in length_types and length >= 0) - core.typecheck_assert( - len(linear) == len(avals), + if len(linear) != len(avals): + raise core.JaxprTypeError( f'scan param linear has length {len(linear)} for {len(avals)} operands') const_avals, init_avals, x_avals = split_list(avals, [num_consts, num_carry]) @@ -2187,20 +2186,20 @@ def _scan_typecheck(bind_time, *avals, reverse, length, num_consts, num_carry, carry_avals_jaxpr, _ = split_list(jaxpr.out_avals, [num_carry]) x_avals_mapped = _map(partial(core.mapped_aval, length, 0), x_avals) - core.typecheck_assert( - all(_map(core.typematch, init_avals_jaxpr, carry_avals_jaxpr)), + if not all(_map(core.typematch, init_avals_jaxpr, carry_avals_jaxpr)): + raise core.JaxprTypeError( f'scan input carry input and output types mismatch: ' f'\n{_avals_short(init_avals_jaxpr)}\nvs\n{_avals_short(carry_avals_jaxpr)}') - core.typecheck_assert( - all(_map(core.typecompat, const_avals_jaxpr, const_avals)), + if not all(_map(core.typecompat, const_avals_jaxpr, const_avals)): + raise core.JaxprTypeError( f'scan jaxpr takes input const types\n{_avals_short(const_avals_jaxpr)},\n' f'called with consts of type\n{_avals_short(const_avals)}') - core.typecheck_assert( - all(_map(core.typecompat, init_avals_jaxpr, init_avals)), + if not all(_map(core.typecompat, init_avals_jaxpr, init_avals)): + raise core.JaxprTypeError( f'scan jaxpr takes input carry types\n{_avals_short(init_avals_jaxpr)},\n' f'called with initial carry of type\n{_avals_short(init_avals)}') - core.typecheck_assert( - all(_map(core.typecompat, x_avals_jaxpr, x_avals_mapped)), + if not all(_map(core.typecompat, x_avals_jaxpr, x_avals_mapped)): + raise core.JaxprTypeError( f'scan jaxpr takes input sequence types\n{_avals_short(x_avals_jaxpr)},\n' f'called with sequence of type\n{_avals_short(x_avals)}') diff --git a/jax/core.py b/jax/core.py index db7fd96d7393..30c37e48192a 100644 --- a/jax/core.py +++ b/jax/core.py @@ -16,6 +16,7 @@ import collections from collections import namedtuple from contextlib import contextmanager +import functools from functools import partial, partialmethod, total_ordering import gc import itertools as it @@ -2050,10 +2051,6 @@ def typematch(aval1: AbstractValue, aval2: AbstractValue) -> bool: class JaxprTypeError(TypeError): pass -def typecheck_assert(pred, msg): - if not pred: - raise JaxprTypeError(msg) - custom_typechecks: Dict[Primitive, Callable] = {} def check_jaxpr(jaxpr: Jaxpr): @@ -2067,13 +2064,17 @@ def check_jaxpr(jaxpr: Jaxpr): Raises `JaxprTypeError` if `jaxpr` is determined invalid. Returns `None` otherwise. """ - ctx = JaxprPpContext() - try: pp_jaxpr(jaxpr, ctx) # side-effect on ctx, build variable names - except: pass + @functools.lru_cache(maxsize=None) + def ctx_factory(): + ctx = JaxprPpContext() + try: pp_jaxpr(jaxpr, ctx) # side-effect on ctx, build variable names + except: pass + return ctx try: - _check_jaxpr(ctx, jaxpr, [v.aval for v in jaxpr.invars]) + _check_jaxpr(ctx_factory, jaxpr, [v.aval for v in jaxpr.invars]) except JaxprTypeError as e: + ctx = ctx_factory() if len(e.args) == 2: msg, eqnidx = e.args jaxpr_str = str(pp_jaxpr_eqn_range(jaxpr, eqnidx - 10, eqnidx + 10, ctx)) @@ -2083,22 +2084,28 @@ def check_jaxpr(jaxpr: Jaxpr): msg = "\n\n".join([msg, "while checking jaxpr:", jaxpr_str]) raise JaxprTypeError(msg) from None -def _check_jaxpr(ctx: 'JaxprPpContext', jaxpr: Jaxpr, +def _check_jaxpr(ctx_factory: Callable[[], 'JaxprPpContext'], jaxpr: Jaxpr, in_avals: Sequence[AbstractValue]) -> None: def read(v: Atom) -> AbstractValue: if isinstance(v, Literal): return raise_to_shaped(get_aval(v.val)) else: - typecheck_assert(v in env, f"Variable '{pp_var(v, ctx)}' not defined") + if v not in env: + ctx = ctx_factory() + raise JaxprTypeError(f"Variable '{pp_var(v, ctx)}' not defined") return env[v] def write(v: Var, a: AbstractValue) -> None: - typecheck_assert(v not in env, f"Variable '{pp_var(v, ctx)}' already bound") + if v in env: + ctx = ctx_factory() + raise JaxprTypeError(f"Variable '{pp_var(v, ctx)}' already bound") if not isinstance(v, DropVar): - typecheck_assert(typecompat(v.aval, a), - f"Variable '{pp_var(v, ctx)}' inconsistently typed as " - f"{pp_aval(a, ctx)}, bound as {pp_aval(v.aval, ctx)}") + if not typecompat(v.aval, a): + ctx = ctx_factory() + raise JaxprTypeError( + f"Variable '{pp_var(v, ctx)}' inconsistently typed as " + f"{pp_aval(a, ctx)}, bound as {pp_aval(v.aval, ctx)}") env[v] = a env : Dict[Var, AbstractValue] = {} @@ -2111,20 +2118,21 @@ def write(v: Var, a: AbstractValue) -> None: prim = eqn.primitive try: in_avals = map(read, eqn.invars) - typecheck_assert(all(not isinstance(ina, ConcreteArray) for ina in in_avals), - "Equation given ConcreteArray type inputs") + if any(isinstance(ina, ConcreteArray) for ina in in_avals): + raise JaxprTypeError("Equation given ConcreteArray type inputs") if prim in custom_typechecks: out_avals = custom_typechecks[prim](*in_avals, **eqn.params) if out_avals is None: out_avals = [v.aval for v in eqn.outvars] elif prim.call_primitive: - out_avals = check_call(ctx, prim, in_avals, eqn.params) + out_avals = check_call(ctx_factory, prim, in_avals, eqn.params) elif prim.map_primitive: - out_avals = check_map(ctx, prim, in_avals, eqn.params) + out_avals = check_map(ctx_factory, prim, in_avals, eqn.params) else: out_avals = check_eqn(prim, in_avals, eqn.params) map(write, eqn.outvars, out_avals) except JaxprTypeError as e: + ctx = ctx_factory() msg, = e.args src = source_info_util.summarize(eqn.source_info) msg = "\n\n".join([msg, "in equation:", str(pp.nest(2, pp_eqn(eqn, ctx))), @@ -2142,57 +2150,58 @@ def check_eqn(prim, in_avals, params): out_avals = [out_avals] return out_avals -def check_call(ctx, prim, in_avals, params): - typecheck_assert("call_jaxpr" in params, - f"Call primitive {prim} missing 'call_jaxpr' parameter") +def check_call(ctx_factory, prim, in_avals, params): + if "call_jaxpr" not in params: + raise JaxprTypeError( + f"Call primitive {prim} missing 'call_jaxpr' parameter") call_jaxpr = params["call_jaxpr"] # These checks also happen in recursive call, but give better errors here. - typecheck_assert(len(in_avals) == len(call_jaxpr.invars), - f"Call primitive {prim} with {len(call_jaxpr.invars)} " - f"operands cannot call jaxpr with {len(call_jaxpr.invars)} " - f"inputs") + if len(in_avals) != len(call_jaxpr.invars): + raise JaxprTypeError(f"Call primitive {prim} with {len(call_jaxpr.invars)} " + f"operands cannot call jaxpr with {len(call_jaxpr.invars)} " + f"inputs") binder_avals = [v.aval for v in call_jaxpr.invars] for binder_aval, in_aval in zip(binder_avals, in_avals): - typecheck_assert(typecompat(binder_aval, in_aval), - f"Call primitive {prim} passes operand {in_aval} " - f"to jaxpr expecting {binder_aval}") + if not typecompat(binder_aval, in_aval): + raise JaxprTypeError(f"Call primitive {prim} passes operand {in_aval} " + f"to jaxpr expecting {binder_aval}") - _check_jaxpr(ctx, call_jaxpr, in_avals) + _check_jaxpr(ctx_factory, call_jaxpr, in_avals) out_avals = [v.aval for v in call_jaxpr.outvars] return out_avals -def check_map(ctx, prim, in_avals, params): - typecheck_assert("call_jaxpr" in params, - f"Map primitive {prim} missing 'call_jaxpr' parameter") +def check_map(ctx_factory, prim, in_avals, params): + if "call_jaxpr" not in params: + raise JaxprTypeError(f"Map primitive {prim} missing 'call_jaxpr' parameter") call_jaxpr = params["call_jaxpr"] - typecheck_assert("axis_size" in params, - f"Map primitive {prim} missing 'axis_size' parameter") + if "axis_size" not in params: + raise JaxprTypeError(f"Map primitive {prim} missing 'axis_size' parameter") axis_size = params["axis_size"] - typecheck_assert("axis_name" in params, - f"Map primitive {prim} missing 'axis_name' parameter") + if "axis_name" not in params: + raise JaxprTypeError(f"Map primitive {prim} missing 'axis_name' parameter") axis_name = params["axis_name"] - typecheck_assert("in_axes" in params, - f"Map primitive {prim} missing 'in_axes' parameter") + if "in_axes" not in params: + raise JaxprTypeError(f"Map primitive {prim} missing 'in_axes' parameter") in_axes = params["in_axes"] - typecheck_assert("out_axes" in params, - f"Map primitive {prim} missing 'out_axes' parameter") + if "out_axes" not in params: + raise JaxprTypeError(f"Map primitive {prim} missing 'out_axes' parameter") out_axes = params["out_axes"] binder_avals = [unmapped_aval(axis_size, axis_name, in_axis, v.aval) if in_axis is not None else v.aval for v, in_axis in zip(call_jaxpr.invars, in_axes)] for binder_aval, in_aval in zip(binder_avals, in_avals): - typecheck_assert(typecompat(binder_aval, in_aval), - f"Call primitive {prim} passes operand {in_aval} " - f"to jaxpr expecting {binder_aval}") + if not typecompat(binder_aval, in_aval): + raise JaxprTypeError(f"Call primitive {prim} passes operand {in_aval} " + f"to jaxpr expecting {binder_aval}") mapped_avals = [mapped_aval(axis_size, in_axis, aval) if in_axis is not None else aval for aval, in_axis in zip(in_avals, in_axes)] with extend_axis_env(params['axis_name'], axis_size, None): - _check_jaxpr(ctx, call_jaxpr, mapped_avals) + _check_jaxpr(ctx_factory, call_jaxpr, mapped_avals) mapped_out_avals = [v.aval for v in call_jaxpr.outvars] out_avals = [unmapped_aval(axis_size, axis_name, out_axis, aval) if out_axis is not None else aval diff --git a/jax/experimental/maps.py b/jax/experimental/maps.py index 5272e1bed431..e37142f4b8b6 100644 --- a/jax/experimental/maps.py +++ b/jax/experimental/maps.py @@ -951,14 +951,15 @@ def _typecheck_xmap( binder_in_avals = [_insert_aval_axes(v.aval, a_in_axes, local_axis_sizes) for v, a_in_axes in zip(call_jaxpr.invars, in_axes)] for binder_in_aval, in_aval in zip(binder_in_avals, in_avals): - core.typecheck_assert( - core.typecompat(binder_in_aval, in_aval), + if not core.typecompat(binder_in_aval, in_aval): + raise core.JaxprTypeError( f"xmap passes operand {in_aval} to jaxpr expecting {binder_in_aval}") mapped_in_avals = [_delete_aval_axes(a, a_in_axes, global_axis_sizes) for a, a_in_axes in zip(in_avals, in_axes)] with core.extend_axis_env_nd(global_axis_sizes.items()): - core._check_jaxpr(core.JaxprPpContext(), call_jaxpr, mapped_in_avals) + core._check_jaxpr(lambda: core.JaxprPpContext(), call_jaxpr, + mapped_in_avals) mapped_out_avals = [v.aval for v in call_jaxpr.outvars] out_avals = [_insert_aval_axes(a, a_out_axes, local_axis_sizes)