diff --git a/oqpy/base.py b/oqpy/base.py index 82b0bd4..1bf15d3 100644 --- a/oqpy/base.py +++ b/oqpy/base.py @@ -180,6 +180,12 @@ def __bool__(self) -> bool: "the equality of expressions using == instead of expr_matches." ) + def _expr_matches(self, other: Any) -> bool: + """Called by expr_matches to compare expression instances.""" + if not isinstance(other, type(self)): + return False + return expr_matches(self.__dict__, other.__dict__) + def _get_type(val: AstConvertible) -> Optional[ast.ClassicalType]: if isinstance(val, OQPyExpression): @@ -318,6 +324,8 @@ def expr_matches(a: Any, b: Any) -> bool: This bypasses calling ``__eq__`` on expr objects. """ + if a is b: + return True if type(a) is not type(b): return False if isinstance(a, (list, np.ndarray)): @@ -328,8 +336,9 @@ def expr_matches(a: Any, b: Any) -> bool: if a.keys() != b.keys(): return False return all(expr_matches(va, b[k]) for k, va in a.items()) - if hasattr(a, "__dict__"): - return expr_matches(a.__dict__, b.__dict__) + if isinstance(a, OQPyExpression): + # Bypass `__eq__` which is overloaded on OQPyExpressions + return a._expr_matches(b) else: return a == b diff --git a/tests/test_directives.py b/tests/test_directives.py index f114c2d..ae1f359 100644 --- a/tests/test_directives.py +++ b/tests/test_directives.py @@ -2624,3 +2624,53 @@ def test_box_with_negative_duration(): with pytest.raises(ValueError, match="Expected a non-negative duration, but got -4e-09"): with Box(prog, -4e-9): pass + + +def test_expr_matches_handles_outside_data(): + x1 = oqpy.FloatVar(3, name="x") + x2 = oqpy.FloatVar(3, name="x") + class MyEntity: + def __init__(self): + self.self_ref = self + + def __eq__(self, other): + return True + + x1._entity = MyEntity() + x2._entity = MyEntity() + assert oqpy.base.expr_matches(x1, x2) + + class MyEntityNoEq: + def __init__(self): + self.self_ref = self + def __eq__(self, other): + raise RuntimeError("Eq not allowed") + + x1._entity = MyEntityNoEq() + x2._entity = x1._entity + oqpy.base.expr_matches(x1, x2) + + class MyFloatVar(oqpy.FloatVar): + ... + + x1 = MyFloatVar(3, name="x") + x2 = MyFloatVar(3, name="x") + assert not x1._expr_matches(oqpy.FloatVar(3, name="x")) + assert oqpy.base.expr_matches(x1, x2) + + class MyFloatVarWithIgnoredData(oqpy.FloatVar): + ignored: int + def _expr_matches(self, other): + if not isinstance(other, type(self)): + return False + d1 = self.__dict__.copy() + d2 = other.__dict__.copy() + d1.pop("ignored") + d2.pop("ignored") + return oqpy.base.expr_matches(d1, d2) + + x1 = MyFloatVarWithIgnoredData(3, name="x") + x1.ignored = 1 + x2 = MyFloatVarWithIgnoredData(3, name="x") + x2.ignored = 2 + assert oqpy.base.expr_matches(x1, x2)