From 8ea85769eaf35f591a0b3a9710ba4a49c31a7bc9 Mon Sep 17 00:00:00 2001 From: Lena Martens Date: Mon, 10 Jan 2022 18:21:41 +0000 Subject: [PATCH] Checkify: add way to disable categories of errors. By default only user_asserts are lifted into the checked function. --- jax/experimental/checkify.py | 98 +++++++++++++++++++----------- tests/checkify_test.py | 113 +++++++++++++++++++++++++---------- 2 files changed, 144 insertions(+), 67 deletions(-) diff --git a/jax/experimental/checkify.py b/jax/experimental/checkify.py index 39d314f77ff0..d36908234c9b 100644 --- a/jax/experimental/checkify.py +++ b/jax/experimental/checkify.py @@ -12,10 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. +import enum from dataclasses import dataclass from functools import partial import itertools as it -from typing import Union, Optional, Callable, Dict, Tuple, TypeVar +from typing import Union, Optional, Callable, Dict, Tuple, TypeVar, Set, FrozenSet import numpy as np @@ -97,6 +98,13 @@ def __init__(self, trace, val): class CheckifyTrace(core.Trace): pure = lift = lambda self, val: CheckifyTracer(self, val) + def __init__(self, main: core.MainTrace, sublevel: core.Sublevel, + enabled_errors: FrozenSet['ErrorCategory']) -> None: + self.main = main + self.level = main.level + self.sublevel = sublevel + self.main.enabled_errors = enabled_errors + def sublift(self, tracer): return CheckifyTracer(self, tracer.val) @@ -104,7 +112,7 @@ def process_primitive(self, primitive, tracers, params): in_vals = [t.val for t in tracers] rule = error_checks.get(primitive) if rule: - out, self.main.error = rule(self.main.error, *in_vals, **params) # type: ignore + out, self.main.error = rule(self.main.error, self.main.enabled_errors, *in_vals, **params) # type: ignore else: out = primitive.bind(*in_vals, **params) if primitive.multiple_results: @@ -166,18 +174,18 @@ def _reduce_any_error(errs, codes): errs_, codes_ = lax.sort_key_val(errs, codes, dimension=0) return errs_[-1], codes_[-1] -ErrorCheckRule = Callable +ErrorCheckRule = Callable # (Error, FrozenSet[ErrorCategory], *in_vals, **params) -> (Any, Error) error_checks: Dict[core.Primitive, ErrorCheckRule] = {} -def checkify_flat(fun: lu.WrappedFun, *args): +def checkify_flat(fun: lu.WrappedFun, enabled_errors: FrozenSet['ErrorCategory'], *args): fun, msgs = checkify_subtrace(fun) - fun = checkify_traceable(fun, tuple(init_error.msgs.items())) + fun = checkify_traceable(fun, tuple(init_error.msgs.items()), enabled_errors) err, code, *outvals = fun.call_wrapped(init_error.err, init_error.code, *args) return (err, code, outvals), msgs() @lu.transformation -def checkify_traceable(msgs, err, code, *args): - with core.new_main(CheckifyTrace) as main: +def checkify_traceable(msgs, enabled_errors, err, code, *args): + with core.new_main(CheckifyTrace, enabled_errors=enabled_errors) as main: outs = yield (main, msgs, err, code, *args), {} del main yield outs @@ -196,13 +204,13 @@ def checkify_subtrace(main, msgs, err, code, *args): # TODO take (error_aval, code_aval) instead of error here? -def checkify_jaxpr(jaxpr, error): +def checkify_jaxpr(jaxpr, error, enabled_errors): f = lu.wrap_init(core.jaxpr_as_fun(jaxpr)) - return checkify_fun_to_jaxpr(f, error, jaxpr.in_avals) + return checkify_fun_to_jaxpr(f, error, enabled_errors, jaxpr.in_avals) -def checkify_fun_to_jaxpr(f, error, in_avals): +def checkify_fun_to_jaxpr(f, error, enabled_errors, in_avals): f, msgs = checkify_subtrace(f) - f = checkify_traceable(f, tuple(error.msgs.items())) + f = checkify_traceable(f, tuple(error.msgs.items()), enabled_errors) err_aval = core.raise_to_shaped(core.get_aval(error.err)) code_aval = core.raise_to_shaped(core.get_aval(error.code)) avals_in = [err_aval, code_aval, *in_avals] @@ -244,13 +252,15 @@ def assert_abstract_eval(pred, code, *, msgs): def summary() -> str: return str(source_info_util.summarize(source_info_util.current())) -def nan_error_check(prim, error, *in_vals, **params): +def nan_error_check(prim, error, enabled_errors, *in_vals, **params): out = prim.bind(*in_vals, **params) + if ErrorCategory.NAN not in enabled_errors: + return out, error no_nans = jnp.logical_not(jnp.any(jnp.isnan(out))) msg = f"nan generated by primitive {prim.name} at {summary()}" return out, assert_func(error, no_nans, msg) -def gather_error_check(error, operand, start_indices, *, +def gather_error_check(error, enabled_errors, operand, start_indices, *, dimension_numbers, slice_sizes, unique_indices, indices_are_sorted, mode, fill_value): out = lax.gather_p.bind( @@ -258,6 +268,9 @@ def gather_error_check(error, operand, start_indices, *, slice_sizes=slice_sizes, unique_indices=unique_indices, indices_are_sorted=indices_are_sorted, mode=mode, fill_value=fill_value) + if ErrorCategory.OOB not in enabled_errors: + return out, error + # compare to OOB masking logic in lax._gather_translation_rule dnums = dimension_numbers operand_dims = np.array(operand.shape) @@ -270,12 +283,13 @@ def gather_error_check(error, operand, start_indices, *, return out, assert_func(error, all_inbounds, msg) error_checks[lax.gather_p] = gather_error_check -def div_error_check(error, x, y): +def div_error_check(error, enabled_errors, x, y): """Checks for division by zero and NaN.""" - all_nonzero = jnp.logical_not(jnp.any(jnp.equal(y, 0))) - msg = f'divided by zero at {summary()}' - div_by_zero_err = assert_func(error, all_nonzero, msg) - return nan_error_check(lax.div_p, div_by_zero_err, x, y) + if ErrorCategory.DIV in enabled_errors: + all_nonzero = jnp.logical_not(jnp.any(jnp.equal(y, 0))) + msg = f'divided by zero at {summary()}' + error = assert_func(error, all_nonzero, msg) + return nan_error_check(lax.div_p, error, enabled_errors, x, y) error_checks[lax.div_p] = div_error_check def scatter_in_bounds(operand, indices, updates, dnums): @@ -300,10 +314,9 @@ def scatter_in_bounds(operand, indices, updates, dnums): upper_in_bounds = jnp.all(jnp.less_equal(indices, upper_bound)) return jnp.logical_and(lower_in_bounds, upper_in_bounds) -def scatter_error_check(prim, error, operand, indices, updates, *, - update_jaxpr, update_consts, - dimension_numbers, indices_are_sorted, - unique_indices, mode): +def scatter_error_check(prim, error, enabled_errors, operand, indices, updates, + *, update_jaxpr, update_consts, dimension_numbers, + indices_are_sorted, unique_indices, mode): """Checks if indices are within bounds and update does not generate NaN.""" out = prim.bind( operand, indices, updates, update_jaxpr=update_jaxpr, @@ -311,6 +324,9 @@ def scatter_error_check(prim, error, operand, indices, updates, *, indices_are_sorted=indices_are_sorted, unique_indices=unique_indices, mode=mode) + if ErrorCategory.OOB not in enabled_errors: + return out, error + in_bounds = scatter_in_bounds(operand, indices, updates, dimension_numbers) oob_msg = f'out-of-bounds indexing while updating at {summary()}' oob_error = assert_func(error, in_bounds, oob_msg) @@ -324,8 +340,8 @@ def scatter_error_check(prim, error, operand, indices, updates, *, error_checks[lax.scatter_min_p] = partial(scatter_error_check, lax.scatter_min_p) error_checks[lax.scatter_max_p] = partial(scatter_error_check, lax.scatter_max_p) -def cond_error_check(error, index, *ops, branches, linear): - new_branches, msgs_ = unzip2(checkify_jaxpr(jxpr, error) for jxpr in branches) +def cond_error_check(error, enabled_errors, index, *ops, branches, linear): + new_branches, msgs_ = unzip2(checkify_jaxpr(jxpr, error, enabled_errors) for jxpr in branches) new_linear = (False, False, *linear) err, code, *outs = lax.cond_p.bind( index, error.err, error.code, *ops, @@ -334,9 +350,9 @@ def cond_error_check(error, index, *ops, branches, linear): return outs, Error(err, code, new_msgs) error_checks[lax.cond_p] = cond_error_check -def scan_error_check(error, *in_flat, reverse, length, jaxpr, num_consts, num_carry, linear, unroll): +def scan_error_check(error, enabled_errors, *in_flat, reverse, length, jaxpr, num_consts, num_carry, linear, unroll): consts, carry, xs = split_list(in_flat, [num_consts, num_carry]) - checked_jaxpr, msgs_ = checkify_jaxpr(jaxpr, error) + checked_jaxpr, msgs_ = checkify_jaxpr(jaxpr, error, enabled_errors) new_linear = (False, False, *linear) new_in_flat = [*consts, error.err, error.code, *carry, *xs] err, code, *outs = lax.scan_p.bind( @@ -348,14 +364,14 @@ def scan_error_check(error, *in_flat, reverse, length, jaxpr, num_consts, num_ca return outs, Error(err, code, new_msgs) error_checks[lax.scan_p] = scan_error_check -def checkify_while_body_jaxpr(cond_jaxpr, body_jaxpr, error): +def checkify_while_body_jaxpr(cond_jaxpr, body_jaxpr, error, enabled_errors): cond_f = core.jaxpr_as_fun(cond_jaxpr) body_f = core.jaxpr_as_fun(body_jaxpr) def new_body_f(*vals): out = body_f(*vals) _ = cond_f(*out) # this checks if the next cond application will error return out - return checkify_fun_to_jaxpr(lu.wrap_init(new_body_f), error, body_jaxpr.in_avals) + return checkify_fun_to_jaxpr(lu.wrap_init(new_body_f), error, enabled_errors, body_jaxpr.in_avals) def ignore_errors_jaxpr(jaxpr, error): """Constructs a jaxpr which takes two extra args but ignores them.""" @@ -369,13 +385,13 @@ def ignore_errors_jaxpr(jaxpr, error): jaxpr.outvars, jaxpr.eqns) return core.ClosedJaxpr(new_jaxpr, consts) -def while_loop_error_check(error, *in_flat, cond_nconsts, cond_jaxpr, body_nconsts, body_jaxpr): - checked_cond_jaxpr, msgs_cond = checkify_jaxpr(cond_jaxpr, error) +def while_loop_error_check(error, enabled_errors, *in_flat, cond_nconsts, cond_jaxpr, body_nconsts, body_jaxpr): + checked_cond_jaxpr, msgs_cond = checkify_jaxpr(cond_jaxpr, error, enabled_errors) checked_cond_fun = core.jaxpr_as_fun(checked_cond_jaxpr) # Check if the first cond application will error. cond_err, cond_code, _ = checked_cond_fun(error.err, error.code, *in_flat) - checked_body_jaxpr, msgs_body = checkify_while_body_jaxpr(cond_jaxpr, body_jaxpr, error) + checked_body_jaxpr, msgs_body = checkify_while_body_jaxpr(cond_jaxpr, body_jaxpr, error, enabled_errors) compat_cond_jaxpr = ignore_errors_jaxpr(cond_jaxpr, error) c_consts, b_consts, carry = split_list(in_flat, [cond_nconsts, body_nconsts]) new_in_flat = [*c_consts, *b_consts, cond_err, cond_code, *carry] @@ -453,7 +469,10 @@ def add_nan_check(prim): add_nan_check(lax.max_p) add_nan_check(lax.min_p) -def assert_discharge_rule(error, pred, code, *, msgs): +def assert_discharge_rule(error, enabled_errors, pred, code, *, msgs): + if ErrorCategory.ASSERT not in enabled_errors: + return [], error + out_err = error.err | jnp.logical_not(pred) out_code = lax.select(error.err, error.code, code) return [], Error(out_err, out_code, {**error.msgs, **msgs}) @@ -462,13 +481,24 @@ def assert_discharge_rule(error, pred, code, *, msgs): ## checkify api +ErrorCategory = enum.Enum('ErrorCategory', ['NAN', 'OOB', 'DIV', 'ASSERT']) + +float_errors = {ErrorCategory.NAN, ErrorCategory.DIV} +index_errors = {ErrorCategory.OOB} +automatic_errors = float_errors | index_errors +user_asserts = {ErrorCategory.ASSERT} + Out = TypeVar('Out') -def checkify(fun: Callable[..., Out]) -> Callable[..., Tuple[Error, Out]]: +def checkify(fun: Callable[..., Out], errors: Set[ErrorCategory] = user_asserts) -> Callable[..., Tuple[Error, Out]]: + if not errors: + raise ValueError('Checkify needs to be called with at least one enabled' + ' ErrorCategory, was called with an empty errors set.') + @traceback_util.api_boundary def checked_fun(*args, **kwargs): args_flat, in_tree = tree_flatten((args, kwargs)) f, out_tree = flatten_fun(lu.wrap_init(fun), in_tree) - (err, code, out_flat), msgs = checkify_flat(f, *args_flat) + (err, code, out_flat), msgs = checkify_flat(f, frozenset(errors), *args_flat) out = tree_unflatten(out_tree(), out_flat) return Error(err, code, msgs), out return checked_fun diff --git a/tests/checkify_test.py b/tests/checkify_test.py index 5c27317d085c..0a9e1238f338 100644 --- a/tests/checkify_test.py +++ b/tests/checkify_test.py @@ -39,11 +39,12 @@ def f(x1, x2): return y1 + y2 f = jax.jit(f) if jit else f + checked_f = checkify.checkify(f, errors=checkify.float_errors) - err, _ = checkify.checkify(f)(3., 4.) + err, _ = checked_f(3., 4.) self.assertIs(err.get(), None) - err, _ = checkify.checkify(f)(3., jnp.inf) + err, _ = checked_f(3., jnp.inf) self.assertIsNotNone(err.get()) self.assertStartsWith(err.get(), 'nan generated by primitive sin') @@ -58,16 +59,17 @@ def f(x, i): return w f = jax.jit(f) if jit else f + checked_f = checkify.checkify(f, errors=checkify.index_errors) - err, _ = checkify.checkify(f)(jnp.arange(3), 2) + err, _ = checked_f(jnp.arange(3), 2) self.assertIs(err.get(), None) - err, _ = checkify.checkify(f)(jnp.arange(3), 5) + err, _ = checked_f(jnp.arange(3), 5) self.assertIsNotNone(err.get()) self.assertStartsWith(err.get(), 'out-of-bounds indexing') @parameterized.named_parameters( - {"testcase_name": f"_update={update_fn}", "update_fn": update_fn} + {"testcase_name": f"_updatefn={update_fn}", "update_fn": update_fn} for update_fn in ["set", "add", "multiply", "divide", "power", "min", "max", "get"]) def test_jit_oob_update(self, update_fn): @@ -75,11 +77,12 @@ def f(x, i): return getattr(x.at[i], update_fn)(1.) f = jax.jit(f) + checked_f = checkify.checkify(f, errors=checkify.index_errors) - err, _ = checkify.checkify(f)(jnp.arange(3), 2) + err, _ = checked_f(jnp.arange(3), 2) self.assertIs(err.get(), None) - err, _ = checkify.checkify(f)(jnp.arange(3), 3) + err, _ = checked_f(jnp.arange(3), 3) self.assertIsNotNone(err.get()) self.assertStartsWith(err.get(), 'out-of-bounds indexing') @@ -91,15 +94,16 @@ def f(x, y): return x/y f = jax.jit(f) if jit else f + checked_f = checkify.checkify(f, errors=checkify.float_errors) - err, _ = checkify.checkify(f)(jnp.ones((3,)), jnp.ones((3,))) + err, _ = checked_f(jnp.ones((3,)), jnp.ones((3,))) self.assertIs(err.get(), None) - err, _ = checkify.checkify(f)(jnp.ones((3,)), jnp.array([1, 0, 1])) + err, _ = checked_f(jnp.ones((3,)), jnp.array([1, 0, 1])) self.assertIsNotNone(err.get()) self.assertStartsWith(err.get(), "divided by zero") - err, _ = checkify.checkify(f)(jnp.array([1, jnp.inf, 1]), jnp.array([1, jnp.inf, 1])) + err, _ = checked_f(jnp.array([1, jnp.inf, 1]), jnp.array([1, jnp.inf, 1])) self.assertIsNotNone(err.get()) self.assertStartsWith(err.get(), 'nan generated by primitive div') @@ -114,18 +118,19 @@ def f(x, i): return z f = jax.jit(f) if jit else f + checked_f = checkify.checkify(f, errors=checkify.automatic_errors) # no error - err, _ = checkify.checkify(f)(jnp.array([0., jnp.inf, 2.]), 2) + err, _ = checked_f(jnp.array([0., jnp.inf, 2.]), 2) self.assertIs(err.get(), None) # oob error - err, _ = checkify.checkify(f)(jnp.array([0., 1., 2.]), 5) + err, _ = checked_f(jnp.array([0., 1., 2.]), 5) self.assertIsNotNone(err.get()) self.assertStartsWith(err.get(), 'out-of-bounds indexing') # nan error - err, _ = checkify.checkify(f)(jnp.array([0., 1., jnp.inf]), 2) + err, _ = checked_f(jnp.array([0., 1., jnp.inf]), 2) self.assertIsNotNone(err.get()) self.assertStartsWith(err.get(), 'nan generated by primitive cos') @@ -139,9 +144,10 @@ def f(x, i): return y * z f = jax.jit(f) if jit else f + checked_f = checkify.checkify(f, errors=checkify.automatic_errors) # both oob and nan error, but oob happens first - err, _ = checkify.checkify(f)(jnp.array([0., 1., jnp.inf]), 5) + err, _ = checked_f(jnp.array([0., 1., jnp.inf]), 5) self.assertIsNotNone(err.get()) self.assertStartsWith(err.get(), 'out-of-bounds indexing') @@ -155,13 +161,14 @@ def f(x1, x2): y1 = jnp.sin(x1) y2 = jnp.sin(x2) return y1 + y2 + checked_f = checkify.checkify(f, errors=checkify.float_errors) xs = jnp.array([0., 2.]) - err, _ = checkify.checkify(f)(xs, xs) + err, _ = checked_f(xs, xs) self.assertIs(err.get(), None) ys = jnp.array([3., jnp.inf]) - err, _ = checkify.checkify(f)(xs, ys) + err, _ = checked_f(xs, ys) self.assertIsNotNone(err.get()) self.assertStartsWith(err.get(), 'nan generated by primitive sin') @@ -173,14 +180,16 @@ def f(x): lambda: jnp.sin(x), lambda: x) - err, y = checkify.checkify(f)(3.) + checked_f = checkify.checkify(f, errors=checkify.float_errors) + + err, y = checked_f(3.) self.assertIs(err.get(), None) - err, y = checkify.checkify(f)(jnp.inf) + err, y = checked_f(jnp.inf) self.assertIsNotNone(err.get()) self.assertStartsWith(err.get(), 'nan generated by primitive sin') - err, y = checkify.checkify(f)(-jnp.inf) + err, y = checked_f(-jnp.inf) self.assertIs(err.get(), None) @@ -193,14 +202,16 @@ def scan_body(_, x): def f(xs): return lax.scan(scan_body, None, xs) + checked_f = checkify.checkify(f, errors=checkify.float_errors) + xs = jnp.array([0., 2.]) - err, (_, ch_outs) = checkify.checkify(f)(xs) + err, (_, ch_outs) = checked_f(xs) _, outs = f(xs) self.assertIs(err.get(), None) self.assertArraysEqual(ch_outs, outs) xs = jnp.array([3., jnp.inf]) - err, (_, ch_outs) = checkify.checkify(f)(xs) + err, (_, ch_outs) = checked_f(xs) _, outs = f(xs) self.assertIsNotNone(err.get()) self.assertStartsWith(err.get(), "nan generated by primitive sin") @@ -217,8 +228,10 @@ def scan_body(carry, x): def f(carry, xs): return lax.scan(scan_body, carry, xs) + checked_f = checkify.checkify(f, errors=checkify.float_errors) + carry, xs = 3., jnp.ones((2,)) - err, (ch_out_carry, ch_outs) = checkify.checkify(f)(carry, xs) + err, (ch_out_carry, ch_outs) = checked_f(carry, xs) out_carry, outs = f(carry, xs) self.assertIs(err.get(), None) self.assertArraysEqual(ch_outs, outs) @@ -226,7 +239,7 @@ def f(carry, xs): # error happens on first iteration carry, xs = 1., jnp.ones((2,)) - err, (ch_out_carry, ch_outs) = checkify.checkify(f)(carry, xs) + err, (ch_out_carry, ch_outs) = checked_f(carry, xs) out_carry, outs = f(carry, xs) self.assertIsNotNone(err.get()) self.assertStartsWith(err.get(), "divided by zero") @@ -235,7 +248,7 @@ def f(carry, xs): # error happens on second iteration carry, xs = 2., jnp.ones((4,)) - err, (ch_out_carry, ch_outs) = checkify.checkify(f)(carry, xs) + err, (ch_out_carry, ch_outs) = checked_f(carry, xs) out_carry, outs = f(carry, xs) self.assertIsNotNone(err.get()) self.assertStartsWith(err.get(), "divided by zero") @@ -257,14 +270,16 @@ def while_body(val): def f(init_val): return lax.while_loop(while_cond, while_body, (init_val, 0.)) + checked_f = checkify.checkify(f, errors=checkify.float_errors) + init_val = 1. - err, ch_out = checkify.checkify(f)(init_val) + err, ch_out = checked_f(init_val) out = f(init_val) self.assertIs(err.get(), None) self.assertArraysEqual(ch_out, out) init_val = 0. - err, ch_out = checkify.checkify(f)(init_val) + err, ch_out = checked_f(init_val) out = f(init_val) self.assertIsNotNone(err.get()) self.assertStartsWith(err.get(), "divided by zero") @@ -283,14 +298,16 @@ def while_body(val): def f(init_val): return lax.while_loop(while_cond, while_body, init_val) + checked_f = checkify.checkify(f, errors=checkify.float_errors) + init_val = 1. - err, ch_out = checkify.checkify(f)(init_val) + err, ch_out = checked_f(init_val) out = f(init_val) self.assertIs(err.get(), None) self.assertArraysEqual(ch_out, out) init_val = 0. - err, ch_out = checkify.checkify(f)(init_val) + err, ch_out = checked_f(init_val) out = f(init_val) self.assertIsNotNone(err.get()) self.assertStartsWith(err.get(), "divided by zero") @@ -307,15 +324,17 @@ def while_cond(val): def f(init_val): return lax.while_loop(while_cond, lambda val: val-1, init_val) + checked_f = checkify.checkify(f, errors=checkify.float_errors) + # error on first cond init_val = 0. - err, _ = checkify.checkify(f)(init_val) + err, _ = checked_f(init_val) self.assertIsNotNone(err.get()) self.assertStartsWith(err.get(), "divided by zero") # error on second cond init_val = 1. - err, _ = checkify.checkify(f)(init_val) + err, _ = checked_f(init_val) self.assertIsNotNone(err.get()) self.assertStartsWith(err.get(), "divided by zero") @@ -335,25 +354,53 @@ def while_body(val): def f(cond_val, body_val): return lax.while_loop(while_cond, while_body, (0., cond_val, body_val)) + checked_f = checkify.checkify(f, errors=checkify.float_errors) + cond_val = jnp.inf body_val = 1. - err, _ = checkify.checkify(f)(cond_val, body_val) + err, _ = checked_f(cond_val, body_val) self.assertIsNotNone(err.get()) self.assertStartsWith(err.get(), "nan generated by primitive sin") cond_val = 1. body_val = jnp.inf - err, _ = checkify.checkify(f)(cond_val, body_val) + err, _ = checked_f(cond_val, body_val) self.assertIsNotNone(err.get()) self.assertStartsWith(err.get(), "nan generated by primitive cos") cond_val = jnp.inf body_val = jnp.inf - err, _ = checkify.checkify(f)(cond_val, body_val) + err, _ = checked_f(cond_val, body_val) self.assertIsNotNone(err.get()) # first error which occurs is in cond self.assertStartsWith(err.get(), "nan generated by primitive sin") + def test_empty_enabled_errors(self): + with self.assertRaisesRegex(ValueError, 'called with an empty errors set'): + checkify.checkify(lambda x: x, errors={}) + + @parameterized.named_parameters( + ("assert", checkify.user_asserts, "must be negative!"), + ("div", {checkify.ErrorCategory.DIV}, "divided by zero"), + ("nan", {checkify.ErrorCategory.NAN}, "nan generated"), + ("oob", checkify.index_errors, "out-of-bounds indexing"), + ("automatic_errors", checkify.automatic_errors, "divided by zero"), + ) + @jtu.skip_on_devices('tpu') + def test_enabled_errors(self, error_set, expected_error): + def multi_errors(x): + x = x/0 # DIV + x = jnp.sin(x) # NAN + x = x[500] # OOB + checkify.assert_(x < 0, "must be negative!") # ASSERT + return x + + x = jnp.ones((2,)) + err, _ = checkify.checkify(multi_errors, errors=error_set)(x) + self.assertIsNotNone(err.get()) + self.assertStartsWith(err.get(), expected_error) + + class AssertPrimitiveTests(jtu.JaxTestCase): def test_assert_primitive_impl(self): def f():