Skip to content

Commit

Permalink
Allow expr_matches to better handle presence of extra data
Browse files Browse the repository at this point in the history
  • Loading branch information
PhilReinhold committed Aug 26, 2024
1 parent a1c746c commit 2df694b
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 1 deletion.
4 changes: 3 additions & 1 deletion oqpy/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,6 +318,8 @@ def expr_matches(a: Any, b: Any) -> bool:
This bypasses calling ``__eq__`` on expr objects.
"""
if a is b:
return True
if type(a) is not type(b):
return False
if isinstance(a, (list, np.ndarray)):
Expand All @@ -328,7 +330,7 @@ def expr_matches(a: Any, b: Any) -> bool:
if a.keys() != b.keys():
return False
return all(expr_matches(va, b[k]) for k, va in a.items())
if hasattr(a, "__dict__"):
if hasattr(a, "__dict__") and type(a).__module__.startswith("oqpy"):
return expr_matches(a.__dict__, b.__dict__)
else:
return a == b
Expand Down
25 changes: 25 additions & 0 deletions tests/test_directives.py
Original file line number Diff line number Diff line change
Expand Up @@ -2624,3 +2624,28 @@ def test_box_with_negative_duration():
with pytest.raises(ValueError, match="Expected a non-negative duration, but got -4e-09"):
with Box(prog, -4e-9):
pass


def test_expr_matches_handles_outside_data():
x1 = oqpy.FloatVar(3, name="x")
x2 = oqpy.FloatVar(3, name="x")
class MyEntity:
def __init__(self):
self.self_ref = self

def __eq__(self, other):
return True

x1._entity = MyEntity()
x2._entity = MyEntity()
assert oqpy.base.expr_matches(x1, x2)

class MyEntityNoEq:
def __init__(self):
self.self_ref = self
def __eq__(self, other):
raise RuntimeError("Eq not allowed")

x1._entity = MyEntityNoEq()
x2._entity = x1._entity
oqpy.base.expr_matches(x1, x2)

0 comments on commit 2df694b

Please sign in to comment.