Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Set operations and tests added #5983

Merged
merged 41 commits into from
Jul 26, 2024
Merged
Show file tree
Hide file tree
Changes from 16 commits
Commits
Show all changes
41 commits
Select commit Hold shift + click to select a range
fce3d8d
Set operations and tests added
austingmhuang Jul 10, 2024
9b8c9b6
changelog and docstrigns
austingmhuang Jul 10, 2024
e20e0cd
codefactor unhappy
austingmhuang Jul 10, 2024
1e8c249
pylint
austingmhuang Jul 10, 2024
9c3fe5d
fix errors for py3.9
austingmhuang Jul 11, 2024
e966a81
wrong place for future import
austingmhuang Jul 11, 2024
39b6084
indentation error...
austingmhuang Jul 11, 2024
8b21c06
black
austingmhuang Jul 11, 2024
63124ec
maybe dont need from future?
austingmhuang Jul 11, 2024
b1de740
Update pennylane/wires.py
austingmhuang Jul 15, 2024
0e68949
Update pennylane/wires.py
austingmhuang Jul 15, 2024
be5844e
Update pennylane/wires.py
austingmhuang Jul 15, 2024
effec8e
Update pennylane/wires.py
austingmhuang Jul 15, 2024
af54d3a
address comments and update tests accordingly
austingmhuang Jul 15, 2024
a714ccd
docstring
austingmhuang Jul 15, 2024
2a6980d
Merge branch 'master' into set-based-operations
austingmhuang Jul 15, 2024
f0cf271
fixes to tests
austingmhuang Jul 15, 2024
b9414e9
Merge branch 'master' into set-based-operations
austingmhuang Jul 15, 2024
3ee6dd0
duck typing for wires, removed tests related to typeErrors that are n…
austingmhuang Jul 16, 2024
2a43240
Merge branch 'master' into set-based-operations
austingmhuang Jul 18, 2024
e1c3746
Update pennylane/wires.py
austingmhuang Jul 19, 2024
3fd94b3
Update pennylane/wires.py
austingmhuang Jul 19, 2024
138eb31
Update pennylane/wires.py
austingmhuang Jul 19, 2024
c5847a3
Update pennylane/wires.py
austingmhuang Jul 19, 2024
e0fe314
Update pennylane/wires.py
austingmhuang Jul 19, 2024
2ca88d7
Update pennylane/wires.py
austingmhuang Jul 19, 2024
c4e7f95
Update pennylane/wires.py
austingmhuang Jul 19, 2024
8acb343
small impelementation change
austingmhuang Jul 19, 2024
7023a6e
Update pennylane/wires.py
austingmhuang Jul 19, 2024
0e6caa7
revert
austingmhuang Jul 19, 2024
5dd10ff
Update pennylane/wires.py
austingmhuang Jul 24, 2024
9f349de
Update pennylane/wires.py
austingmhuang Jul 24, 2024
475333b
Update pennylane/wires.py
austingmhuang Jul 24, 2024
c234fad
Update pennylane/wires.py
austingmhuang Jul 24, 2024
796219c
Update pennylane/wires.py
austingmhuang Jul 24, 2024
b7a9c8d
fixes
austingmhuang Jul 25, 2024
675dbcb
Merge branch 'master' into set-based-operations
austingmhuang Jul 25, 2024
018c7f5
Merge branch 'master' into set-based-operations
austingmhuang Jul 26, 2024
5a859ae
Merge branch 'master' into set-based-operations
austingmhuang Jul 26, 2024
4d5ef07
Update doc/releases/changelog-dev.md
austingmhuang Jul 26, 2024
d21786c
Merge branch 'master' into set-based-operations
austingmhuang Jul 26, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions doc/releases/changelog-dev.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,9 @@
* `QuantumScript.hash` is now cached, leading to performance improvements.
[(#5919)](https://github.com/PennyLaneAI/pennylane/pull/5919)

* Set operations are now supported by Wires.
[(#5983)](https://github.com/PennyLaneAI/pennylane/pull/5983)

* The representation for `Wires` has now changed to be more copy-paste friendly.
[(#5958)](https://github.com/PennyLaneAI/pennylane/pull/5958)

Expand Down
180 changes: 180 additions & 0 deletions pennylane/wires.py
Original file line number Diff line number Diff line change
Expand Up @@ -493,6 +493,186 @@ def unique_wires(list_of_wires):

return Wires(tuple(unique), _override=True)

def union(self, other):
"""Return the union of the current Wires object and another Wires object.

Args:
other (Wires): Another Wires object to perform the union with.
austingmhuang marked this conversation as resolved.
Show resolved Hide resolved

Returns:
Wires: A new Wires object representing the union of the two Wires objects.

Raises:
TypeError: If `other` is not an instance of Wires.
austingmhuang marked this conversation as resolved.
Show resolved Hide resolved

**Example**
austingmhuang marked this conversation as resolved.
Show resolved Hide resolved

>>> wires1 = Wires([1, 2, 3])
>>> wires2 = Wires([3, 4, 5])
austingmhuang marked this conversation as resolved.
Show resolved Hide resolved
>>> wires1.union(wires2)
Wires([1, 2, 3, 4, 5])
"""
return Wires(set(self.labels) | set(other.labels))

def __or__(self, other):
"""Return the union of the current Wires object and another Wires object.

Args:
other (Wires): Another Wires object to perform the union with.
austingmhuang marked this conversation as resolved.
Show resolved Hide resolved

Returns:
Wires: A new Wires object representing the union of the two Wires objects.

Raises:
TypeError: If `other` is not an instance of Wires.

**Example**

>>> wires1 = Wires([1, 2, 3])
>>> wires2 = Wires([3, 4, 5])
austingmhuang marked this conversation as resolved.
Show resolved Hide resolved
>>> wires1 | wires2
austingmhuang marked this conversation as resolved.
Show resolved Hide resolved
Wires([1, 2, 3, 4, 5])
"""
if not isinstance(other, Wires):
raise TypeError(f"Can only do the union of Wires with Wires. Got {type(other)}.")
return self.union(other)

def intersection(self, other):
"""Return the intersection of the current Wires object and another Wires object.

Args:
other (Wires): Another Wires object to perform the intersection with.

Returns:
Wires: A new Wires object representing the intersection of the two Wires objects.

Raises:
TypeError: If `other` is not an instance of Wires.
austingmhuang marked this conversation as resolved.
Show resolved Hide resolved
austingmhuang marked this conversation as resolved.
Show resolved Hide resolved

**Example**

>>> wires1 = Wires([1, 2, 3])
>>> wires2 = Wires([2, 3, 4])
austingmhuang marked this conversation as resolved.
Show resolved Hide resolved
>>> wires1.intersection(wires2)
Wires([2, 3])
"""
return Wires(set(self.labels) & set(other.labels))

def __and__(self, other):
"""Return the intersection of the current Wires object and another Wires object.

Args:
other (Wires): Another Wires object to perform the intersection with.

Returns:
Wires: A new Wires object representing the intersection of the two Wires objects.

Raises:
TypeError: If `other` is not an instance of Wires.

**Example**

>>> wires1 = Wires([1, 2, 3])
>>> wires2 = Wires([2, 3, 4])
austingmhuang marked this conversation as resolved.
Show resolved Hide resolved
>>> wires1 & wires2
Wires([2, 3])
"""
if not isinstance(other, Wires):
raise TypeError(f"Can only do the intersection of Wires with Wires. Got {type(other)}.")
return self.intersection(other)
austingmhuang marked this conversation as resolved.
Show resolved Hide resolved

def difference(self, other):
austingmhuang marked this conversation as resolved.
Show resolved Hide resolved
"""Return the difference of the current Wires object and another Wires object.

Args:
other (Wires): Another Wires object to perform the difference with.

Returns:
Wires: A new Wires object representing the difference of the two Wires objects.

Raises:
TypeError: If `other` is not an instance of Wires.
austingmhuang marked this conversation as resolved.
Show resolved Hide resolved

**Example**

>>> wires1 = Wires([1, 2, 3])
>>> wires2 = Wires([2, 3, 4])
austingmhuang marked this conversation as resolved.
Show resolved Hide resolved
>>> wires1.difference(wires2)
Wires([1])
"""
return Wires(set(self.labels) - set(other.labels))
isaacdevlugt marked this conversation as resolved.
Show resolved Hide resolved

def __sub__(self, other):
"""Return the difference of the current Wires object and another Wires object.

Args:
other (Wires): Another Wires object to perform the difference with.

Returns:
Wires: A new Wires object representing the difference of the two Wires objects.

Raises:
TypeError: If `other` is not an instance of Wires.

**Example**

>>> wires1 = Wires([1, 2, 3])
austingmhuang marked this conversation as resolved.
Show resolved Hide resolved
>>> wires2 = Wires([2, 3, 4])
>>> wires1 - wires2
Wires([1])
"""
if not isinstance(other, Wires):
raise TypeError(
f"Can only do the difference of Wires with other Wires. Got {type(other)}."
)
return self.difference(other)

def symmetric_difference(self, other):
"""Return the symmetric difference of the current Wires object and another Wires object.

Args:
other (Wires): Another Wires object to perform the symmetric difference with.

Returns:
Wires: A new Wires object representing the symmetric difference of the two Wires objects.

Raises:
TypeError: If `other` is not an instance of Wires.
austingmhuang marked this conversation as resolved.
Show resolved Hide resolved

**Example**

>>> wires1 = Wires([1, 2, 3])
>>> wires2 = Wires([3, 4, 5])
austingmhuang marked this conversation as resolved.
Show resolved Hide resolved
>>> wires1.symmetric_difference(wires2)
Wires([1, 2, 4, 5])
"""
return Wires(set(self.labels) ^ set(other.labels))

def __xor__(self, other):
"""Return the symmetric difference of the current Wires object and another Wires object.

Args:
other (Wires): Another Wires object to perform the symmetric difference with.

Returns:
Wires: A new Wires object representing the symmetric difference of the two Wires objects.

Raises:
TypeError: If `other` is not an instance of Wires.

**Example**

>>> wires1 = Wires([1, 2, 3])
>>> wires2 = Wires([3, 4, 5])
austingmhuang marked this conversation as resolved.
Show resolved Hide resolved
>>> wires1 ^ wires2
Wires([1, 2, 4, 5])
"""
if not isinstance(other, Wires):
raise TypeError(
f"Can only the symmetric difference of Wires with other Wires. Got {type(other)}."
)
return self.symmetric_difference(other)


# Register Wires as a PyTree-serializable class
register_pytree(Wires, Wires._flatten, Wires._unflatten) # pylint: disable=protected-access
135 changes: 135 additions & 0 deletions tests/test_wires.py
Original file line number Diff line number Diff line change
Expand Up @@ -377,3 +377,138 @@ def test_wires_pytree(self, source):
wires2 = tree_unflatten(tree, wires_flat)
assert isinstance(wires2, Wires), f"{wires2} is not Wires"
assert wires == wires2, f"{wires} != {wires2}"

@pytest.mark.parametrize(
"wire_a, wire_b, expected",
[
(Wires([0, 1]), Wires([2, 3]), Wires([0, 1, 2, 3])),
austingmhuang marked this conversation as resolved.
Show resolved Hide resolved
(Wires([]), Wires([1, 2, 3]), Wires([1, 2, 3])),
(Wires([1, 2]), Wires([1, 2, 3]), Wires([1, 2, 3])),
],
)
def test_union(self, wire_a, wire_b, expected):
"""
Test the union operation (|) between two Wires objects.
"""
assert wire_a | wire_b == expected
assert wire_a.union(wire_b) == expected

@pytest.mark.parametrize(
"wire_a, wire_b, expected",
[
(Wires([0, 1, 2]), Wires([2, 3, 4]), Wires([2])),
(Wires([1, 2, 3]), Wires([]), Wires([])),
(Wires([1, 2, 3]), Wires([1, 2, 3, 4]), Wires([1, 2, 3])),
],
)
def test_intersection(self, wire_a, wire_b, expected):
"""
Test the intersection operation (&) between two Wires objects.
"""
assert wire_a & wire_b == expected
assert wire_a.intersection(wire_b) == expected

@pytest.mark.parametrize(
"wire_a, wire_b, expected",
[
(Wires([0, 1, 2, 3]), Wires([2, 3, 4]), Wires([0, 1])),
(Wires([1, 2, 3]), Wires([]), Wires([1, 2, 3])),
(Wires([1, 2, 3]), Wires([1, 2, 3, 4]), Wires([])),
],
)
def test_difference(self, wire_a, wire_b, expected):
"""
Test the difference operation (-) between two Wires objects.
"""
assert wire_a - wire_b == expected
assert wire_a.difference(wire_b) == expected

@pytest.mark.parametrize(
"wire_a, wire_b, expected",
[
(Wires([0, 1, 2]), Wires([2, 3, 4]), Wires([0, 1, 3, 4])),
(Wires([]), Wires([1, 2, 3]), Wires([1, 2, 3])),
(Wires([1, 2, 3]), Wires([1, 2, 3]), Wires([])),
],
)
def test_symmetric_difference(self, wire_a, wire_b, expected):
"""
Test the symmetric difference operation (^) between two Wires objects.
"""
assert wire_a ^ wire_b == expected
assert wire_a.symmetric_difference(wire_b) == expected

@pytest.mark.parametrize(
"wire_a, wire_b, wire_c, wire_d, expected",
[
(
Wires([0, 1]),
Wires([2, 3]),
Wires([4, 5]),
Wires([6, 7]),
Wires([0, 1, 2, 3, 4, 5, 6, 7]),
),
(Wires([0, 1]), Wires([1, 2]), Wires([2, 3]), Wires([3, 4]), Wires([0, 1, 2, 3, 4])),
(Wires([]), Wires([1, 2]), Wires([2, 3]), Wires([3, 4, 5]), Wires([1, 2, 3, 4, 5])),
],
)
# pylint: disable=too-many-arguments
def test_multiple_union(self, wire_a, wire_b, wire_c, wire_d, expected):
"""
Test the union operation (|) with multiple Wires objects.
"""
result = wire_a | wire_b | wire_c | wire_d
assert result == expected
assert wire_a.union(wire_b.union(wire_c.union(wire_d))) == expected

def test_complex_operation(self):
"""
Test a complex operation involving multiple set operations.
This test combines union, intersection, difference, and symmetric difference operations.
"""
wire_a = Wires([0, 1, 2, 3])
wire_b = Wires([2, 3, 4, 5])
wire_c = Wires([4, 5, 6, 7])
wire_d = Wires([6, 7, 8, 9])

# ((A ∪ B) ∩ (C ∪ D)) ^ ((A - D) ∪ (C - B))
result = ((wire_a | wire_b) & (wire_c | wire_d)) ^ ((wire_a - wire_d) | (wire_c - wire_b))
soranjh marked this conversation as resolved.
Show resolved Hide resolved
assert (wire_a | wire_b) & (wire_c | wire_d) == Wires([4, 5])
assert (wire_a - wire_d) | (wire_c - wire_b) == Wires([0, 1, 2, 3, 6, 7])

expected = Wires([0, 1, 2, 3, 4, 5, 6, 7])
assert result == expected

@pytest.mark.parametrize(
"operation, error_message",
[
(lambda a, b: a | b, "Can only do the union of Wires with Wires"),
(lambda a, b: a & b, "Can only do the intersection of Wires with Wires"),
(lambda a, b: a - b, "Can only do the difference of Wires with other Wires"),
(lambda a, b: a ^ b, "Can only the symmetric difference of Wires with other Wires"),
austingmhuang marked this conversation as resolved.
Show resolved Hide resolved
],
)
def test_type_error_raised(self, operation, error_message):
"""
Test that TypeError is raised with the correct error message when performing
set operations between a Wires object and an incompatible type.

This test covers union (|), intersection (&), difference (-),
and symmetric difference (^) operations.
"""
wire = Wires([1, 2, 3])
invalid_operands = [
42,
"string",
[1, 2, 3],
{1, 2, 3},
(1, 2, 3),
]

for invalid_operand in invalid_operands:
with pytest.raises(TypeError, match=error_message):
operation(wire, invalid_operand)

# Ensure the reverse also has a TypeError
with pytest.raises(TypeError, match="unsupported operand"):
operation(invalid_operand, wire)
Loading