Skip to content

Commit

Permalink
improve while_loop carry pytree/type mismatch errors
Browse files Browse the repository at this point in the history
Now we call into the same error utility as we use in scan.
  • Loading branch information
mattjj committed Aug 3, 2024
1 parent 09beb33 commit bdcd358
Show file tree
Hide file tree
Showing 2 changed files with 71 additions and 62 deletions.
61 changes: 33 additions & 28 deletions jax/_src/lax/control_flow/loops.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@
from jax._src.lax import slicing
from jax._src.lax import windowed_reductions
from jax._src.lax.control_flow.common import (
_abstractify, _avals_short, _check_tree_and_avals, _initial_style_jaxpr,
_abstractify, _avals_short, _initial_style_jaxpr,
_initial_style_jaxpr_attrs, _make_closed_jaxpr_attrs, _prune_zeros,
_typecheck_param)
from jax._src.lib.mlir import ir
Expand Down Expand Up @@ -285,7 +285,7 @@ def _create_jaxpr(init):
in_flat, jaxpr, consts, out_tree, out_tree_children, attrs_tracked = rest
num_carry = len(init_flat)

_check_scan_carry_type(f, init, out_tree_children[0], carry_avals_out)
_check_carry_type('scan body', f, init, out_tree_children[0], carry_avals_out)
disallowed_effects = effects.control_flow_allowed_effects.filter_not_in(jaxpr.effects)
if disallowed_effects:
raise NotImplementedError(
Expand Down Expand Up @@ -328,7 +328,7 @@ def _get_states(attrs_tracked):
vals.extend(leaves)
return vals

def _check_scan_carry_type(body_fun, in_carry, out_carry_tree, out_avals):
def _check_carry_type(name, body_fun, in_carry, out_carry_tree, out_avals):
try:
sig = inspect.signature(body_fun)
except (ValueError, TypeError):
Expand All @@ -353,33 +353,39 @@ def _check_scan_carry_type(body_fun, in_carry, out_carry_tree, out_avals):
differences = [f'the input tree structure is:\n{in_carry_tree}\n',
f'the output tree structure is:\n{out_carry_tree}\n']
else:
differences = '\n'.join(
f' * {component(path)} is a {thing1} but the corresponding component '
f'of the carry output is a {thing2}, so {explanation}\n'
for path, thing1, thing2, explanation
in equality_errors(in_carry, out_carry))
diffs = [f'{component(path)} is a {thing1} but the corresponding component '
f'of the carry output is a {thing2}, so {explanation}'
for path, thing1, thing2, explanation
in equality_errors(in_carry, out_carry)]
if len(diffs) == 1:
differences = f'{diffs[0]}.\n'.capitalize()
else:
differences = ('\n'.join(f' * {d};\n' for d in diffs[:-1])
+ f' * {diffs[-1]}.\n')
raise TypeError(
"Scanned function carry input and carry output must have the same "
"pytree structure, but they differ:\n"
f"{name} function carry input and carry output must have the same "
"pytree structure, but they differ:\n\n"
f"{differences}\n"
"Revise the scanned function so that its output is a pair where the "
"first element has the same pytree structure as the first argument."
)
"Revise the function so that the carry output has the same pytree "
"structure as the carry input.")
if not all(_map(core.typematch, in_avals, out_avals)):
differences = '\n'.join(
f' * {component(path)} has type {in_aval.str_short()}'
' but the corresponding output carry component has type '
f'{out_aval.str_short()}{_aval_mismatch_extra(in_aval, out_aval)}\n'
for path, in_aval, out_aval in zip(paths, in_avals, out_avals)
if not core.typematch(in_aval, out_aval))
diffs = [f'{component(path)} has type {in_aval.str_short()}'
' but the corresponding output carry component has type '
f'{out_aval.str_short()}{_aval_mismatch_extra(in_aval, out_aval)}'
for path, in_aval, out_aval in zip(paths, in_avals, out_avals)
if not core.typematch(in_aval, out_aval)]
if len(diffs) == 1:
differences = f'{diffs[0]}.\n'.capitalize()
else:
differences = ('\n'.join(f' * {d};\n' for d in diffs[:-1])
+ f' * {diffs[-1]}.\n')
raise TypeError(
"Scanned function carry input and carry output must have equal types "
f"{name} function carry input and carry output must have equal types "
"(e.g. shapes and dtypes of arrays), "
"but they differ:\n"
"but they differ:\n\n"
f"{differences}\n"
"Revise the scanned function so that all output types (e.g. shapes "
"and dtypes) match the corresponding input types."
)
"Revise the function so that all output types (e.g. shapes "
"and dtypes) match the corresponding input types.")

def _aval_mismatch_extra(a1: core.AbstractValue, a2: core.AbstractValue) -> str:
assert not core.typematch(a1, a2)
Expand Down Expand Up @@ -1314,16 +1320,15 @@ def _create_jaxpr(init_val):
# necessary, a second time with modified init values.
init_vals, init_avals, body_jaxpr, in_tree, *rest = _create_jaxpr(init_val)
new_init_vals, changed = _promote_weak_typed_inputs(init_vals, init_avals, body_jaxpr.out_avals)
new_init_val, = tree_unflatten(in_tree, new_init_vals)
if changed:
new_init_val, = tree_unflatten(in_tree, new_init_vals)
init_vals, init_avals, body_jaxpr, in_tree, *rest = _create_jaxpr(new_init_val)
cond_jaxpr, cond_consts, body_consts, body_tree = rest

in_tree_children = in_tree.children()
assert len(in_tree_children) == 1
_check_tree_and_avals("body_fun output and input",
body_tree, body_jaxpr.out_avals,
in_tree_children[0], init_avals)
_check_carry_type('while_loop body', body_fun, new_init_val, body_tree,
body_jaxpr.out_avals)
joined_effects = core.join_effects(cond_jaxpr.effects, body_jaxpr.effects)
disallowed_effects = effects.control_flow_allowed_effects.filter_not_in(joined_effects)
if disallowed_effects:
Expand Down
72 changes: 38 additions & 34 deletions tests/lax_control_flow_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,21 +300,25 @@ def testWhileTypeErrors(self):
"""Test typing error messages for while."""
tuple_treedef = jax.tree.structure((1., 1.))
leaf_treedef = jax.tree.structure(0.)
with self.assertRaisesRegex(TypeError,
with self.assertRaisesRegex(
TypeError,
re.escape(f"cond_fun must return a boolean scalar, but got pytree {tuple_treedef}.")):
lax.while_loop(lambda c: (1., 1.), lambda c: c, 0.)
with self.assertRaisesRegex(TypeError,
with self.assertRaisesRegex(
TypeError,
re.escape("cond_fun must return a boolean scalar, but got output type(s) [ShapedArray(float32[])].")):
lax.while_loop(lambda c: np.float32(1.), lambda c: c, np.float32(0.))
with self.assertRaisesRegex(TypeError,
re.escape("body_fun output and input must have same type structure, "
f"got {tuple_treedef} and {leaf_treedef}.")):
with self.assertRaisesRegex(
TypeError,
re.escape("while_loop body function carry input and carry output must "
"have the same pytree structure, but they differ:\n\n"
"The input carry c is a")):
lax.while_loop(lambda c: True, lambda c: (1., 1.), 0.)
with self.assertRaisesWithLiteralMatch(TypeError,
("body_fun output and input must have identical types, got\n"
"('ShapedArray(bool[])', "
"'DIFFERENT ShapedArray(bool[]) vs. "
"ShapedArray(float32[])').")):
with self.assertRaisesRegex(
TypeError,
r"The input carry component c\[1\] has type float32\[\] but the "
r"corresponding output carry component has type bool\[\], so the "
"dtypes do not match."):
lax.while_loop(lambda c: True, lambda c: (True, True),
(np.bool_(True), np.float32(0.)))

Expand Down Expand Up @@ -1882,39 +1886,39 @@ def testScanBodyOutputError(self):
def testScanBodyCarryPytreeMismatchErrors(self):
with self.assertRaisesRegex(
TypeError,
re.escape("Scanned function carry input and carry output must have "
"the same pytree structure, but they differ:\n"
" * the input carry c is a tuple of length 2")):
re.escape("function carry input and carry output must have "
"the same pytree structure, but they differ:\n\n"
"The input carry c is a tuple of length 2")):
lax.scan(lambda c, x: ((0, 0, 0), x), (1, (2, 3)), jnp.arange(5.))

with self.assertRaisesRegex(
TypeError,
re.escape("Scanned function carry input and carry output must have the "
"same pytree structure, but they differ:\n"
" * the input carry x is a tuple of length 2")):
re.escape("function carry input and carry output must have the "
"same pytree structure, but they differ:\n\n"
"The input carry x is a tuple of length 2")):
lax.scan(lambda x, _: ((x[0].astype('float32'),), None),
(jnp.array(0, 'int32'),) * 2, None, length=1)

with self.assertRaisesRegex(
TypeError,
re.escape("Scanned function carry input and carry output must have the "
"same pytree structure, but they differ:\n"
" * the input carry x is a <class 'tuple'> but the corres")):
re.escape("function carry input and carry output must have the "
"same pytree structure, but they differ:\n\n"
"The input carry x is a <class 'tuple'> but the corres")):
jax.lax.scan(lambda x, _: ([x[0].astype('float32'),] * 2, None),
(jnp.array(0, 'int32'),) * 2, None, length=1)

with self.assertRaisesRegex(
TypeError,
re.escape("Scanned function carry input and carry output must have the "
"same pytree structure, but they differ:\n"
" * the input carry x is a <class 'dict'> with 1 child but")):
re.escape("function carry input and carry output must have the "
"same pytree structure, but they differ:\n\n"
"The input carry x is a <class 'dict'> with 1 child but")):
jax.lax.scan(lambda x, _: ({'a': x['a'], 'b': x['a']}, None),
{'a': jnp.array(0, 'int32')}, None, length=1)

with self.assertRaisesRegex(
TypeError,
re.escape("Scanned function carry input and carry output must have the "
"same pytree structure, but they differ:\n"
re.escape("function carry input and carry output must have the "
"same pytree structure, but they differ:\n\n"
" * the input carry component x[0] is a <class 'dict'> with "
"1 child but the corresponding component of the carry "
"output is a <class 'dict'> with 2 children")):
Expand All @@ -1924,9 +1928,9 @@ def testScanBodyCarryPytreeMismatchErrors(self):
def testScanBodyCarryTypeMismatchErrors(self):
with self.assertRaisesRegex(
TypeError,
re.escape("Scanned function carry input and carry output must have equal "
"types (e.g. shapes and dtypes of arrays), but they differ:\n"
" * the input carry x has type int32[] but the corresponding "
re.escape("function carry input and carry output must have equal "
"types (e.g. shapes and dtypes of arrays), but they differ:\n\n"
"The input carry x has type int32[] but the corresponding "
"output carry component has type float32[], so the dtypes do "
"not match"
)):
Expand All @@ -1935,9 +1939,9 @@ def testScanBodyCarryTypeMismatchErrors(self):

with self.assertRaisesRegex(
TypeError,
re.escape("Scanned function carry input and carry output must have equal "
"types (e.g. shapes and dtypes of arrays), but they differ:\n"
" * the input carry component x[1] has type int32[] but the "
re.escape("function carry input and carry output must have equal "
"types (e.g. shapes and dtypes of arrays), but they differ:\n\n"
"The input carry component x[1] has type int32[] but the "
"corresponding output carry component has type float32[], "
"so the dtypes do not match"
)):
Expand All @@ -1946,14 +1950,14 @@ def testScanBodyCarryTypeMismatchErrors(self):

with self.assertRaisesRegex(
TypeError,
re.escape("Scanned function carry input and carry output must have equal "
"types (e.g. shapes and dtypes of arrays), but they differ:\n"
re.escape("function carry input and carry output must have equal "
"types (e.g. shapes and dtypes of arrays), but they differ:\n\n"
" * the input carry component x[0] has type int32[] but the "
"corresponding output carry component has type float32[], "
"so the dtypes do not match\n\n"
"so the dtypes do not match;\n"
" * the input carry component x[1] has type int32[] but the "
"corresponding output carry component has type float32[1,1], "
"so the dtypes do not match and also the shapes do not match"
"so the dtypes do not match and also the shapes do not match."
)):
jax.lax.scan(lambda x, _: ((x[0].astype('float32'),
x[1].astype('float32').reshape(1, 1),
Expand Down

0 comments on commit bdcd358

Please sign in to comment.