Skip to content

Commit

Permalink
Set operations and tests added (#5983)
Browse files Browse the repository at this point in the history
**Context:**
Set operations are useful for Wires and Registers and would be a welcome
addition/feature in the PennyLane ecosystem

**Description of the Change:**
This PR implements the 4 set operations for Wires and you can access
them with their respective operators (| & - ^) as well as comprehensive
tests.

**Benefits:**
Nice syntactic sugar

**Possible Drawbacks:**

**Related GitHub Issues:**
[sc-67925]

---------

Co-authored-by: Isaac De Vlugt <34751083+isaacdevlugt@users.noreply.github.com>
Co-authored-by: Thomas R. Bromley <49409390+trbromley@users.noreply.github.com>
  • Loading branch information
3 people authored Jul 26, 2024
1 parent c534c17 commit 9808d0c
Show file tree
Hide file tree
Showing 3 changed files with 314 additions and 0 deletions.
3 changes: 3 additions & 0 deletions doc/releases/changelog-dev.md
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,9 @@
* `QuantumScript.hash` is now cached, leading to performance improvements.
[(#5919)](https://github.com/PennyLaneAI/pennylane/pull/5919)

* Set operations are now supported by Wires.
[(#5983)](https://github.com/PennyLaneAI/pennylane/pull/5983)

* `qml.dynamic_one_shot` now supports circuits using the `"tensorflow"` interface.
[(#5973)](https://github.com/PennyLaneAI/pennylane/pull/5973)

Expand Down
201 changes: 201 additions & 0 deletions pennylane/wires.py
Original file line number Diff line number Diff line change
Expand Up @@ -493,6 +493,207 @@ def unique_wires(list_of_wires):

return Wires(tuple(unique), _override=True)

def union(self, other):
"""Return the union of the current Wires object and either another Wires object or an
iterable that can be interpreted like a Wires object e.g., List.
Args:
other (Any): Wires or any iterable that can be interpreted like a Wires object
to perform the union with. See _process for details on the interpretation.
Returns:
Wires: A new Wires object representing the union of the two Wires objects.
**Example**
>>> from pennylane.wires import Wires
>>> wires1 = Wires([1, 2, 3])
>>> wires2 = Wires([3, 4, 5])
>>> wires1.union(wires2)
Wires([1, 2, 3, 4, 5])
Alternatively, use the | operator:
>>> wires1 | wires2
Wires([1, 2, 3, 4, 5])
"""
return Wires((set(self.labels) | set(_process(other))))

def __or__(self, other):
"""Return the union of the current Wires object and either another Wires object or an
iterable that can be interpreted like a Wires object e.g., List.
Args:
other (Any): Wires or any iterable that can be interpreted like a Wires object
to perform the union with. See _process for details on the interpretation.
Returns:
Wires: A new Wires object representing the union of the two Wires objects.
**Example**
>>> from pennylane.wires import Wires
>>> wires1 = Wires([1, 2, 3])
>>> wires2 = Wires([3, 4, 5])
>>> wires1 | wires2
Wires([1, 2, 3, 4, 5])
"""
return self.union(other)

def __ror__(self, other):
"""Right-hand version of __or__."""
return self.union(other)

def intersection(self, other):
"""Return the intersection of the current Wires object and either another Wires object or
an iterable that can be interpreted like a Wires object e.g., List.
Args:
other (Any): Wires or any iterable that can be interpreted like a Wires object
to perform the union with. See _process for details on the interpretation.
Returns:
Wires: A new Wires object representing the intersection of the two Wires objects.
**Example**
>>> from pennylane.wires import Wires
>>> wires1 = Wires([1, 2, 3])
>>> wires2 = Wires([2, 3, 4])
>>> wires1.intersection(wires2)
Wires([2, 3])
Alternatively, use the & operator:
>>> wires1 & wires2
Wires([2, 3])
"""
return Wires((set(self.labels) & set(_process(other))))

def __and__(self, other):
"""Return the intersection of the current Wires object and either another Wires object or
an iterable that can be interpreted like a Wires object e.g., List.
Args:
other (Any): Wires or any iterable that can be interpreted like a Wires object
to perform the union with. See _process for details on the interpretation.
Returns:
Wires: A new Wires object representing the intersection of the two Wires objects.
**Example**
>>> from pennylane.wires import Wires
>>> wires1 = Wires([1, 2, 3])
>>> wires2 = Wires([2, 3, 4])
>>> wires1 & wires2
Wires([2, 3])
"""
return self.intersection(other)

def __rand__(self, other):
"""Right-hand version of __and__."""
return self.intersection(other)

def difference(self, other):
"""Return the difference of the current Wires object and either another Wires object or
an iterable that can be interpreted like a Wires object e.g., List.
Args:
other (Any): Wires object or any iterable that can be interpreted like a Wires object
to perform the union with. See _process for details on the interpretation.
Returns:
Wires: A new Wires object representing the difference of the two Wires objects.
**Example**
>>> from pennylane.wires import Wires
>>> wires1 = Wires([1, 2, 3])
>>> wires2 = Wires([2, 3, 4])
>>> wires1.difference(wires2)
Wires([1])
Alternatively, use the - operator:
>>> wires1 - wires2
Wires([1])
"""
return Wires((set(self.labels) - set(_process(other))))

def __sub__(self, other):
"""Return the difference of the current Wires object and either another Wires object or
an iterable that can be interpreted like a Wires object e.g., List.
Args:
other (Any): Wires or any iterable that can be interpreted like a Wires object
to perform the union with. See _process for details on the interpretation.
Returns:
Wires: A new Wires object representing the difference of the two Wires objects.
**Example**
>>> from pennylane.wires import Wires
>>> wires1 = Wires([1, 2, 3])
>>> wires2 = Wires([2, 3, 4])
>>> wires1 - wires2
Wires([1])
"""
return self.difference(other)

def __rsub__(self, other):
"""Right-hand version of __sub__."""
return Wires((set(_process(other)) - set(self.labels)))

def symmetric_difference(self, other):
"""Return the symmetric difference of the current Wires object and either another Wires
object or an iterable that can be interpreted like a Wires object e.g., List.
Args:
other (Any): Wires or any iterable that can be interpreted like a Wires object
to perform the union with. See _process for details on the interpretation.
Returns:
Wires: A new Wires object representing the symmetric difference of the two Wires objects.
**Example**
>>> from pennylane.wires import Wires
>>> wires1 = Wires([1, 2, 3])
>>> wires2 = Wires([3, 4, 5])
>>> wires1.symmetric_difference(wires2)
Wires([1, 2, 4, 5])
Alternatively, use the ^ operator:
>>> wires1 ^ wires2
Wires([1, 2, 4, 5])
"""

return Wires((set(self.labels) ^ set(_process(other))))

def __xor__(self, other):
"""Return the symmetric difference of the current Wires object and either another Wires
object or an iterable that can be interpreted like a Wires object e.g., List.
Args:
other (Any): Wires or any iterable that can be interpreted like a Wires object
to perform the union with. See _process for details on the interpretation.
Returns:
Wires: A new Wires object representing the symmetric difference of the two Wires objects.
**Example**
>>> from pennylane.wires import Wires
>>> wires1 = Wires([1, 2, 3])
>>> wires2 = Wires([3, 4, 5])
>>> wires1 ^ wires2
Wires([1, 2, 4, 5])
"""
return self.symmetric_difference(other)

def __rxor__(self, other):
"""Right-hand version of __xor__."""
return Wires((set(_process(other)) ^ set(self.labels)))


# Register Wires as a PyTree-serializable class
register_pytree(Wires, Wires._flatten, Wires._unflatten) # pylint: disable=protected-access
110 changes: 110 additions & 0 deletions tests/test_wires.py
Original file line number Diff line number Diff line change
Expand Up @@ -377,3 +377,113 @@ def test_wires_pytree(self, source):
wires2 = tree_unflatten(tree, wires_flat)
assert isinstance(wires2, Wires), f"{wires2} is not Wires"
assert wires == wires2, f"{wires} != {wires2}"

@pytest.mark.parametrize(
"wire_a, wire_b, expected",
[
(Wires([0, 1]), Wires([2, 3]), Wires([0, 1, 2, 3])),
(Wires([0, 1]), [2, 3], Wires([0, 1, 2, 3])),
([], Wires([1, 2, 3]), Wires([1, 2, 3])),
({4, 5}, Wires([1, 2, 3]), Wires([1, 2, 3, 4, 5])),
(Wires([1, 2]), Wires([1, 2, 3]), Wires([1, 2, 3])),
],
)
def test_union(self, wire_a, wire_b, expected):
"""
Test the union operation (|) between two Wires objects.
"""
assert wire_a | wire_b == expected
assert wire_b | wire_a == expected

@pytest.mark.parametrize(
"wire_a, wire_b, expected",
[
(Wires([0, 1, 2]), Wires([2, 3, 4]), Wires([2])),
(Wires([1, 2, 3]), Wires([]), Wires([])),
(Wires([1, 2, 3]), [], Wires([])),
(Wires([1, 2, 3]), "2", Wires([])),
(Wires([1, 2, 3]), {3}, Wires([3])),
(Wires([1, 2, 3]), Wires([1, 2, 3, 4]), Wires([1, 2, 3])),
],
)
def test_intersection(self, wire_a, wire_b, expected):
"""
Test the intersection operation (&) between two Wires objects.
"""
assert wire_a & wire_b == expected
assert wire_b & wire_a == expected

@pytest.mark.parametrize(
"wire_a, wire_b, expected",
[
(Wires([0, 1, 2, 3]), Wires([2, 3, 4]), Wires([0, 1])),
(Wires([1, 2, 3]), Wires([]), Wires([1, 2, 3])),
(Wires([1, 2, 3]), Wires([1, 2, 3, 4]), Wires([])),
(Wires([1, 2, 3]), [], Wires([1, 2, 3])),
([1, 2, 3], Wires([]), Wires([1, 2, 3])),
([], Wires([]), Wires([])),
(Wires([]), [], Wires([])),
],
)
def test_difference(self, wire_a, wire_b, expected):
"""
Test the difference operation (-) between two Wires objects.
"""
assert wire_a - wire_b == expected

@pytest.mark.parametrize(
"wire_a, wire_b, expected",
[
(Wires([0, 1, 2]), Wires([2, 3, 4]), Wires([0, 1, 3, 4])),
([0, 1, 2], Wires([2, 3, 4]), Wires([0, 1, 3, 4])),
(Wires([0, 1, 2]), [2, 3, 4], Wires([0, 1, 3, 4])),
(Wires([]), Wires([1, 2, 3]), Wires([1, 2, 3])),
(Wires([1, 2, 3]), Wires([1, 2, 3]), Wires([])),
],
)
def test_symmetric_difference(self, wire_a, wire_b, expected):
"""
Test the symmetric difference operation (^) between two Wires objects.
"""
assert wire_a ^ wire_b == expected

@pytest.mark.parametrize(
"wire_a, wire_b, wire_c, wire_d, expected",
[
(
Wires([0, 1]),
Wires([2, 3]),
Wires([4, 5]),
Wires([6, 7]),
Wires([0, 1, 2, 3, 4, 5, 6, 7]),
),
(Wires([0, 1]), Wires([1, 2]), Wires([2, 3]), Wires([3, 4]), Wires([0, 1, 2, 3, 4])),
(Wires([]), Wires([1, 2]), Wires([2, 3]), Wires([3, 4, 5]), Wires([1, 2, 3, 4, 5])),
],
)
# pylint: disable=too-many-arguments
def test_multiple_union(self, wire_a, wire_b, wire_c, wire_d, expected):
"""
Test the union operation (|) with multiple Wires objects.
"""
result = wire_a | wire_b | wire_c | wire_d
assert result == expected
assert wire_a.union(wire_b.union(wire_c.union(wire_d))) == expected

def test_complex_operation(self):
"""
Test a complex operation involving multiple set operations.
This test combines union, intersection, difference, and symmetric difference operations.
"""
wire_a = Wires([0, 1, 2, 3])
wire_b = Wires([2, 3, 4, 5])
wire_c = Wires([4, 5, 6, 7])
wire_d = Wires([6, 7, 8, 9])

# ((A ∪ B) ∩ (C ∪ D)) ^ ((A - D) ∪ (C - B))
result = ((wire_a | wire_b) & (wire_c | wire_d)) ^ ((wire_a - wire_d) | (wire_c - wire_b))
assert (wire_a | wire_b) & (wire_c | wire_d) == Wires([4, 5])
assert (wire_a - wire_d) | (wire_c - wire_b) == Wires([0, 1, 2, 3, 6, 7])

expected = Wires([0, 1, 2, 3, 4, 5, 6, 7])
assert result == expected

0 comments on commit 9808d0c

Please sign in to comment.