Skip to content

Commit

Permalink
Merge pull request #9201 from LenaMartens:changelist/420794552
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 422589737
  • Loading branch information
jax authors committed Jan 18, 2022
2 parents 6411f8a + 8ea8576 commit e30b96c
Show file tree
Hide file tree
Showing 2 changed files with 144 additions and 67 deletions.
98 changes: 64 additions & 34 deletions jax/experimental/checkify.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import enum
from dataclasses import dataclass
from functools import partial
import itertools as it
from typing import Union, Optional, Callable, Dict, Tuple, TypeVar
from typing import Union, Optional, Callable, Dict, Tuple, TypeVar, Set, FrozenSet

import numpy as np

Expand Down Expand Up @@ -97,14 +98,21 @@ def __init__(self, trace, val):
class CheckifyTrace(core.Trace):
pure = lift = lambda self, val: CheckifyTracer(self, val)

def __init__(self, main: core.MainTrace, sublevel: core.Sublevel,
enabled_errors: FrozenSet['ErrorCategory']) -> None:
self.main = main
self.level = main.level
self.sublevel = sublevel
self.main.enabled_errors = enabled_errors

def sublift(self, tracer):
return CheckifyTracer(self, tracer.val)

def process_primitive(self, primitive, tracers, params):
in_vals = [t.val for t in tracers]
rule = error_checks.get(primitive)
if rule:
out, self.main.error = rule(self.main.error, *in_vals, **params) # type: ignore
out, self.main.error = rule(self.main.error, self.main.enabled_errors, *in_vals, **params) # type: ignore
else:
out = primitive.bind(*in_vals, **params)
if primitive.multiple_results:
Expand Down Expand Up @@ -166,18 +174,18 @@ def _reduce_any_error(errs, codes):
errs_, codes_ = lax.sort_key_val(errs, codes, dimension=0)
return errs_[-1], codes_[-1]

ErrorCheckRule = Callable
ErrorCheckRule = Callable # (Error, FrozenSet[ErrorCategory], *in_vals, **params) -> (Any, Error)
error_checks: Dict[core.Primitive, ErrorCheckRule] = {}

def checkify_flat(fun: lu.WrappedFun, *args):
def checkify_flat(fun: lu.WrappedFun, enabled_errors: FrozenSet['ErrorCategory'], *args):
fun, msgs = checkify_subtrace(fun)
fun = checkify_traceable(fun, tuple(init_error.msgs.items()))
fun = checkify_traceable(fun, tuple(init_error.msgs.items()), enabled_errors)
err, code, *outvals = fun.call_wrapped(init_error.err, init_error.code, *args)
return (err, code, outvals), msgs()

@lu.transformation
def checkify_traceable(msgs, err, code, *args):
with core.new_main(CheckifyTrace) as main:
def checkify_traceable(msgs, enabled_errors, err, code, *args):
with core.new_main(CheckifyTrace, enabled_errors=enabled_errors) as main:
outs = yield (main, msgs, err, code, *args), {}
del main
yield outs
Expand All @@ -196,13 +204,13 @@ def checkify_subtrace(main, msgs, err, code, *args):


# TODO take (error_aval, code_aval) instead of error here?
def checkify_jaxpr(jaxpr, error):
def checkify_jaxpr(jaxpr, error, enabled_errors):
f = lu.wrap_init(core.jaxpr_as_fun(jaxpr))
return checkify_fun_to_jaxpr(f, error, jaxpr.in_avals)
return checkify_fun_to_jaxpr(f, error, enabled_errors, jaxpr.in_avals)

def checkify_fun_to_jaxpr(f, error, in_avals):
def checkify_fun_to_jaxpr(f, error, enabled_errors, in_avals):
f, msgs = checkify_subtrace(f)
f = checkify_traceable(f, tuple(error.msgs.items()))
f = checkify_traceable(f, tuple(error.msgs.items()), enabled_errors)
err_aval = core.raise_to_shaped(core.get_aval(error.err))
code_aval = core.raise_to_shaped(core.get_aval(error.code))
avals_in = [err_aval, code_aval, *in_avals]
Expand Down Expand Up @@ -244,20 +252,25 @@ def assert_abstract_eval(pred, code, *, msgs):
def summary() -> str:
return str(source_info_util.summarize(source_info_util.current()))

def nan_error_check(prim, error, *in_vals, **params):
def nan_error_check(prim, error, enabled_errors, *in_vals, **params):
out = prim.bind(*in_vals, **params)
if ErrorCategory.NAN not in enabled_errors:
return out, error
no_nans = jnp.logical_not(jnp.any(jnp.isnan(out)))
msg = f"nan generated by primitive {prim.name} at {summary()}"
return out, assert_func(error, no_nans, msg)

def gather_error_check(error, operand, start_indices, *,
def gather_error_check(error, enabled_errors, operand, start_indices, *,
dimension_numbers, slice_sizes, unique_indices,
indices_are_sorted, mode, fill_value):
out = lax.gather_p.bind(
operand, start_indices, dimension_numbers=dimension_numbers,
slice_sizes=slice_sizes, unique_indices=unique_indices,
indices_are_sorted=indices_are_sorted, mode=mode, fill_value=fill_value)

if ErrorCategory.OOB not in enabled_errors:
return out, error

# compare to OOB masking logic in lax._gather_translation_rule
dnums = dimension_numbers
operand_dims = np.array(operand.shape)
Expand All @@ -270,12 +283,13 @@ def gather_error_check(error, operand, start_indices, *,
return out, assert_func(error, all_inbounds, msg)
error_checks[lax.gather_p] = gather_error_check

def div_error_check(error, x, y):
def div_error_check(error, enabled_errors, x, y):
"""Checks for division by zero and NaN."""
all_nonzero = jnp.logical_not(jnp.any(jnp.equal(y, 0)))
msg = f'divided by zero at {summary()}'
div_by_zero_err = assert_func(error, all_nonzero, msg)
return nan_error_check(lax.div_p, div_by_zero_err, x, y)
if ErrorCategory.DIV in enabled_errors:
all_nonzero = jnp.logical_not(jnp.any(jnp.equal(y, 0)))
msg = f'divided by zero at {summary()}'
error = assert_func(error, all_nonzero, msg)
return nan_error_check(lax.div_p, error, enabled_errors, x, y)
error_checks[lax.div_p] = div_error_check

def scatter_in_bounds(operand, indices, updates, dnums):
Expand All @@ -300,17 +314,19 @@ def scatter_in_bounds(operand, indices, updates, dnums):
upper_in_bounds = jnp.all(jnp.less_equal(indices, upper_bound))
return jnp.logical_and(lower_in_bounds, upper_in_bounds)

def scatter_error_check(prim, error, operand, indices, updates, *,
update_jaxpr, update_consts,
dimension_numbers, indices_are_sorted,
unique_indices, mode):
def scatter_error_check(prim, error, enabled_errors, operand, indices, updates,
*, update_jaxpr, update_consts, dimension_numbers,
indices_are_sorted, unique_indices, mode):
"""Checks if indices are within bounds and update does not generate NaN."""
out = prim.bind(
operand, indices, updates, update_jaxpr=update_jaxpr,
update_consts=update_consts, dimension_numbers=dimension_numbers,
indices_are_sorted=indices_are_sorted, unique_indices=unique_indices,
mode=mode)

if ErrorCategory.OOB not in enabled_errors:
return out, error

in_bounds = scatter_in_bounds(operand, indices, updates, dimension_numbers)
oob_msg = f'out-of-bounds indexing while updating at {summary()}'
oob_error = assert_func(error, in_bounds, oob_msg)
Expand All @@ -324,8 +340,8 @@ def scatter_error_check(prim, error, operand, indices, updates, *,
error_checks[lax.scatter_min_p] = partial(scatter_error_check, lax.scatter_min_p)
error_checks[lax.scatter_max_p] = partial(scatter_error_check, lax.scatter_max_p)

def cond_error_check(error, index, *ops, branches, linear):
new_branches, msgs_ = unzip2(checkify_jaxpr(jxpr, error) for jxpr in branches)
def cond_error_check(error, enabled_errors, index, *ops, branches, linear):
new_branches, msgs_ = unzip2(checkify_jaxpr(jxpr, error, enabled_errors) for jxpr in branches)
new_linear = (False, False, *linear)
err, code, *outs = lax.cond_p.bind(
index, error.err, error.code, *ops,
Expand All @@ -334,9 +350,9 @@ def cond_error_check(error, index, *ops, branches, linear):
return outs, Error(err, code, new_msgs)
error_checks[lax.cond_p] = cond_error_check

def scan_error_check(error, *in_flat, reverse, length, jaxpr, num_consts, num_carry, linear, unroll):
def scan_error_check(error, enabled_errors, *in_flat, reverse, length, jaxpr, num_consts, num_carry, linear, unroll):
consts, carry, xs = split_list(in_flat, [num_consts, num_carry])
checked_jaxpr, msgs_ = checkify_jaxpr(jaxpr, error)
checked_jaxpr, msgs_ = checkify_jaxpr(jaxpr, error, enabled_errors)
new_linear = (False, False, *linear)
new_in_flat = [*consts, error.err, error.code, *carry, *xs]
err, code, *outs = lax.scan_p.bind(
Expand All @@ -348,14 +364,14 @@ def scan_error_check(error, *in_flat, reverse, length, jaxpr, num_consts, num_ca
return outs, Error(err, code, new_msgs)
error_checks[lax.scan_p] = scan_error_check

def checkify_while_body_jaxpr(cond_jaxpr, body_jaxpr, error):
def checkify_while_body_jaxpr(cond_jaxpr, body_jaxpr, error, enabled_errors):
cond_f = core.jaxpr_as_fun(cond_jaxpr)
body_f = core.jaxpr_as_fun(body_jaxpr)
def new_body_f(*vals):
out = body_f(*vals)
_ = cond_f(*out) # this checks if the next cond application will error
return out
return checkify_fun_to_jaxpr(lu.wrap_init(new_body_f), error, body_jaxpr.in_avals)
return checkify_fun_to_jaxpr(lu.wrap_init(new_body_f), error, enabled_errors, body_jaxpr.in_avals)

def ignore_errors_jaxpr(jaxpr, error):
"""Constructs a jaxpr which takes two extra args but ignores them."""
Expand All @@ -369,13 +385,13 @@ def ignore_errors_jaxpr(jaxpr, error):
jaxpr.outvars, jaxpr.eqns)
return core.ClosedJaxpr(new_jaxpr, consts)

def while_loop_error_check(error, *in_flat, cond_nconsts, cond_jaxpr, body_nconsts, body_jaxpr):
checked_cond_jaxpr, msgs_cond = checkify_jaxpr(cond_jaxpr, error)
def while_loop_error_check(error, enabled_errors, *in_flat, cond_nconsts, cond_jaxpr, body_nconsts, body_jaxpr):
checked_cond_jaxpr, msgs_cond = checkify_jaxpr(cond_jaxpr, error, enabled_errors)
checked_cond_fun = core.jaxpr_as_fun(checked_cond_jaxpr)
# Check if the first cond application will error.
cond_err, cond_code, _ = checked_cond_fun(error.err, error.code, *in_flat)

checked_body_jaxpr, msgs_body = checkify_while_body_jaxpr(cond_jaxpr, body_jaxpr, error)
checked_body_jaxpr, msgs_body = checkify_while_body_jaxpr(cond_jaxpr, body_jaxpr, error, enabled_errors)
compat_cond_jaxpr = ignore_errors_jaxpr(cond_jaxpr, error)
c_consts, b_consts, carry = split_list(in_flat, [cond_nconsts, body_nconsts])
new_in_flat = [*c_consts, *b_consts, cond_err, cond_code, *carry]
Expand Down Expand Up @@ -453,7 +469,10 @@ def add_nan_check(prim):
add_nan_check(lax.max_p)
add_nan_check(lax.min_p)

def assert_discharge_rule(error, pred, code, *, msgs):
def assert_discharge_rule(error, enabled_errors, pred, code, *, msgs):
if ErrorCategory.ASSERT not in enabled_errors:
return [], error

out_err = error.err | jnp.logical_not(pred)
out_code = lax.select(error.err, error.code, code)
return [], Error(out_err, out_code, {**error.msgs, **msgs})
Expand All @@ -462,13 +481,24 @@ def assert_discharge_rule(error, pred, code, *, msgs):

## checkify api

ErrorCategory = enum.Enum('ErrorCategory', ['NAN', 'OOB', 'DIV', 'ASSERT'])

float_errors = {ErrorCategory.NAN, ErrorCategory.DIV}
index_errors = {ErrorCategory.OOB}
automatic_errors = float_errors | index_errors
user_asserts = {ErrorCategory.ASSERT}

Out = TypeVar('Out')
def checkify(fun: Callable[..., Out]) -> Callable[..., Tuple[Error, Out]]:
def checkify(fun: Callable[..., Out], errors: Set[ErrorCategory] = user_asserts) -> Callable[..., Tuple[Error, Out]]:
if not errors:
raise ValueError('Checkify needs to be called with at least one enabled'
' ErrorCategory, was called with an empty errors set.')

@traceback_util.api_boundary
def checked_fun(*args, **kwargs):
args_flat, in_tree = tree_flatten((args, kwargs))
f, out_tree = flatten_fun(lu.wrap_init(fun), in_tree)
(err, code, out_flat), msgs = checkify_flat(f, *args_flat)
(err, code, out_flat), msgs = checkify_flat(f, frozenset(errors), *args_flat)
out = tree_unflatten(out_tree(), out_flat)
return Error(err, code, msgs), out
return checked_fun
Loading

0 comments on commit e30b96c

Please sign in to comment.