Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[shape_poly] Improve handling of equality shape constraints #23470

Merged
merged 1 commit into from
Sep 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions docs/export/shape_poly.md
Original file line number Diff line number Diff line change
Expand Up @@ -353,7 +353,7 @@ symbolic constraints:
E.g., `floordiv(a, b) == c` works by replacing all
occurences of `floordiv(a, b)` with `c`.
Equality constraints must not contain addition or
subtraction at the top-leve on the left-hand-side. Examples of
subtraction at the top-level on the left-hand-side. Examples of
valid left-hand-sides are `a * b`, or `4 * a`, or
`floordiv(a + c, b)`.

Expand Down Expand Up @@ -530,7 +530,7 @@ Array([[ 9, 8, 7],
>>> k, = export.symbolic_shape("k", constraints=["k <= 10"])
>>> export.export(jax.jit(my_top_k, static_argnums=0))(k, x) # doctest: +IGNORE_EXCEPTION_DETAIL
Traceback (most recent call last):
KeyError: "Encountered dimension variable 'k' that is not appearing in the shapes of the function arguments
UnexpectedDimVar: "Encountered dimension variable 'k' that is not appearing in the shapes of the function arguments

```

Expand Down
60 changes: 39 additions & 21 deletions jax/_src/export/shape_poly.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,19 +80,23 @@ def __init__(self, message: str):
# https://github.com/python/mypy/issues/5887
super().__init__(error_msg)

class UnexpectedDimVar(Exception):
pass

class Comparator(Enum):
EQ = 1
GEQ = 2

@dataclasses.dataclass(frozen=True)
class _SymbolicConstraint:
# Either e1 == e2 if cmp == Comparator.EQ else e1 >= e2
cmp: Comparator
debug_str: str # The form in which the user expressed it, for error messages
diff: _DimExpr # For GEQ: diff >= 0, and for EQ: diff == 0
e1: DimSize # This has been normalized w.r.t. previous constraints only
e2: DimSize # This has been normalized w.r.t. previous constraints only

def __repr__(self):
return f"Constraint({self.debug_str}: {self.diff})"
return f"Constraint({self.debug_str})"


class _DimFactor:
Expand Down Expand Up @@ -209,15 +213,22 @@ def __ge__(self, other: _DimFactor):
"""Lexicographic comparison"""
return self._syntactic_cmp(other) >= 0

def evaluate(self, env: DimVarEnv):
def evaluate(self, env: DimVarEnv, scope: SymbolicScope):
if self.var is not None:
try:
return env[self.var]
except KeyError:
# Perhaps there is a normalization rule for this variable
normalized_var = _DimExpr._from_var(self.var, scope)
if core.is_constant_dim(normalized_var):
return normalized_var
non_trivial_normalization = (v1 := normalized_var._to_var()) is None or v1 != self.var # type: ignore
if non_trivial_normalization:
return normalized_var._evaluate(env) # type: ignore
err_msg = (
f"Encountered dimension variable '{self.var}' that is not appearing in the shapes of the function arguments.\n"
"Please see https://jax.readthedocs.io/en/latest/export/shape_poly.html#dimension-variables-must-be-solvable-from-the-input-shapes for more details.")
raise KeyError(err_msg)
raise UnexpectedDimVar(err_msg)
else:
operand_values = [opnd._evaluate(env) for opnd in self.operands]
if self.operation == _DimFactor.FLOORDIV:
Expand Down Expand Up @@ -370,11 +381,11 @@ def divide(self, divisor: _DimTerm) -> _DimTerm:
raise InconclusiveDimensionOperation(f"Cannot divide {self} by {divisor}.")
return _DimTerm(new_factors)

def evaluate(self, env: DimVarEnv):
def evaluate(self, env: DimVarEnv, scope: SymbolicScope):
prod = lambda xs: functools.reduce(_evaluate_multiply, xs) if xs else core.dim_constant(1)
def pow_opt(v, p: int):
return v if p == 1 else prod([v] * p)
return prod([pow_opt(f.evaluate(env), exp) for f, exp in self._factors])
return prod([pow_opt(f.evaluate(env, scope), exp) for f, exp in self._factors])

def __deepcopy__(self, memo):
return _DimTerm(copy.deepcopy(self._factors, memo))
Expand Down Expand Up @@ -404,7 +415,7 @@ class _DimExpr:
def __init__(self, sorted_terms: SortedTerms,
scope: SymbolicScope):
# Do not construct _DimExpr directly, unless you are sure that `terms` is
# normalized; Use _DimExpr.normalize.
# normalized; Use _DimExpr._normalize_sorted_terms.
self._sorted_terms = tuple(sorted_terms) or ((_DimTerm_one, 0),)
self._scope = scope
self._hash = None
Expand All @@ -426,8 +437,8 @@ def _from_term(t: _DimTerm, t_k: int, scope: SymbolicScope) -> DimSize:
return _DimExpr._normalize_sorted_terms(((t, t_k),), scope)

@staticmethod
def _from_var(v: str, scope: SymbolicScope) -> _DimExpr:
return _DimExpr(((_DimTerm.from_var(v), 1),), scope)
def _from_var(v: str, scope: SymbolicScope) -> DimSize:
return _DimExpr._normalize_sorted_terms(((_DimTerm.from_var(v), 1),), scope)

@staticmethod
def _from_operation(operation: str, *operands: DimSize,
Expand Down Expand Up @@ -475,8 +486,9 @@ def _add_coeff(coeffs: dict[_DimTerm, int], t: _DimTerm, coeff: int):
def _normalize_term(t: _DimTerm, t_k: int,
scope: SymbolicScope) -> Sequence[tuple[_DimTerm, int]]:
# If (t, t_k) is among the scope normalization rules, then return
# a list of updates to apply to the expression containing (t, t_k).
# Returns empty sequence if no normalizations are necessary.
# a list of `term * coefficient` to add to the expression containing (t, t_k).
# Returns the empty sequence if no normalizations are necessary.
if not scope._normalization_rules: return []
updates = []
after, t_k_after = scope._normalization_rules.get(t, (None, 0))
if after is not None and t_k % t_k_after == 0:
Expand Down Expand Up @@ -899,7 +911,7 @@ def _divmod(self, divisor: DimSize) -> tuple[DimSize, int]:

def _evaluate(self, env: DimVarEnv):
# Evaluates as a value of dtype=core.dim_value_dtype()
terms = [_evaluate_multiply(t.evaluate(env), core.dim_constant(t_k))
terms = [_evaluate_multiply(t.evaluate(env, self.scope), core.dim_constant(t_k))
for t, t_k in self._sorted_terms]
return functools.reduce(_evaluate_add, terms) if len(terms) > 1 else terms[0]

Expand Down Expand Up @@ -1046,8 +1058,6 @@ def _parse_and_process_explicit_constraint(self, c_str: str):
raise ValueError(f"Unsatisfiable explicit constraint: {c_str}")
return

constr = _SymbolicConstraint(debug_str=c_str, cmp=cmp, diff=diff) # type: ignore[arg-type]
self._explicit_constraints.append(constr)
if cmp == Comparator.EQ:
if not isinstance(e1, _DimExpr):
raise ValueError("Invalid equality constraint: {e1} == {e2}. "
Expand All @@ -1063,6 +1073,9 @@ def _parse_and_process_explicit_constraint(self, c_str: str):
f"Found multiple equality constraints with the same left-hand-side: {before}")
self._normalization_rules[before] = (after, before_k)

constr = _SymbolicConstraint(debug_str=c_str, cmp=cmp, e1=e1, e2=e2)
self._explicit_constraints.append(constr)

def _check_same_scope(self, other: _DimExpr,
when: str = "",
self_descr: str = " ",
Expand Down Expand Up @@ -2016,7 +2029,7 @@ def _solve_dim_equations(
# Returns a shape environment and the shape constraints if it can solve all
# dimension variables. Raises an exception if it cannot.
shape_env: DimVarEnv = {}
solution_error_message_pieces: list[str | _DimExpr] = [
solution_error_message_pieces: list[str | DimSize] = [
" Obtained dimension variables: "
] # Error message describing the solution
# Prepare error message piece describing the polymorphic shape specs
Expand Down Expand Up @@ -2050,8 +2063,8 @@ def process_one_eqn(eqn: _DimEquation) -> bool:
for term, term_k in eqn.aval_dim_expr._sorted_terms:
# Perhaps we can already evaluate this term (all vars solved)
try:
term_value = term.evaluate(shape_env)
except KeyError:
term_value = term.evaluate(shape_env, scope)
except UnexpectedDimVar:
# `mon` still uses some variables not yet solved. We handle only the
# case when `mon` is a single variable.
v = term.to_var()
Expand Down Expand Up @@ -2118,14 +2131,19 @@ def add_explicit_symbolic_constraints(shape_env: DimVarEnv):
if not shape_env: return
assert scope is not None
for constr in scope._explicit_constraints:
c_value = constr.diff._evaluate(shape_env)
# We can't just construct constr.e1 - constr.e2 because for an equality
# constraint it would be reduced to 0.
c_e1 = constr.e1._evaluate(shape_env) if not core.is_constant_dim(constr.e1) else constr.e1 # type: ignore
c_e2 = constr.e2._evaluate(shape_env) if not core.is_constant_dim(constr.e2) else constr.e2 # type: ignore
c_diff = c_e1 - c_e2
shape_constraints.add_constraint(
constr.cmp, c_value, 0,
constr.cmp, c_diff, 0,
error_message_pieces=[
f"Input shapes do not match the symbolic shape constraint {constr.debug_str}. "
f"Expected '{constr.diff}' to be "
f"Expected '{constr.e1} - {constr.e2}' to be "
f"{'greater or equal' if constr.cmp == Comparator.GEQ else 'equal'} to 0, "
"but found ", c_value,
"but found ", c_diff,

". " + poly_specs_err_msg
] + solution_error_message_pieces + [
solution_err_msg_trailer_errors])
Expand Down
8 changes: 6 additions & 2 deletions jax/_src/export/shape_poly_decision.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@

import numpy as np

from jax._src import core
from jax._src.export import shape_poly
from jax._src.export.shape_poly import (
_DimExpr, _DimTerm, _DimFactor,
Expand Down Expand Up @@ -84,15 +85,18 @@ def initialize(self) -> _DecisionByElimination:
# the result (albeit, for now, without a good feedback loop to understand
# how the order matters for inequalities).
for constr in self.scope._explicit_constraints:
self.add_implicit_constraints_expr(constr.diff)
if not core.is_constant_dim(constr.e1):
self.add_implicit_constraints_expr(constr.e1) # type: ignore
if not core.is_constant_dim(constr.e2):
self.add_implicit_constraints_expr(constr.e2) # type: ignore
# The equality constraints are not needed for inequality decisions,
# because the LHS should always be rewritten in terms of the RHS.
# In fact, adding them may break the assumption that if we eliminate
# the leading term we end up with only smaller terms, because the LHS
# may appear in the rest and may be rewritten to something larger.
# However, we want to add the implicit constraints within.
if constr.cmp == Comparator.GEQ:
self.combine_and_add_constraint(constr.cmp, constr.diff, 0,
self.combine_and_add_constraint(constr.cmp, constr.e1 - constr.e2, 0,
constr.debug_str)


Expand Down
2 changes: 1 addition & 1 deletion jax/experimental/jax2tf/jax2tf.py
Original file line number Diff line number Diff line change
Expand Up @@ -1253,7 +1253,7 @@ def __init__(self, trace: TensorFlowTrace, val: TfVal,
# We have a TF value with known shape, and the abstract shape is a shape variable.
try:
aval_int = int(_eval_shape([aval_dim])) # type: ignore
except (TypeError, KeyError):
except (TypeError, KeyError, shape_poly.UnexpectedDimVar):
continue
assert aval_int == val_dim, f"expected {phys_aval.shape} == {val_shape}. Found {aval_int} != {val_dim}."

Expand Down
70 changes: 63 additions & 7 deletions tests/shape_poly_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -995,6 +995,16 @@ def test_constraints_ge_override(self):
self.assertEqual(_bounds(a), (10, np.inf))
self.assertEqual(_bounds(b), (1, 10))

def test_constraint_eq_0(self):
a, b, c, d = shape_poly.symbolic_shape(
"a, b, c, d",
constraints=("b == a", "c == a + b", "d == 5"))
# Check that we have already applied the normalizaton rules
self.assertEqual(a._to_var(), "a")
self.assertEqual(b._to_var(), "a")
self.assertEqual(c._to_single_term(), (0, 2, a._to_term()))
self.assertIs(d, 5)

def test_constraints_eq_1(self):
# Some constaints override other
a, b, c = shape_poly.symbolic_shape("a, b, c",
Expand Down Expand Up @@ -1073,6 +1083,20 @@ def test_constraints_eq_7(self):
self.assertEqual(128 * (t1_ceil // 128), t1_ceil)
self.assertEqual(128 * b1 * (t1_ceil // 128), b1 * t1_ceil)

def test_constraints_eq_bug_23456(self):
b, = jax.export.symbolic_shape('b', constraints=['b==5'])
jax.eval_shape(lambda k: jnp.tile(k, 3), jax.ShapeDtypeStruct((b,), jnp.float32))

def test_constraints_eq_bug_23437(self):
def f1(x, y):
return x + y

x = jnp.ones((4,), dtype=jnp.int32)
y = jnp.ones((4,), dtype=jnp.int32)
args_specs = jax.export.symbolic_args_specs((x, y), ("a*2", "b*2"), constraints=("a==b",))
exp = jax.export.export(jax.jit(f1))(*args_specs)
self.assertEqual(exp.in_avals[0], exp.in_avals[1])

def test_constraints_eq_threefry(self):
# Test equalities that arise out of the threefree lowering
# x : i32[a] # a may be even or odd
Expand Down Expand Up @@ -1106,12 +1130,9 @@ def test_constraints_a_minus_4d_eq(self):
assumptions1 = ["m1 >= 0", "m1 <= 3", "a1 == 4*d1 + m1"]
scope1 = shape_poly.SymbolicScope(assumptions1)
a1, d1, m1 = shape_poly.symbolic_shape("a1, d1, m1", scope=scope1)
# TODO: The incompleteness is due to the way we combine external constraints
self.assertEqual(_bounds(a1 - 4*d1), (1, 3)) # a - 4d = m >= 1
self.assertEqual(_bounds(a1 - 2*d1), (3, np.inf)) # a - 2d = m + 2d >= 3
# TODO: The incompleteness is due to the way we combine external constraints
self.assertEqual(_bounds(a1),
_expect(best=(5, np.inf), current=(-np.inf, np.inf))) # a >= 4d + m >= 5
self.assertEqual(_bounds(a1), (5, np.inf)) # a >= 4d + m >= 5

def test_constraints_error_msg(self):
a, b = shape_poly.symbolic_shape("a, b",
Expand Down Expand Up @@ -1642,8 +1663,7 @@ def f(x): # x: i32[a, b]
_ = export.export(jax.jit(f))(
jax.ShapeDtypeStruct(export.symbolic_shape("a, b"), np.int32))


def test_constraints_compile_time_check(self):
def test_constraints_ge_compile_time_check(self):
def f(x): # x: i32[a]
a = x.shape[0]
assert _bounds(a) == (2, 4)
Expand All @@ -1669,9 +1689,45 @@ def f(x): # x: i32[a]

with self.assertRaisesRegex(
ValueError,
re.escape("Expected '- a + 4' to be greater or equal to 0, but found -1")):
re.escape("Expected '4 - a' to be greater or equal to 0, but found -1")):
exp.call(np.arange(5, dtype=np.int32))

def test_constraints_eq_0_compile_time_check(self):
def f(x): # x: i32[a, b]
return x

x_spec = jax.ShapeDtypeStruct(
export.symbolic_shape("a, b",
constraints=["max(a, b) == b"]), np.int32)
exp = export.export(jax.jit(f))(x_spec)
with self.assertRaisesRegex(
ValueError,
re.escape("Expected 'max(a, b) - b' to be equal to 0, but found 1")):
exp.call(np.ones((3, 2), dtype=np.int32))

def test_constraints_eq_1_compile_time_check(self):
def f(x): # x: i32[a, b]
return x

x_spec = jax.ShapeDtypeStruct(
export.symbolic_shape("a, b",
constraints=["a == b"]), np.int32)
exp = export.export(jax.jit(f))(x_spec)
exp.call(np.ones((3, 3), dtype=np.int32))

def test_constraints_eq_2_compile_time_check(self):
def f(x): # x: i32[a, b]
return x

x_spec = jax.ShapeDtypeStruct(
export.symbolic_shape("a, b",
constraints=["max(a, b) == 4", "a == b"]), np.int32)
exp = export.export(jax.jit(f))(x_spec)
with self.assertRaisesRegex(
ValueError,
re.escape("Expected 'max(a, b) - 4' to be equal to 0, but found -1")):
exp.call(np.ones((3, 3), dtype=np.int32))

def test_caching_with_scopes(self):
f_tracing_count = 0
expected_a_bounds = (1, np.inf)
Expand Down