Skip to content

Commit

Permalink
Speed up check_jaxpr().
Browse files Browse the repository at this point in the history
(check_jaxpr() is only used when debugging.)

Don't eagerly pretty print jaxprs: only do so if we are going to raise an error.
Don't eagerly form error messages. Delete typecheck_assert.

PiperOrigin-RevId: 422594126
  • Loading branch information
hawkinsp authored and jax authors committed Jan 18, 2022
1 parent e30b96c commit 4c423c3
Show file tree
Hide file tree
Showing 3 changed files with 93 additions and 84 deletions.
73 changes: 36 additions & 37 deletions jax/_src/lax/control_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,12 +121,13 @@ def _abstractify(x):
return raise_to_shaped(core.get_aval(x))

def _typecheck_param(prim, param, name, msg_required, pred):
msg = (f'invalid {prim} param {name} of type {type(param).__name__}, '
f'{msg_required} required:')
param_str = str(param)
sep = os.linesep if os.linesep in param_str else ' '
msg = sep.join([msg, param_str])
core.typecheck_assert(pred, msg)
if not pred:
msg = (f'invalid {prim} param {name} of type {type(param).__name__}, '
f'{msg_required} required:')
param_str = str(param)
sep = os.linesep if os.linesep in param_str else ' '
msg = sep.join([msg, param_str])
raise core.JaxprTypeError(msg)


### fori_loop and while_loop
Expand Down Expand Up @@ -1340,47 +1341,45 @@ def _cond_typecheck(*avals, branches, linear):
tc(linear, 'linear', 'tuple of bool',
type(linear) is tuple and all(type(x) is bool for x in linear))

core.typecheck_assert(
len(branches) > 0,
'cond requires at least one branch function')
core.typecheck_assert(
len(linear) + 1 == len(avals),
f'cond given {len(linear)} linear flags for '
f'{len(avals) - 1} non-predicate operands')
if len(branches) == 0:
raise core.JaxprTypeError('cond requires at least one branch function')
if len(linear) + 1 != len(avals):
raise core.JaxprTypeError(f'cond given {len(linear)} linear flags for '
f'{len(avals) - 1} non-predicate operands')

jaxpr0 = branches[0]
jaxpr0_in_avals_str = _avals_short(jaxpr0.in_avals)
jaxpr0_out_avals_str = _avals_short(jaxpr0.out_avals)

for i, jaxpr in enumerate(branches[1:]):
core.typecheck_assert(
len(jaxpr0.in_avals) == len(jaxpr.in_avals),
if len(jaxpr0.in_avals) != len(jaxpr.in_avals):
raise core.JaxprTypeError(
f'cond branch 0 takes {len(jaxpr0.in_avals)} inputs, '
f'branch {i+1} takes {len(jaxpr.in_avals)}')
core.typecheck_assert(
len(jaxpr0.out_avals) == len(jaxpr.out_avals),
if len(jaxpr0.out_avals) != len(jaxpr.out_avals):
raise core.JaxprTypeError(
f'cond branch 0 outputs {len(jaxpr0.out_avals)} values, '
f'branch {i+1} outputs {len(jaxpr.out_avals)}')
core.typecheck_assert(
all(_map(core.typematch, jaxpr0.in_avals, jaxpr.in_avals)),
if not all(_map(core.typematch, jaxpr0.in_avals, jaxpr.in_avals)):
raise core.JaxprTypeError(
f'cond branches 0 and {i+1} have mismatching input types: '
f'{jaxpr0_in_avals_str} vs {_avals_short(jaxpr.in_avals)}')
core.typecheck_assert(
all(_map(core.typematch, jaxpr0.out_avals, jaxpr.out_avals)),
if not all(_map(core.typematch, jaxpr0.out_avals, jaxpr.out_avals)):
raise core.JaxprTypeError(
f'cond branches 0 and {i+1} have mismatching output types: '
f'{jaxpr0_out_avals_str} vs {_avals_short(jaxpr.out_avals)}')

core.typecheck_assert(
len(avals) == 1 + len(jaxpr0.in_avals),
if len(avals) != 1 + len(jaxpr0.in_avals):
raise core.JaxprTypeError(
f'cond called with {len(avals) - 1} non-predicate operands, '
f'but branches take {len(jaxpr0.in_avals)} inputs')

index_aval, *op_avals = avals
core.typecheck_assert(
index_aval.dtype == np.int32,
if index_aval.dtype != np.int32:
raise core.JaxprTypeError(
f'cond called with index of type {index_aval.dtype} instead of int32')
core.typecheck_assert(
all(_map(core.typecompat, jaxpr0.in_avals, op_avals)),
if not all(_map(core.typecompat, jaxpr0.in_avals, op_avals)):
raise core.JaxprTypeError(
f'cond branches take input types {jaxpr0_in_avals_str}, '
f'called with operands of type {_avals_short(op_avals)}')

Expand Down Expand Up @@ -2177,8 +2176,8 @@ def _scan_typecheck(bind_time, *avals, reverse, length, num_consts, num_carry,
tc(length, 'length', 'non-negative int',
type(length) in length_types and length >= 0)

core.typecheck_assert(
len(linear) == len(avals),
if len(linear) != len(avals):
raise core.JaxprTypeError(
f'scan param linear has length {len(linear)} for {len(avals)} operands')

const_avals, init_avals, x_avals = split_list(avals, [num_consts, num_carry])
Expand All @@ -2187,20 +2186,20 @@ def _scan_typecheck(bind_time, *avals, reverse, length, num_consts, num_carry,
carry_avals_jaxpr, _ = split_list(jaxpr.out_avals, [num_carry])
x_avals_mapped = _map(partial(core.mapped_aval, length, 0), x_avals)

core.typecheck_assert(
all(_map(core.typematch, init_avals_jaxpr, carry_avals_jaxpr)),
if not all(_map(core.typematch, init_avals_jaxpr, carry_avals_jaxpr)):
raise core.JaxprTypeError(
f'scan input carry input and output types mismatch: '
f'\n{_avals_short(init_avals_jaxpr)}\nvs\n{_avals_short(carry_avals_jaxpr)}')
core.typecheck_assert(
all(_map(core.typecompat, const_avals_jaxpr, const_avals)),
if not all(_map(core.typecompat, const_avals_jaxpr, const_avals)):
raise core.JaxprTypeError(
f'scan jaxpr takes input const types\n{_avals_short(const_avals_jaxpr)},\n'
f'called with consts of type\n{_avals_short(const_avals)}')
core.typecheck_assert(
all(_map(core.typecompat, init_avals_jaxpr, init_avals)),
if not all(_map(core.typecompat, init_avals_jaxpr, init_avals)):
raise core.JaxprTypeError(
f'scan jaxpr takes input carry types\n{_avals_short(init_avals_jaxpr)},\n'
f'called with initial carry of type\n{_avals_short(init_avals)}')
core.typecheck_assert(
all(_map(core.typecompat, x_avals_jaxpr, x_avals_mapped)),
if not all(_map(core.typecompat, x_avals_jaxpr, x_avals_mapped)):
raise core.JaxprTypeError(
f'scan jaxpr takes input sequence types\n{_avals_short(x_avals_jaxpr)},\n'
f'called with sequence of type\n{_avals_short(x_avals)}')

Expand Down
97 changes: 53 additions & 44 deletions jax/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import collections
from collections import namedtuple
from contextlib import contextmanager
import functools
from functools import partial, partialmethod, total_ordering
import gc
import itertools as it
Expand Down Expand Up @@ -2050,10 +2051,6 @@ def typematch(aval1: AbstractValue, aval2: AbstractValue) -> bool:

class JaxprTypeError(TypeError): pass

def typecheck_assert(pred, msg):
if not pred:
raise JaxprTypeError(msg)

custom_typechecks: Dict[Primitive, Callable] = {}

def check_jaxpr(jaxpr: Jaxpr):
Expand All @@ -2067,13 +2064,17 @@ def check_jaxpr(jaxpr: Jaxpr):
Raises `JaxprTypeError` if `jaxpr` is determined invalid. Returns `None`
otherwise.
"""
ctx = JaxprPpContext()
try: pp_jaxpr(jaxpr, ctx) # side-effect on ctx, build variable names
except: pass
@functools.lru_cache(maxsize=None)
def ctx_factory():
ctx = JaxprPpContext()
try: pp_jaxpr(jaxpr, ctx) # side-effect on ctx, build variable names
except: pass
return ctx

try:
_check_jaxpr(ctx, jaxpr, [v.aval for v in jaxpr.invars])
_check_jaxpr(ctx_factory, jaxpr, [v.aval for v in jaxpr.invars])
except JaxprTypeError as e:
ctx = ctx_factory()
if len(e.args) == 2:
msg, eqnidx = e.args
jaxpr_str = str(pp_jaxpr_eqn_range(jaxpr, eqnidx - 10, eqnidx + 10, ctx))
Expand All @@ -2083,22 +2084,28 @@ def check_jaxpr(jaxpr: Jaxpr):
msg = "\n\n".join([msg, "while checking jaxpr:", jaxpr_str])
raise JaxprTypeError(msg) from None

def _check_jaxpr(ctx: 'JaxprPpContext', jaxpr: Jaxpr,
def _check_jaxpr(ctx_factory: Callable[[], 'JaxprPpContext'], jaxpr: Jaxpr,
in_avals: Sequence[AbstractValue]) -> None:

def read(v: Atom) -> AbstractValue:
if isinstance(v, Literal):
return raise_to_shaped(get_aval(v.val))
else:
typecheck_assert(v in env, f"Variable '{pp_var(v, ctx)}' not defined")
if v not in env:
ctx = ctx_factory()
raise JaxprTypeError(f"Variable '{pp_var(v, ctx)}' not defined")
return env[v]

def write(v: Var, a: AbstractValue) -> None:
typecheck_assert(v not in env, f"Variable '{pp_var(v, ctx)}' already bound")
if v in env:
ctx = ctx_factory()
raise JaxprTypeError(f"Variable '{pp_var(v, ctx)}' already bound")
if not isinstance(v, DropVar):
typecheck_assert(typecompat(v.aval, a),
f"Variable '{pp_var(v, ctx)}' inconsistently typed as "
f"{pp_aval(a, ctx)}, bound as {pp_aval(v.aval, ctx)}")
if not typecompat(v.aval, a):
ctx = ctx_factory()
raise JaxprTypeError(
f"Variable '{pp_var(v, ctx)}' inconsistently typed as "
f"{pp_aval(a, ctx)}, bound as {pp_aval(v.aval, ctx)}")
env[v] = a

env : Dict[Var, AbstractValue] = {}
Expand All @@ -2111,20 +2118,21 @@ def write(v: Var, a: AbstractValue) -> None:
prim = eqn.primitive
try:
in_avals = map(read, eqn.invars)
typecheck_assert(all(not isinstance(ina, ConcreteArray) for ina in in_avals),
"Equation given ConcreteArray type inputs")
if any(isinstance(ina, ConcreteArray) for ina in in_avals):
raise JaxprTypeError("Equation given ConcreteArray type inputs")
if prim in custom_typechecks:
out_avals = custom_typechecks[prim](*in_avals, **eqn.params)
if out_avals is None:
out_avals = [v.aval for v in eqn.outvars]
elif prim.call_primitive:
out_avals = check_call(ctx, prim, in_avals, eqn.params)
out_avals = check_call(ctx_factory, prim, in_avals, eqn.params)
elif prim.map_primitive:
out_avals = check_map(ctx, prim, in_avals, eqn.params)
out_avals = check_map(ctx_factory, prim, in_avals, eqn.params)
else:
out_avals = check_eqn(prim, in_avals, eqn.params)
map(write, eqn.outvars, out_avals)
except JaxprTypeError as e:
ctx = ctx_factory()
msg, = e.args
src = source_info_util.summarize(eqn.source_info)
msg = "\n\n".join([msg, "in equation:", str(pp.nest(2, pp_eqn(eqn, ctx))),
Expand All @@ -2142,57 +2150,58 @@ def check_eqn(prim, in_avals, params):
out_avals = [out_avals]
return out_avals

def check_call(ctx, prim, in_avals, params):
typecheck_assert("call_jaxpr" in params,
f"Call primitive {prim} missing 'call_jaxpr' parameter")
def check_call(ctx_factory, prim, in_avals, params):
if "call_jaxpr" not in params:
raise JaxprTypeError(
f"Call primitive {prim} missing 'call_jaxpr' parameter")
call_jaxpr = params["call_jaxpr"]

# These checks also happen in recursive call, but give better errors here.
typecheck_assert(len(in_avals) == len(call_jaxpr.invars),
f"Call primitive {prim} with {len(call_jaxpr.invars)} "
f"operands cannot call jaxpr with {len(call_jaxpr.invars)} "
f"inputs")
if len(in_avals) != len(call_jaxpr.invars):
raise JaxprTypeError(f"Call primitive {prim} with {len(call_jaxpr.invars)} "
f"operands cannot call jaxpr with {len(call_jaxpr.invars)} "
f"inputs")
binder_avals = [v.aval for v in call_jaxpr.invars]
for binder_aval, in_aval in zip(binder_avals, in_avals):
typecheck_assert(typecompat(binder_aval, in_aval),
f"Call primitive {prim} passes operand {in_aval} "
f"to jaxpr expecting {binder_aval}")
if not typecompat(binder_aval, in_aval):
raise JaxprTypeError(f"Call primitive {prim} passes operand {in_aval} "
f"to jaxpr expecting {binder_aval}")

_check_jaxpr(ctx, call_jaxpr, in_avals)
_check_jaxpr(ctx_factory, call_jaxpr, in_avals)

out_avals = [v.aval for v in call_jaxpr.outvars]
return out_avals

def check_map(ctx, prim, in_avals, params):
typecheck_assert("call_jaxpr" in params,
f"Map primitive {prim} missing 'call_jaxpr' parameter")
def check_map(ctx_factory, prim, in_avals, params):
if "call_jaxpr" not in params:
raise JaxprTypeError(f"Map primitive {prim} missing 'call_jaxpr' parameter")
call_jaxpr = params["call_jaxpr"]
typecheck_assert("axis_size" in params,
f"Map primitive {prim} missing 'axis_size' parameter")
if "axis_size" not in params:
raise JaxprTypeError(f"Map primitive {prim} missing 'axis_size' parameter")
axis_size = params["axis_size"]
typecheck_assert("axis_name" in params,
f"Map primitive {prim} missing 'axis_name' parameter")
if "axis_name" not in params:
raise JaxprTypeError(f"Map primitive {prim} missing 'axis_name' parameter")
axis_name = params["axis_name"]
typecheck_assert("in_axes" in params,
f"Map primitive {prim} missing 'in_axes' parameter")
if "in_axes" not in params:
raise JaxprTypeError(f"Map primitive {prim} missing 'in_axes' parameter")
in_axes = params["in_axes"]
typecheck_assert("out_axes" in params,
f"Map primitive {prim} missing 'out_axes' parameter")
if "out_axes" not in params:
raise JaxprTypeError(f"Map primitive {prim} missing 'out_axes' parameter")
out_axes = params["out_axes"]

binder_avals = [unmapped_aval(axis_size, axis_name, in_axis, v.aval)
if in_axis is not None else v.aval
for v, in_axis in zip(call_jaxpr.invars, in_axes)]
for binder_aval, in_aval in zip(binder_avals, in_avals):
typecheck_assert(typecompat(binder_aval, in_aval),
f"Call primitive {prim} passes operand {in_aval} "
f"to jaxpr expecting {binder_aval}")
if not typecompat(binder_aval, in_aval):
raise JaxprTypeError(f"Call primitive {prim} passes operand {in_aval} "
f"to jaxpr expecting {binder_aval}")

mapped_avals = [mapped_aval(axis_size, in_axis, aval)
if in_axis is not None else aval
for aval, in_axis in zip(in_avals, in_axes)]
with extend_axis_env(params['axis_name'], axis_size, None):
_check_jaxpr(ctx, call_jaxpr, mapped_avals)
_check_jaxpr(ctx_factory, call_jaxpr, mapped_avals)

mapped_out_avals = [v.aval for v in call_jaxpr.outvars]
out_avals = [unmapped_aval(axis_size, axis_name, out_axis, aval) if out_axis is not None else aval
Expand Down
7 changes: 4 additions & 3 deletions jax/experimental/maps.py
Original file line number Diff line number Diff line change
Expand Up @@ -951,14 +951,15 @@ def _typecheck_xmap(
binder_in_avals = [_insert_aval_axes(v.aval, a_in_axes, local_axis_sizes)
for v, a_in_axes in zip(call_jaxpr.invars, in_axes)]
for binder_in_aval, in_aval in zip(binder_in_avals, in_avals):
core.typecheck_assert(
core.typecompat(binder_in_aval, in_aval),
if not core.typecompat(binder_in_aval, in_aval):
raise core.JaxprTypeError(
f"xmap passes operand {in_aval} to jaxpr expecting {binder_in_aval}")

mapped_in_avals = [_delete_aval_axes(a, a_in_axes, global_axis_sizes)
for a, a_in_axes in zip(in_avals, in_axes)]
with core.extend_axis_env_nd(global_axis_sizes.items()):
core._check_jaxpr(core.JaxprPpContext(), call_jaxpr, mapped_in_avals)
core._check_jaxpr(lambda: core.JaxprPpContext(), call_jaxpr,
mapped_in_avals)

mapped_out_avals = [v.aval for v in call_jaxpr.outvars]
out_avals = [_insert_aval_axes(a, a_out_axes, local_axis_sizes)
Expand Down

0 comments on commit 4c423c3

Please sign in to comment.