From 9808d0ca19f104033970792bbbfe019e236679fc Mon Sep 17 00:00:00 2001 From: Austin Huang <65315367+austingmhuang@users.noreply.github.com> Date: Fri, 26 Jul 2024 16:59:15 -0400 Subject: [PATCH] Set operations and tests added (#5983) **Context:** Set operations are useful for Wires and Registers and would be a welcome addition/feature in the PennyLane ecosystem **Description of the Change:** This PR implements the 4 set operations for Wires and you can access them with their respective operators (| & - ^) as well as comprehensive tests. **Benefits:** Nice syntactic sugar **Possible Drawbacks:** **Related GitHub Issues:** [sc-67925] --------- Co-authored-by: Isaac De Vlugt <34751083+isaacdevlugt@users.noreply.github.com> Co-authored-by: Thomas R. Bromley <49409390+trbromley@users.noreply.github.com> --- doc/releases/changelog-dev.md | 3 + pennylane/wires.py | 201 ++++++++++++++++++++++++++++++++++ tests/test_wires.py | 110 +++++++++++++++++++ 3 files changed, 314 insertions(+) diff --git a/doc/releases/changelog-dev.md b/doc/releases/changelog-dev.md index 9c030886b8d..f45e8223db6 100644 --- a/doc/releases/changelog-dev.md +++ b/doc/releases/changelog-dev.md @@ -60,6 +60,9 @@ * `QuantumScript.hash` is now cached, leading to performance improvements. [(#5919)](https://github.com/PennyLaneAI/pennylane/pull/5919) +* Set operations are now supported by Wires. + [(#5983)](https://github.com/PennyLaneAI/pennylane/pull/5983) + * `qml.dynamic_one_shot` now supports circuits using the `"tensorflow"` interface. [(#5973)](https://github.com/PennyLaneAI/pennylane/pull/5973) diff --git a/pennylane/wires.py b/pennylane/wires.py index 3ee494bba3a..c26a4a0d253 100644 --- a/pennylane/wires.py +++ b/pennylane/wires.py @@ -493,6 +493,207 @@ def unique_wires(list_of_wires): return Wires(tuple(unique), _override=True) + def union(self, other): + """Return the union of the current Wires object and either another Wires object or an + iterable that can be interpreted like a Wires object e.g., List. + + Args: + other (Any): Wires or any iterable that can be interpreted like a Wires object + to perform the union with. See _process for details on the interpretation. + + Returns: + Wires: A new Wires object representing the union of the two Wires objects. + + **Example** + + >>> from pennylane.wires import Wires + >>> wires1 = Wires([1, 2, 3]) + >>> wires2 = Wires([3, 4, 5]) + >>> wires1.union(wires2) + Wires([1, 2, 3, 4, 5]) + + Alternatively, use the | operator: + >>> wires1 | wires2 + Wires([1, 2, 3, 4, 5]) + """ + return Wires((set(self.labels) | set(_process(other)))) + + def __or__(self, other): + """Return the union of the current Wires object and either another Wires object or an + iterable that can be interpreted like a Wires object e.g., List. + + Args: + other (Any): Wires or any iterable that can be interpreted like a Wires object + to perform the union with. See _process for details on the interpretation. + + Returns: + Wires: A new Wires object representing the union of the two Wires objects. + + **Example** + + >>> from pennylane.wires import Wires + >>> wires1 = Wires([1, 2, 3]) + >>> wires2 = Wires([3, 4, 5]) + >>> wires1 | wires2 + Wires([1, 2, 3, 4, 5]) + """ + return self.union(other) + + def __ror__(self, other): + """Right-hand version of __or__.""" + return self.union(other) + + def intersection(self, other): + """Return the intersection of the current Wires object and either another Wires object or + an iterable that can be interpreted like a Wires object e.g., List. + + Args: + other (Any): Wires or any iterable that can be interpreted like a Wires object + to perform the union with. See _process for details on the interpretation. + + Returns: + Wires: A new Wires object representing the intersection of the two Wires objects. + + **Example** + + >>> from pennylane.wires import Wires + >>> wires1 = Wires([1, 2, 3]) + >>> wires2 = Wires([2, 3, 4]) + >>> wires1.intersection(wires2) + Wires([2, 3]) + + Alternatively, use the & operator: + >>> wires1 & wires2 + Wires([2, 3]) + """ + return Wires((set(self.labels) & set(_process(other)))) + + def __and__(self, other): + """Return the intersection of the current Wires object and either another Wires object or + an iterable that can be interpreted like a Wires object e.g., List. + + Args: + other (Any): Wires or any iterable that can be interpreted like a Wires object + to perform the union with. See _process for details on the interpretation. + + Returns: + Wires: A new Wires object representing the intersection of the two Wires objects. + + **Example** + + >>> from pennylane.wires import Wires + >>> wires1 = Wires([1, 2, 3]) + >>> wires2 = Wires([2, 3, 4]) + >>> wires1 & wires2 + Wires([2, 3]) + """ + return self.intersection(other) + + def __rand__(self, other): + """Right-hand version of __and__.""" + return self.intersection(other) + + def difference(self, other): + """Return the difference of the current Wires object and either another Wires object or + an iterable that can be interpreted like a Wires object e.g., List. + + Args: + other (Any): Wires object or any iterable that can be interpreted like a Wires object + to perform the union with. See _process for details on the interpretation. + + Returns: + Wires: A new Wires object representing the difference of the two Wires objects. + + **Example** + + >>> from pennylane.wires import Wires + >>> wires1 = Wires([1, 2, 3]) + >>> wires2 = Wires([2, 3, 4]) + >>> wires1.difference(wires2) + Wires([1]) + + Alternatively, use the - operator: + >>> wires1 - wires2 + Wires([1]) + """ + return Wires((set(self.labels) - set(_process(other)))) + + def __sub__(self, other): + """Return the difference of the current Wires object and either another Wires object or + an iterable that can be interpreted like a Wires object e.g., List. + + Args: + other (Any): Wires or any iterable that can be interpreted like a Wires object + to perform the union with. See _process for details on the interpretation. + + Returns: + Wires: A new Wires object representing the difference of the two Wires objects. + + **Example** + + >>> from pennylane.wires import Wires + >>> wires1 = Wires([1, 2, 3]) + >>> wires2 = Wires([2, 3, 4]) + >>> wires1 - wires2 + Wires([1]) + """ + return self.difference(other) + + def __rsub__(self, other): + """Right-hand version of __sub__.""" + return Wires((set(_process(other)) - set(self.labels))) + + def symmetric_difference(self, other): + """Return the symmetric difference of the current Wires object and either another Wires + object or an iterable that can be interpreted like a Wires object e.g., List. + + Args: + other (Any): Wires or any iterable that can be interpreted like a Wires object + to perform the union with. See _process for details on the interpretation. + + Returns: + Wires: A new Wires object representing the symmetric difference of the two Wires objects. + + **Example** + + >>> from pennylane.wires import Wires + >>> wires1 = Wires([1, 2, 3]) + >>> wires2 = Wires([3, 4, 5]) + >>> wires1.symmetric_difference(wires2) + Wires([1, 2, 4, 5]) + + Alternatively, use the ^ operator: + >>> wires1 ^ wires2 + Wires([1, 2, 4, 5]) + """ + + return Wires((set(self.labels) ^ set(_process(other)))) + + def __xor__(self, other): + """Return the symmetric difference of the current Wires object and either another Wires + object or an iterable that can be interpreted like a Wires object e.g., List. + + Args: + other (Any): Wires or any iterable that can be interpreted like a Wires object + to perform the union with. See _process for details on the interpretation. + + Returns: + Wires: A new Wires object representing the symmetric difference of the two Wires objects. + + **Example** + + >>> from pennylane.wires import Wires + >>> wires1 = Wires([1, 2, 3]) + >>> wires2 = Wires([3, 4, 5]) + >>> wires1 ^ wires2 + Wires([1, 2, 4, 5]) + """ + return self.symmetric_difference(other) + + def __rxor__(self, other): + """Right-hand version of __xor__.""" + return Wires((set(_process(other)) ^ set(self.labels))) + # Register Wires as a PyTree-serializable class register_pytree(Wires, Wires._flatten, Wires._unflatten) # pylint: disable=protected-access diff --git a/tests/test_wires.py b/tests/test_wires.py index 4164e48b377..5ceb6475bb9 100644 --- a/tests/test_wires.py +++ b/tests/test_wires.py @@ -377,3 +377,113 @@ def test_wires_pytree(self, source): wires2 = tree_unflatten(tree, wires_flat) assert isinstance(wires2, Wires), f"{wires2} is not Wires" assert wires == wires2, f"{wires} != {wires2}" + + @pytest.mark.parametrize( + "wire_a, wire_b, expected", + [ + (Wires([0, 1]), Wires([2, 3]), Wires([0, 1, 2, 3])), + (Wires([0, 1]), [2, 3], Wires([0, 1, 2, 3])), + ([], Wires([1, 2, 3]), Wires([1, 2, 3])), + ({4, 5}, Wires([1, 2, 3]), Wires([1, 2, 3, 4, 5])), + (Wires([1, 2]), Wires([1, 2, 3]), Wires([1, 2, 3])), + ], + ) + def test_union(self, wire_a, wire_b, expected): + """ + Test the union operation (|) between two Wires objects. + """ + assert wire_a | wire_b == expected + assert wire_b | wire_a == expected + + @pytest.mark.parametrize( + "wire_a, wire_b, expected", + [ + (Wires([0, 1, 2]), Wires([2, 3, 4]), Wires([2])), + (Wires([1, 2, 3]), Wires([]), Wires([])), + (Wires([1, 2, 3]), [], Wires([])), + (Wires([1, 2, 3]), "2", Wires([])), + (Wires([1, 2, 3]), {3}, Wires([3])), + (Wires([1, 2, 3]), Wires([1, 2, 3, 4]), Wires([1, 2, 3])), + ], + ) + def test_intersection(self, wire_a, wire_b, expected): + """ + Test the intersection operation (&) between two Wires objects. + """ + assert wire_a & wire_b == expected + assert wire_b & wire_a == expected + + @pytest.mark.parametrize( + "wire_a, wire_b, expected", + [ + (Wires([0, 1, 2, 3]), Wires([2, 3, 4]), Wires([0, 1])), + (Wires([1, 2, 3]), Wires([]), Wires([1, 2, 3])), + (Wires([1, 2, 3]), Wires([1, 2, 3, 4]), Wires([])), + (Wires([1, 2, 3]), [], Wires([1, 2, 3])), + ([1, 2, 3], Wires([]), Wires([1, 2, 3])), + ([], Wires([]), Wires([])), + (Wires([]), [], Wires([])), + ], + ) + def test_difference(self, wire_a, wire_b, expected): + """ + Test the difference operation (-) between two Wires objects. + """ + assert wire_a - wire_b == expected + + @pytest.mark.parametrize( + "wire_a, wire_b, expected", + [ + (Wires([0, 1, 2]), Wires([2, 3, 4]), Wires([0, 1, 3, 4])), + ([0, 1, 2], Wires([2, 3, 4]), Wires([0, 1, 3, 4])), + (Wires([0, 1, 2]), [2, 3, 4], Wires([0, 1, 3, 4])), + (Wires([]), Wires([1, 2, 3]), Wires([1, 2, 3])), + (Wires([1, 2, 3]), Wires([1, 2, 3]), Wires([])), + ], + ) + def test_symmetric_difference(self, wire_a, wire_b, expected): + """ + Test the symmetric difference operation (^) between two Wires objects. + """ + assert wire_a ^ wire_b == expected + + @pytest.mark.parametrize( + "wire_a, wire_b, wire_c, wire_d, expected", + [ + ( + Wires([0, 1]), + Wires([2, 3]), + Wires([4, 5]), + Wires([6, 7]), + Wires([0, 1, 2, 3, 4, 5, 6, 7]), + ), + (Wires([0, 1]), Wires([1, 2]), Wires([2, 3]), Wires([3, 4]), Wires([0, 1, 2, 3, 4])), + (Wires([]), Wires([1, 2]), Wires([2, 3]), Wires([3, 4, 5]), Wires([1, 2, 3, 4, 5])), + ], + ) + # pylint: disable=too-many-arguments + def test_multiple_union(self, wire_a, wire_b, wire_c, wire_d, expected): + """ + Test the union operation (|) with multiple Wires objects. + """ + result = wire_a | wire_b | wire_c | wire_d + assert result == expected + assert wire_a.union(wire_b.union(wire_c.union(wire_d))) == expected + + def test_complex_operation(self): + """ + Test a complex operation involving multiple set operations. + This test combines union, intersection, difference, and symmetric difference operations. + """ + wire_a = Wires([0, 1, 2, 3]) + wire_b = Wires([2, 3, 4, 5]) + wire_c = Wires([4, 5, 6, 7]) + wire_d = Wires([6, 7, 8, 9]) + + # ((A ∪ B) ∩ (C ∪ D)) ^ ((A - D) ∪ (C - B)) + result = ((wire_a | wire_b) & (wire_c | wire_d)) ^ ((wire_a - wire_d) | (wire_c - wire_b)) + assert (wire_a | wire_b) & (wire_c | wire_d) == Wires([4, 5]) + assert (wire_a - wire_d) | (wire_c - wire_b) == Wires([0, 1, 2, 3, 6, 7]) + + expected = Wires([0, 1, 2, 3, 4, 5, 6, 7]) + assert result == expected