Skip to content

Commit

Permalink
Allow expr_matches to better handle presence of extra data (#93)
Browse files Browse the repository at this point in the history
* Allow expr_matches to better handle presence of extra data

* Switch to checking isinstance instead of package name

* Add OQPyExpression._expr_matches to give subclasses ability to control expr_matches behavior

* mypy

* fix coverage
  • Loading branch information
PhilReinhold authored Aug 28, 2024
1 parent a1c746c commit 6efb6f5
Show file tree
Hide file tree
Showing 2 changed files with 61 additions and 2 deletions.
13 changes: 11 additions & 2 deletions oqpy/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,12 @@ def __bool__(self) -> bool:
"the equality of expressions using == instead of expr_matches."
)

def _expr_matches(self, other: Any) -> bool:
"""Called by expr_matches to compare expression instances."""
if not isinstance(other, type(self)):
return False
return expr_matches(self.__dict__, other.__dict__)


def _get_type(val: AstConvertible) -> Optional[ast.ClassicalType]:
if isinstance(val, OQPyExpression):
Expand Down Expand Up @@ -318,6 +324,8 @@ def expr_matches(a: Any, b: Any) -> bool:
This bypasses calling ``__eq__`` on expr objects.
"""
if a is b:
return True
if type(a) is not type(b):
return False
if isinstance(a, (list, np.ndarray)):
Expand All @@ -328,8 +336,9 @@ def expr_matches(a: Any, b: Any) -> bool:
if a.keys() != b.keys():
return False
return all(expr_matches(va, b[k]) for k, va in a.items())
if hasattr(a, "__dict__"):
return expr_matches(a.__dict__, b.__dict__)
if isinstance(a, OQPyExpression):
# Bypass `__eq__` which is overloaded on OQPyExpressions
return a._expr_matches(b)
else:
return a == b

Expand Down
50 changes: 50 additions & 0 deletions tests/test_directives.py
Original file line number Diff line number Diff line change
Expand Up @@ -2624,3 +2624,53 @@ def test_box_with_negative_duration():
with pytest.raises(ValueError, match="Expected a non-negative duration, but got -4e-09"):
with Box(prog, -4e-9):
pass


def test_expr_matches_handles_outside_data():
x1 = oqpy.FloatVar(3, name="x")
x2 = oqpy.FloatVar(3, name="x")
class MyEntity:
def __init__(self):
self.self_ref = self

def __eq__(self, other):
return True

x1._entity = MyEntity()
x2._entity = MyEntity()
assert oqpy.base.expr_matches(x1, x2)

class MyEntityNoEq:
def __init__(self):
self.self_ref = self
def __eq__(self, other):
raise RuntimeError("Eq not allowed")

x1._entity = MyEntityNoEq()
x2._entity = x1._entity
oqpy.base.expr_matches(x1, x2)

class MyFloatVar(oqpy.FloatVar):
...

x1 = MyFloatVar(3, name="x")
x2 = MyFloatVar(3, name="x")
assert not x1._expr_matches(oqpy.FloatVar(3, name="x"))
assert oqpy.base.expr_matches(x1, x2)

class MyFloatVarWithIgnoredData(oqpy.FloatVar):
ignored: int
def _expr_matches(self, other):
if not isinstance(other, type(self)):
return False
d1 = self.__dict__.copy()
d2 = other.__dict__.copy()
d1.pop("ignored")
d2.pop("ignored")
return oqpy.base.expr_matches(d1, d2)

x1 = MyFloatVarWithIgnoredData(3, name="x")
x1.ignored = 1
x2 = MyFloatVarWithIgnoredData(3, name="x")
x2.ignored = 2
assert oqpy.base.expr_matches(x1, x2)

0 comments on commit 6efb6f5

Please sign in to comment.