From bdcd358b65e9cd3bc99379cab95bb835686ecfc6 Mon Sep 17 00:00:00 2001 From: Matthew Johnson Date: Sat, 3 Aug 2024 21:49:58 +0000 Subject: [PATCH] improve while_loop carry pytree/type mismatch errors Now we call into the same error utility as we use in scan. --- jax/_src/lax/control_flow/loops.py | 61 +++++++++++++------------ tests/lax_control_flow_test.py | 72 ++++++++++++++++-------------- 2 files changed, 71 insertions(+), 62 deletions(-) diff --git a/jax/_src/lax/control_flow/loops.py b/jax/_src/lax/control_flow/loops.py index 16ebd25e8961..24029b92873e 100644 --- a/jax/_src/lax/control_flow/loops.py +++ b/jax/_src/lax/control_flow/loops.py @@ -47,7 +47,7 @@ from jax._src.lax import slicing from jax._src.lax import windowed_reductions from jax._src.lax.control_flow.common import ( - _abstractify, _avals_short, _check_tree_and_avals, _initial_style_jaxpr, + _abstractify, _avals_short, _initial_style_jaxpr, _initial_style_jaxpr_attrs, _make_closed_jaxpr_attrs, _prune_zeros, _typecheck_param) from jax._src.lib.mlir import ir @@ -285,7 +285,7 @@ def _create_jaxpr(init): in_flat, jaxpr, consts, out_tree, out_tree_children, attrs_tracked = rest num_carry = len(init_flat) - _check_scan_carry_type(f, init, out_tree_children[0], carry_avals_out) + _check_carry_type('scan body', f, init, out_tree_children[0], carry_avals_out) disallowed_effects = effects.control_flow_allowed_effects.filter_not_in(jaxpr.effects) if disallowed_effects: raise NotImplementedError( @@ -328,7 +328,7 @@ def _get_states(attrs_tracked): vals.extend(leaves) return vals -def _check_scan_carry_type(body_fun, in_carry, out_carry_tree, out_avals): +def _check_carry_type(name, body_fun, in_carry, out_carry_tree, out_avals): try: sig = inspect.signature(body_fun) except (ValueError, TypeError): @@ -353,33 +353,39 @@ def _check_scan_carry_type(body_fun, in_carry, out_carry_tree, out_avals): differences = [f'the input tree structure is:\n{in_carry_tree}\n', f'the output tree structure is:\n{out_carry_tree}\n'] else: - differences = '\n'.join( - f' * {component(path)} is a {thing1} but the corresponding component ' - f'of the carry output is a {thing2}, so {explanation}\n' - for path, thing1, thing2, explanation - in equality_errors(in_carry, out_carry)) + diffs = [f'{component(path)} is a {thing1} but the corresponding component ' + f'of the carry output is a {thing2}, so {explanation}' + for path, thing1, thing2, explanation + in equality_errors(in_carry, out_carry)] + if len(diffs) == 1: + differences = f'{diffs[0]}.\n'.capitalize() + else: + differences = ('\n'.join(f' * {d};\n' for d in diffs[:-1]) + + f' * {diffs[-1]}.\n') raise TypeError( - "Scanned function carry input and carry output must have the same " - "pytree structure, but they differ:\n" + f"{name} function carry input and carry output must have the same " + "pytree structure, but they differ:\n\n" f"{differences}\n" - "Revise the scanned function so that its output is a pair where the " - "first element has the same pytree structure as the first argument." - ) + "Revise the function so that the carry output has the same pytree " + "structure as the carry input.") if not all(_map(core.typematch, in_avals, out_avals)): - differences = '\n'.join( - f' * {component(path)} has type {in_aval.str_short()}' - ' but the corresponding output carry component has type ' - f'{out_aval.str_short()}{_aval_mismatch_extra(in_aval, out_aval)}\n' - for path, in_aval, out_aval in zip(paths, in_avals, out_avals) - if not core.typematch(in_aval, out_aval)) + diffs = [f'{component(path)} has type {in_aval.str_short()}' + ' but the corresponding output carry component has type ' + f'{out_aval.str_short()}{_aval_mismatch_extra(in_aval, out_aval)}' + for path, in_aval, out_aval in zip(paths, in_avals, out_avals) + if not core.typematch(in_aval, out_aval)] + if len(diffs) == 1: + differences = f'{diffs[0]}.\n'.capitalize() + else: + differences = ('\n'.join(f' * {d};\n' for d in diffs[:-1]) + + f' * {diffs[-1]}.\n') raise TypeError( - "Scanned function carry input and carry output must have equal types " + f"{name} function carry input and carry output must have equal types " "(e.g. shapes and dtypes of arrays), " - "but they differ:\n" + "but they differ:\n\n" f"{differences}\n" - "Revise the scanned function so that all output types (e.g. shapes " - "and dtypes) match the corresponding input types." - ) + "Revise the function so that all output types (e.g. shapes " + "and dtypes) match the corresponding input types.") def _aval_mismatch_extra(a1: core.AbstractValue, a2: core.AbstractValue) -> str: assert not core.typematch(a1, a2) @@ -1314,16 +1320,15 @@ def _create_jaxpr(init_val): # necessary, a second time with modified init values. init_vals, init_avals, body_jaxpr, in_tree, *rest = _create_jaxpr(init_val) new_init_vals, changed = _promote_weak_typed_inputs(init_vals, init_avals, body_jaxpr.out_avals) + new_init_val, = tree_unflatten(in_tree, new_init_vals) if changed: - new_init_val, = tree_unflatten(in_tree, new_init_vals) init_vals, init_avals, body_jaxpr, in_tree, *rest = _create_jaxpr(new_init_val) cond_jaxpr, cond_consts, body_consts, body_tree = rest in_tree_children = in_tree.children() assert len(in_tree_children) == 1 - _check_tree_and_avals("body_fun output and input", - body_tree, body_jaxpr.out_avals, - in_tree_children[0], init_avals) + _check_carry_type('while_loop body', body_fun, new_init_val, body_tree, + body_jaxpr.out_avals) joined_effects = core.join_effects(cond_jaxpr.effects, body_jaxpr.effects) disallowed_effects = effects.control_flow_allowed_effects.filter_not_in(joined_effects) if disallowed_effects: diff --git a/tests/lax_control_flow_test.py b/tests/lax_control_flow_test.py index 8a4798c6010b..d52862ec42ac 100644 --- a/tests/lax_control_flow_test.py +++ b/tests/lax_control_flow_test.py @@ -300,21 +300,25 @@ def testWhileTypeErrors(self): """Test typing error messages for while.""" tuple_treedef = jax.tree.structure((1., 1.)) leaf_treedef = jax.tree.structure(0.) - with self.assertRaisesRegex(TypeError, + with self.assertRaisesRegex( + TypeError, re.escape(f"cond_fun must return a boolean scalar, but got pytree {tuple_treedef}.")): lax.while_loop(lambda c: (1., 1.), lambda c: c, 0.) - with self.assertRaisesRegex(TypeError, + with self.assertRaisesRegex( + TypeError, re.escape("cond_fun must return a boolean scalar, but got output type(s) [ShapedArray(float32[])].")): lax.while_loop(lambda c: np.float32(1.), lambda c: c, np.float32(0.)) - with self.assertRaisesRegex(TypeError, - re.escape("body_fun output and input must have same type structure, " - f"got {tuple_treedef} and {leaf_treedef}.")): + with self.assertRaisesRegex( + TypeError, + re.escape("while_loop body function carry input and carry output must " + "have the same pytree structure, but they differ:\n\n" + "The input carry c is a")): lax.while_loop(lambda c: True, lambda c: (1., 1.), 0.) - with self.assertRaisesWithLiteralMatch(TypeError, - ("body_fun output and input must have identical types, got\n" - "('ShapedArray(bool[])', " - "'DIFFERENT ShapedArray(bool[]) vs. " - "ShapedArray(float32[])').")): + with self.assertRaisesRegex( + TypeError, + r"The input carry component c\[1\] has type float32\[\] but the " + r"corresponding output carry component has type bool\[\], so the " + "dtypes do not match."): lax.while_loop(lambda c: True, lambda c: (True, True), (np.bool_(True), np.float32(0.))) @@ -1882,39 +1886,39 @@ def testScanBodyOutputError(self): def testScanBodyCarryPytreeMismatchErrors(self): with self.assertRaisesRegex( TypeError, - re.escape("Scanned function carry input and carry output must have " - "the same pytree structure, but they differ:\n" - " * the input carry c is a tuple of length 2")): + re.escape("function carry input and carry output must have " + "the same pytree structure, but they differ:\n\n" + "The input carry c is a tuple of length 2")): lax.scan(lambda c, x: ((0, 0, 0), x), (1, (2, 3)), jnp.arange(5.)) with self.assertRaisesRegex( TypeError, - re.escape("Scanned function carry input and carry output must have the " - "same pytree structure, but they differ:\n" - " * the input carry x is a tuple of length 2")): + re.escape("function carry input and carry output must have the " + "same pytree structure, but they differ:\n\n" + "The input carry x is a tuple of length 2")): lax.scan(lambda x, _: ((x[0].astype('float32'),), None), (jnp.array(0, 'int32'),) * 2, None, length=1) with self.assertRaisesRegex( TypeError, - re.escape("Scanned function carry input and carry output must have the " - "same pytree structure, but they differ:\n" - " * the input carry x is a but the corres")): + re.escape("function carry input and carry output must have the " + "same pytree structure, but they differ:\n\n" + "The input carry x is a but the corres")): jax.lax.scan(lambda x, _: ([x[0].astype('float32'),] * 2, None), (jnp.array(0, 'int32'),) * 2, None, length=1) with self.assertRaisesRegex( TypeError, - re.escape("Scanned function carry input and carry output must have the " - "same pytree structure, but they differ:\n" - " * the input carry x is a with 1 child but")): + re.escape("function carry input and carry output must have the " + "same pytree structure, but they differ:\n\n" + "The input carry x is a with 1 child but")): jax.lax.scan(lambda x, _: ({'a': x['a'], 'b': x['a']}, None), {'a': jnp.array(0, 'int32')}, None, length=1) with self.assertRaisesRegex( TypeError, - re.escape("Scanned function carry input and carry output must have the " - "same pytree structure, but they differ:\n" + re.escape("function carry input and carry output must have the " + "same pytree structure, but they differ:\n\n" " * the input carry component x[0] is a with " "1 child but the corresponding component of the carry " "output is a with 2 children")): @@ -1924,9 +1928,9 @@ def testScanBodyCarryPytreeMismatchErrors(self): def testScanBodyCarryTypeMismatchErrors(self): with self.assertRaisesRegex( TypeError, - re.escape("Scanned function carry input and carry output must have equal " - "types (e.g. shapes and dtypes of arrays), but they differ:\n" - " * the input carry x has type int32[] but the corresponding " + re.escape("function carry input and carry output must have equal " + "types (e.g. shapes and dtypes of arrays), but they differ:\n\n" + "The input carry x has type int32[] but the corresponding " "output carry component has type float32[], so the dtypes do " "not match" )): @@ -1935,9 +1939,9 @@ def testScanBodyCarryTypeMismatchErrors(self): with self.assertRaisesRegex( TypeError, - re.escape("Scanned function carry input and carry output must have equal " - "types (e.g. shapes and dtypes of arrays), but they differ:\n" - " * the input carry component x[1] has type int32[] but the " + re.escape("function carry input and carry output must have equal " + "types (e.g. shapes and dtypes of arrays), but they differ:\n\n" + "The input carry component x[1] has type int32[] but the " "corresponding output carry component has type float32[], " "so the dtypes do not match" )): @@ -1946,14 +1950,14 @@ def testScanBodyCarryTypeMismatchErrors(self): with self.assertRaisesRegex( TypeError, - re.escape("Scanned function carry input and carry output must have equal " - "types (e.g. shapes and dtypes of arrays), but they differ:\n" + re.escape("function carry input and carry output must have equal " + "types (e.g. shapes and dtypes of arrays), but they differ:\n\n" " * the input carry component x[0] has type int32[] but the " "corresponding output carry component has type float32[], " - "so the dtypes do not match\n\n" + "so the dtypes do not match;\n" " * the input carry component x[1] has type int32[] but the " "corresponding output carry component has type float32[1,1], " - "so the dtypes do not match and also the shapes do not match" + "so the dtypes do not match and also the shapes do not match." )): jax.lax.scan(lambda x, _: ((x[0].astype('float32'), x[1].astype('float32').reshape(1, 1),