diff --git a/docs/export/shape_poly.md b/docs/export/shape_poly.md index 695ca6cd21d9..f3da7a0e55fe 100644 --- a/docs/export/shape_poly.md +++ b/docs/export/shape_poly.md @@ -353,7 +353,7 @@ symbolic constraints: E.g., `floordiv(a, b) == c` works by replacing all occurences of `floordiv(a, b)` with `c`. Equality constraints must not contain addition or - subtraction at the top-leve on the left-hand-side. Examples of + subtraction at the top-level on the left-hand-side. Examples of valid left-hand-sides are `a * b`, or `4 * a`, or `floordiv(a + c, b)`. @@ -530,7 +530,7 @@ Array([[ 9, 8, 7], >>> k, = export.symbolic_shape("k", constraints=["k <= 10"]) >>> export.export(jax.jit(my_top_k, static_argnums=0))(k, x) # doctest: +IGNORE_EXCEPTION_DETAIL Traceback (most recent call last): -KeyError: "Encountered dimension variable 'k' that is not appearing in the shapes of the function arguments +UnexpectedDimVar: "Encountered dimension variable 'k' that is not appearing in the shapes of the function arguments ``` diff --git a/jax/_src/export/shape_poly.py b/jax/_src/export/shape_poly.py index 0173df4fd345..77786cbf1a9d 100644 --- a/jax/_src/export/shape_poly.py +++ b/jax/_src/export/shape_poly.py @@ -80,6 +80,8 @@ def __init__(self, message: str): # https://github.com/python/mypy/issues/5887 super().__init__(error_msg) +class UnexpectedDimVar(Exception): + pass class Comparator(Enum): EQ = 1 @@ -87,12 +89,14 @@ class Comparator(Enum): @dataclasses.dataclass(frozen=True) class _SymbolicConstraint: + # Either e1 == e2 if cmp == Comparator.EQ else e1 >= e2 cmp: Comparator debug_str: str # The form in which the user expressed it, for error messages - diff: _DimExpr # For GEQ: diff >= 0, and for EQ: diff == 0 + e1: DimSize # This has been normalized w.r.t. previous constraints only + e2: DimSize # This has been normalized w.r.t. previous constraints only def __repr__(self): - return f"Constraint({self.debug_str}: {self.diff})" + return f"Constraint({self.debug_str})" class _DimFactor: @@ -209,15 +213,22 @@ def __ge__(self, other: _DimFactor): """Lexicographic comparison""" return self._syntactic_cmp(other) >= 0 - def evaluate(self, env: DimVarEnv): + def evaluate(self, env: DimVarEnv, scope: SymbolicScope): if self.var is not None: try: return env[self.var] except KeyError: + # Perhaps there is a normalization rule for this variable + normalized_var = _DimExpr._from_var(self.var, scope) + if core.is_constant_dim(normalized_var): + return normalized_var + non_trivial_normalization = (v1 := normalized_var._to_var()) is None or v1 != self.var # type: ignore + if non_trivial_normalization: + return normalized_var._evaluate(env) # type: ignore err_msg = ( f"Encountered dimension variable '{self.var}' that is not appearing in the shapes of the function arguments.\n" "Please see https://jax.readthedocs.io/en/latest/export/shape_poly.html#dimension-variables-must-be-solvable-from-the-input-shapes for more details.") - raise KeyError(err_msg) + raise UnexpectedDimVar(err_msg) else: operand_values = [opnd._evaluate(env) for opnd in self.operands] if self.operation == _DimFactor.FLOORDIV: @@ -370,11 +381,11 @@ def divide(self, divisor: _DimTerm) -> _DimTerm: raise InconclusiveDimensionOperation(f"Cannot divide {self} by {divisor}.") return _DimTerm(new_factors) - def evaluate(self, env: DimVarEnv): + def evaluate(self, env: DimVarEnv, scope: SymbolicScope): prod = lambda xs: functools.reduce(_evaluate_multiply, xs) if xs else core.dim_constant(1) def pow_opt(v, p: int): return v if p == 1 else prod([v] * p) - return prod([pow_opt(f.evaluate(env), exp) for f, exp in self._factors]) + return prod([pow_opt(f.evaluate(env, scope), exp) for f, exp in self._factors]) def __deepcopy__(self, memo): return _DimTerm(copy.deepcopy(self._factors, memo)) @@ -404,7 +415,7 @@ class _DimExpr: def __init__(self, sorted_terms: SortedTerms, scope: SymbolicScope): # Do not construct _DimExpr directly, unless you are sure that `terms` is - # normalized; Use _DimExpr.normalize. + # normalized; Use _DimExpr._normalize_sorted_terms. self._sorted_terms = tuple(sorted_terms) or ((_DimTerm_one, 0),) self._scope = scope self._hash = None @@ -426,8 +437,8 @@ def _from_term(t: _DimTerm, t_k: int, scope: SymbolicScope) -> DimSize: return _DimExpr._normalize_sorted_terms(((t, t_k),), scope) @staticmethod - def _from_var(v: str, scope: SymbolicScope) -> _DimExpr: - return _DimExpr(((_DimTerm.from_var(v), 1),), scope) + def _from_var(v: str, scope: SymbolicScope) -> DimSize: + return _DimExpr._normalize_sorted_terms(((_DimTerm.from_var(v), 1),), scope) @staticmethod def _from_operation(operation: str, *operands: DimSize, @@ -475,8 +486,9 @@ def _add_coeff(coeffs: dict[_DimTerm, int], t: _DimTerm, coeff: int): def _normalize_term(t: _DimTerm, t_k: int, scope: SymbolicScope) -> Sequence[tuple[_DimTerm, int]]: # If (t, t_k) is among the scope normalization rules, then return - # a list of updates to apply to the expression containing (t, t_k). - # Returns empty sequence if no normalizations are necessary. + # a list of `term * coefficient` to add to the expression containing (t, t_k). + # Returns the empty sequence if no normalizations are necessary. + if not scope._normalization_rules: return [] updates = [] after, t_k_after = scope._normalization_rules.get(t, (None, 0)) if after is not None and t_k % t_k_after == 0: @@ -899,7 +911,7 @@ def _divmod(self, divisor: DimSize) -> tuple[DimSize, int]: def _evaluate(self, env: DimVarEnv): # Evaluates as a value of dtype=core.dim_value_dtype() - terms = [_evaluate_multiply(t.evaluate(env), core.dim_constant(t_k)) + terms = [_evaluate_multiply(t.evaluate(env, self.scope), core.dim_constant(t_k)) for t, t_k in self._sorted_terms] return functools.reduce(_evaluate_add, terms) if len(terms) > 1 else terms[0] @@ -1046,8 +1058,6 @@ def _parse_and_process_explicit_constraint(self, c_str: str): raise ValueError(f"Unsatisfiable explicit constraint: {c_str}") return - constr = _SymbolicConstraint(debug_str=c_str, cmp=cmp, diff=diff) # type: ignore[arg-type] - self._explicit_constraints.append(constr) if cmp == Comparator.EQ: if not isinstance(e1, _DimExpr): raise ValueError("Invalid equality constraint: {e1} == {e2}. " @@ -1063,6 +1073,9 @@ def _parse_and_process_explicit_constraint(self, c_str: str): f"Found multiple equality constraints with the same left-hand-side: {before}") self._normalization_rules[before] = (after, before_k) + constr = _SymbolicConstraint(debug_str=c_str, cmp=cmp, e1=e1, e2=e2) + self._explicit_constraints.append(constr) + def _check_same_scope(self, other: _DimExpr, when: str = "", self_descr: str = " ", @@ -2016,7 +2029,7 @@ def _solve_dim_equations( # Returns a shape environment and the shape constraints if it can solve all # dimension variables. Raises an exception if it cannot. shape_env: DimVarEnv = {} - solution_error_message_pieces: list[str | _DimExpr] = [ + solution_error_message_pieces: list[str | DimSize] = [ " Obtained dimension variables: " ] # Error message describing the solution # Prepare error message piece describing the polymorphic shape specs @@ -2050,8 +2063,8 @@ def process_one_eqn(eqn: _DimEquation) -> bool: for term, term_k in eqn.aval_dim_expr._sorted_terms: # Perhaps we can already evaluate this term (all vars solved) try: - term_value = term.evaluate(shape_env) - except KeyError: + term_value = term.evaluate(shape_env, scope) + except UnexpectedDimVar: # `mon` still uses some variables not yet solved. We handle only the # case when `mon` is a single variable. v = term.to_var() @@ -2118,14 +2131,19 @@ def add_explicit_symbolic_constraints(shape_env: DimVarEnv): if not shape_env: return assert scope is not None for constr in scope._explicit_constraints: - c_value = constr.diff._evaluate(shape_env) + # We can't just construct constr.e1 - constr.e2 because for an equality + # constraint it would be reduced to 0. + c_e1 = constr.e1._evaluate(shape_env) if not core.is_constant_dim(constr.e1) else constr.e1 # type: ignore + c_e2 = constr.e2._evaluate(shape_env) if not core.is_constant_dim(constr.e2) else constr.e2 # type: ignore + c_diff = c_e1 - c_e2 shape_constraints.add_constraint( - constr.cmp, c_value, 0, + constr.cmp, c_diff, 0, error_message_pieces=[ f"Input shapes do not match the symbolic shape constraint {constr.debug_str}. " - f"Expected '{constr.diff}' to be " + f"Expected '{constr.e1} - {constr.e2}' to be " f"{'greater or equal' if constr.cmp == Comparator.GEQ else 'equal'} to 0, " - "but found ", c_value, + "but found ", c_diff, + ". " + poly_specs_err_msg ] + solution_error_message_pieces + [ solution_err_msg_trailer_errors]) diff --git a/jax/_src/export/shape_poly_decision.py b/jax/_src/export/shape_poly_decision.py index e325722b0c26..4bad8b7be06d 100644 --- a/jax/_src/export/shape_poly_decision.py +++ b/jax/_src/export/shape_poly_decision.py @@ -23,6 +23,7 @@ import numpy as np +from jax._src import core from jax._src.export import shape_poly from jax._src.export.shape_poly import ( _DimExpr, _DimTerm, _DimFactor, @@ -84,7 +85,10 @@ def initialize(self) -> _DecisionByElimination: # the result (albeit, for now, without a good feedback loop to understand # how the order matters for inequalities). for constr in self.scope._explicit_constraints: - self.add_implicit_constraints_expr(constr.diff) + if not core.is_constant_dim(constr.e1): + self.add_implicit_constraints_expr(constr.e1) # type: ignore + if not core.is_constant_dim(constr.e2): + self.add_implicit_constraints_expr(constr.e2) # type: ignore # The equality constraints are not needed for inequality decisions, # because the LHS should always be rewritten in terms of the RHS. # In fact, adding them may break the assumption that if we eliminate @@ -92,7 +96,7 @@ def initialize(self) -> _DecisionByElimination: # may appear in the rest and may be rewritten to something larger. # However, we want to add the implicit constraints within. if constr.cmp == Comparator.GEQ: - self.combine_and_add_constraint(constr.cmp, constr.diff, 0, + self.combine_and_add_constraint(constr.cmp, constr.e1 - constr.e2, 0, constr.debug_str) diff --git a/jax/experimental/jax2tf/jax2tf.py b/jax/experimental/jax2tf/jax2tf.py index 545945c91ffd..741887abf24b 100644 --- a/jax/experimental/jax2tf/jax2tf.py +++ b/jax/experimental/jax2tf/jax2tf.py @@ -1253,7 +1253,7 @@ def __init__(self, trace: TensorFlowTrace, val: TfVal, # We have a TF value with known shape, and the abstract shape is a shape variable. try: aval_int = int(_eval_shape([aval_dim])) # type: ignore - except (TypeError, KeyError): + except (TypeError, KeyError, shape_poly.UnexpectedDimVar): continue assert aval_int == val_dim, f"expected {phys_aval.shape} == {val_shape}. Found {aval_int} != {val_dim}." diff --git a/tests/shape_poly_test.py b/tests/shape_poly_test.py index d5b32cdbd7fc..357b2e08d091 100644 --- a/tests/shape_poly_test.py +++ b/tests/shape_poly_test.py @@ -995,6 +995,16 @@ def test_constraints_ge_override(self): self.assertEqual(_bounds(a), (10, np.inf)) self.assertEqual(_bounds(b), (1, 10)) + def test_constraint_eq_0(self): + a, b, c, d = shape_poly.symbolic_shape( + "a, b, c, d", + constraints=("b == a", "c == a + b", "d == 5")) + # Check that we have already applied the normalizaton rules + self.assertEqual(a._to_var(), "a") + self.assertEqual(b._to_var(), "a") + self.assertEqual(c._to_single_term(), (0, 2, a._to_term())) + self.assertIs(d, 5) + def test_constraints_eq_1(self): # Some constaints override other a, b, c = shape_poly.symbolic_shape("a, b, c", @@ -1073,6 +1083,20 @@ def test_constraints_eq_7(self): self.assertEqual(128 * (t1_ceil // 128), t1_ceil) self.assertEqual(128 * b1 * (t1_ceil // 128), b1 * t1_ceil) + def test_constraints_eq_bug_23456(self): + b, = jax.export.symbolic_shape('b', constraints=['b==5']) + jax.eval_shape(lambda k: jnp.tile(k, 3), jax.ShapeDtypeStruct((b,), jnp.float32)) + + def test_constraints_eq_bug_23437(self): + def f1(x, y): + return x + y + + x = jnp.ones((4,), dtype=jnp.int32) + y = jnp.ones((4,), dtype=jnp.int32) + args_specs = jax.export.symbolic_args_specs((x, y), ("a*2", "b*2"), constraints=("a==b",)) + exp = jax.export.export(jax.jit(f1))(*args_specs) + self.assertEqual(exp.in_avals[0], exp.in_avals[1]) + def test_constraints_eq_threefry(self): # Test equalities that arise out of the threefree lowering # x : i32[a] # a may be even or odd @@ -1106,12 +1130,9 @@ def test_constraints_a_minus_4d_eq(self): assumptions1 = ["m1 >= 0", "m1 <= 3", "a1 == 4*d1 + m1"] scope1 = shape_poly.SymbolicScope(assumptions1) a1, d1, m1 = shape_poly.symbolic_shape("a1, d1, m1", scope=scope1) - # TODO: The incompleteness is due to the way we combine external constraints self.assertEqual(_bounds(a1 - 4*d1), (1, 3)) # a - 4d = m >= 1 self.assertEqual(_bounds(a1 - 2*d1), (3, np.inf)) # a - 2d = m + 2d >= 3 - # TODO: The incompleteness is due to the way we combine external constraints - self.assertEqual(_bounds(a1), - _expect(best=(5, np.inf), current=(-np.inf, np.inf))) # a >= 4d + m >= 5 + self.assertEqual(_bounds(a1), (5, np.inf)) # a >= 4d + m >= 5 def test_constraints_error_msg(self): a, b = shape_poly.symbolic_shape("a, b", @@ -1642,8 +1663,7 @@ def f(x): # x: i32[a, b] _ = export.export(jax.jit(f))( jax.ShapeDtypeStruct(export.symbolic_shape("a, b"), np.int32)) - - def test_constraints_compile_time_check(self): + def test_constraints_ge_compile_time_check(self): def f(x): # x: i32[a] a = x.shape[0] assert _bounds(a) == (2, 4) @@ -1669,9 +1689,45 @@ def f(x): # x: i32[a] with self.assertRaisesRegex( ValueError, - re.escape("Expected '- a + 4' to be greater or equal to 0, but found -1")): + re.escape("Expected '4 - a' to be greater or equal to 0, but found -1")): exp.call(np.arange(5, dtype=np.int32)) + def test_constraints_eq_0_compile_time_check(self): + def f(x): # x: i32[a, b] + return x + + x_spec = jax.ShapeDtypeStruct( + export.symbolic_shape("a, b", + constraints=["max(a, b) == b"]), np.int32) + exp = export.export(jax.jit(f))(x_spec) + with self.assertRaisesRegex( + ValueError, + re.escape("Expected 'max(a, b) - b' to be equal to 0, but found 1")): + exp.call(np.ones((3, 2), dtype=np.int32)) + + def test_constraints_eq_1_compile_time_check(self): + def f(x): # x: i32[a, b] + return x + + x_spec = jax.ShapeDtypeStruct( + export.symbolic_shape("a, b", + constraints=["a == b"]), np.int32) + exp = export.export(jax.jit(f))(x_spec) + exp.call(np.ones((3, 3), dtype=np.int32)) + + def test_constraints_eq_2_compile_time_check(self): + def f(x): # x: i32[a, b] + return x + + x_spec = jax.ShapeDtypeStruct( + export.symbolic_shape("a, b", + constraints=["max(a, b) == 4", "a == b"]), np.int32) + exp = export.export(jax.jit(f))(x_spec) + with self.assertRaisesRegex( + ValueError, + re.escape("Expected 'max(a, b) - 4' to be equal to 0, but found -1")): + exp.call(np.ones((3, 3), dtype=np.int32)) + def test_caching_with_scopes(self): f_tracing_count = 0 expected_a_bounds = (1, np.inf)