From 253515fb51e048414f736696b4258ade830a5145 Mon Sep 17 00:00:00 2001 From: albi3ro Date: Mon, 13 May 2024 17:56:34 -0400 Subject: [PATCH 01/28] add tools for flattening and unflattening pytrees --- doc/releases/changelog-dev.md | 2 + pennylane/pytrees.py | 145 +++++++++++++++++++++++++++++++++- tests/test_pytrees.py | 126 +++++++++++++++++++++++++++++ 3 files changed, 270 insertions(+), 3 deletions(-) create mode 100644 tests/test_pytrees.py diff --git a/doc/releases/changelog-dev.md b/doc/releases/changelog-dev.md index ed88f98fcb4..3db0c556599 100644 --- a/doc/releases/changelog-dev.md +++ b/doc/releases/changelog-dev.md @@ -82,6 +82,8 @@ allowing error types to be more consistent with the context the `decompose` function is used in. [(#5669)](https://github.com/PennyLaneAI/pennylane/pull/5669) +* The `qml.pytrees` module now has `flatten` and `unflatten` methods for serializing pytrees. +

Breaking changes 💔

* `qml.is_commuting` no longer accepts the `wire_map` argument, which does not bring any functionality. diff --git a/pennylane/pytrees.py b/pennylane/pytrees.py index ddc82a66e19..42221859be7 100644 --- a/pennylane/pytrees.py +++ b/pennylane/pytrees.py @@ -1,4 +1,4 @@ -# Copyright 2018-2023 Xanadu Quantum Technologies Inc. +# Copyright 2018-2024 Xanadu Quantum Technologies Inc. # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -14,8 +14,8 @@ """ An internal module for working with pytrees. """ - -from typing import Any, Callable, Tuple +from collections import namedtuple +from typing import Any, Callable, Dict, List, Tuple, Union has_jax = True try: @@ -30,6 +30,57 @@ UnflattenFn = Callable[[Leaves, Metadata], Any] +def flatten_list(obj: list): + """Flatten a list.""" + return obj, None + + +def flatten_tuple(obj: tuple): + """Flatten a tuple.""" + return obj, None + + +def flatten_dict(obj: dict): + """Flatten a dictionary.""" + return obj.values(), tuple(obj.keys()) + + +flatten_registrations: Dict[type, FlattenFn] = { + list: flatten_list, + tuple: flatten_tuple, + dict: flatten_dict, +} + + +def unflatten_list(data, _) -> list: + """Unflatten a list.""" + return data if isinstance(data, list) else list(data) + + +def unflatten_tuple(data, _) -> tuple: + """Unflatten a tuple.""" + return tuple(data) + + +def unflatten_dict(data, metadata) -> dict: + """Unflatten a dictinoary.""" + return dict(zip(metadata, data)) + + +unflatten_registrations: Dict[type, UnflattenFn] = { + list: unflatten_list, + tuple: unflatten_tuple, + dict: unflatten_dict, +} + + +def _register_pytree_with_pennylane( + pytree_type: type, flatten_fn: FlattenFn, unflatten_fn: UnflattenFn +): + flatten_registrations[pytree_type] = flatten_fn + unflatten_registrations[pytree_type] = unflatten_fn + + def _register_pytree_with_jax(pytree_type: type, flatten_fn: FlattenFn, unflatten_fn: UnflattenFn): def jax_unflatten(aux, parameters): return unflatten_fn(parameters, aux) @@ -54,5 +105,93 @@ def register_pytree(pytree_type: type, flatten_fn: FlattenFn, unflatten_fn: Unfl """ + _register_pytree_with_pennylane(pytree_type, flatten_fn, unflatten_fn) + if has_jax: _register_pytree_with_jax(pytree_type, flatten_fn, unflatten_fn) + + +class Structure(namedtuple("Structure", ["type", "metadata", "children"])): + """A pytree data structure, holding the type, metadata, and child pytree structures.""" + + def __repr__(self): + return f"PyTree({self.type.__name__}, {self.metadata}, {self.children})" + + +class Leaf: + """A terminal node in a pytree.""" + + def __repr__(self): + return "Leaf" + + def __eq__(self, other): + return isinstance(other, Leaf) + + def __hash__(self): + return hash(Leaf) + + +leaf = Leaf() + + +def flatten(obj) -> Tuple[List[Any], Union[Structure, Leaf]]: + """Flattens a pytree into leaves and a structure. + + Args: + obj (Any): any object + + Returns: + List[Any], Union[Structure, Leaf]: a list of leaves and a structure representing the object + + >>> op = qml.adjoint(qml.Rot(1.2, 2.3, 3.4, wires=0)) + >>> data, structure = flatten(op) + >>> data + [1.2, 2.3, 3.4] + >>> structure + , ()), (Leaf, Leaf, Leaf))>,))> + + See also :function:`~.unflatten`. + + """ + flatten_fn = flatten_registrations.get(type(obj), None) + if flatten_fn is None: + return [obj], leaf + leaves, metadata = flatten_fn(obj) + + flattened_leaves = [] + child_structures = [] + for l in leaves: + child_leaves, child_structure = flatten(l) + flattened_leaves += child_leaves + child_structures.append(child_structure) + + structure = Structure(type(obj), metadata, child_structures) + return flattened_leaves, structure + + +def unflatten(data: List[Any], structure: Union[Structure, Leaf]) -> Any: + """Bind data to a structure to reconstruct a pytree object. + + Args: + data (Iterable): iterable of numbers and numeric arrays + structure (Structure, Leaf): The pytree structure object + + Returns: + A repacked pytree. + + .. see-also:: :function:`~.flatten` + + >>> op = qml.adjoint(qml.Rot(1.2, 2.3, 3.4, wires=0)) + >>> data, structure = flatten(op) + >>> unflatten([-2, -3, -4], structure) + Adjoint(Rot(-2, -3, -4, wires=[0])) + + """ + return _unflatten(iter(data), structure) + + +def _unflatten(new_data, structure): + if isinstance(structure, Leaf): + return next(new_data) + children = tuple(_unflatten(new_data, s) for s in structure[2]) + return unflatten_registrations[structure[0]](children, structure[1]) diff --git a/tests/test_pytrees.py b/tests/test_pytrees.py new file mode 100644 index 00000000000..fb429c38db8 --- /dev/null +++ b/tests/test_pytrees.py @@ -0,0 +1,126 @@ +# Copyright 2018-2024 Xanadu Quantum Technologies Inc. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Tests for the pennylane pytrees module +""" +import pennylane as qml +from pennylane.pytrees import Structure, flatten, leaf, register_pytree, unflatten + + +def test_register_new_class(): + """Test that new objects can be registered, flattened, and unflattened.""" + + # pylint: disable=too-few-public-methods + class MyObj: + """a dummy object.""" + + def __init__(self, a): + self.a = a + + def obj_flatten(obj): + return (obj.a,), None + + def obj_unflatten(data, _): + return MyObj(data[0]) + + register_pytree(MyObj, obj_flatten, obj_unflatten) + + obj = MyObj(0.5) + + data, structure = flatten(obj) + assert data == [0.5] + assert structure == Structure(MyObj, None, [leaf]) + + new_obj = unflatten([1.0], structure) + assert isinstance(new_obj, MyObj) + assert new_obj.a == 1.0 + + +def test_list(): + """Test that pennylane treats list as a pytree.""" + + x = [1, 2, [3, 4]] + + data, structure = flatten(x) + assert data == [1, 2, 3, 4] + assert structure == Structure(list, None, [leaf, leaf, Structure(list, None, [leaf, leaf])]) + + new_x = unflatten([5, 6, 7, 8], structure) + assert new_x == [5, 6, [7, 8]] + + +def test_tuple(): + """Test that pennylane can handle tuples as pytrees.""" + x = (1, 2, (3, 4)) + + data, structure = flatten(x) + assert data == [1, 2, 3, 4] + assert structure == Structure(tuple, None, [leaf, leaf, Structure(tuple, None, [leaf, leaf])]) + + new_x = unflatten([5, 6, 7, 8], structure) + assert new_x == (5, 6, (7, 8)) + + +def test_dict(): + """Test that pennylane can handle dictionaries as pytees.""" + + x = {"a": 1, "b": {"c": 2, "d": 3}} + + data, structure = flatten(x) + assert data == [1, 2, 3] + assert structure == Structure( + dict, ("a", "b"), [leaf, Structure(dict, ("c", "d"), [leaf, leaf])] + ) + new_x = unflatten([5, 6, 7], structure) + assert new_x == {"a": 5, "b": {"c": 6, "d": 7}} + + +def test_nested_pl_object(): + """Test that we can flatten and unflatten nested pennylane object.""" + + tape = qml.tape.QuantumScript( + [qml.adjoint(qml.RX(0.1, wires=0))], + [qml.expval(2 * qml.X(0))], + shots=50, + trainable_params=(0, 1), + ) + + data, structure = flatten(tape) + assert data == [0.1, 2, None] + + wires0 = qml.wires.Wires(0) + op_structure = Structure(tape[0].__class__, (), [Structure(qml.RX, (wires0, ()), [leaf])]) + list_op_struct = Structure(list, None, [op_structure]) + + sprod_structure = Structure(qml.ops.SProd, (), [leaf, Structure(qml.X, (wires0, ()), [])]) + meas_structure = Structure( + qml.measurements.ExpectationMP, (("wires", None),), [sprod_structure, leaf] + ) + list_meas_struct = Structure(list, None, [meas_structure]) + tape_structure = Structure( + qml.tape.QuantumScript, + (tape.shots, tape.trainable_params), + [list_op_struct, list_meas_struct], + ) + + assert structure == tape_structure + + new_tape = unflatten([3, 4, None], structure) + expected_new_tape = qml.tape.QuantumScript( + [qml.adjoint(qml.RX(3, wires=0))], + [qml.expval(4 * qml.X(0))], + shots=50, + trainable_params=(0, 1), + ) + assert qml.equal(new_tape, expected_new_tape) From 1a7fa1947a5fbea72a010a221dff4c8107652cb0 Mon Sep 17 00:00:00 2001 From: albi3ro Date: Thu, 16 May 2024 16:46:54 -0400 Subject: [PATCH 02/28] adding coverage --- tests/test_pytrees.py | 20 +++++++++++++++++++- 1 file changed, 19 insertions(+), 1 deletion(-) diff --git a/tests/test_pytrees.py b/tests/test_pytrees.py index fb429c38db8..56f7c071179 100644 --- a/tests/test_pytrees.py +++ b/tests/test_pytrees.py @@ -15,7 +15,25 @@ Tests for the pennylane pytrees module """ import pennylane as qml -from pennylane.pytrees import Structure, flatten, leaf, register_pytree, unflatten +from pennylane.pytrees import Leaf, Structure, flatten, leaf, register_pytree, unflatten + + +def test_structure_repr(): + """Test the repr of the structure class.""" + op = qml.RX(0.1, wires=0) + _, structure = qml.pytrees.flatten(op) + expected = "PyTree(RX, (, ()), [Leaf])" + assert repr(structure) == expected + + +def test_leaf_class(): + """Test the dunder methods of the leaf class.""" + + assert repr(leaf) == "Leaf" + assert Leaf() == Leaf() + assert hash(leaf) == hash(Leaf) + + assert set((Leaf(), Leaf())) == set((Leaf(),)) def test_register_new_class(): From 985805533de5418632134381be7055466b278bec Mon Sep 17 00:00:00 2001 From: Christina Lee Date: Tue, 21 May 2024 08:52:02 -0400 Subject: [PATCH 03/28] Apply suggestions from code review Co-authored-by: Thomas R. Bromley <49409390+trbromley@users.noreply.github.com> Co-authored-by: Jack Brown Co-authored-by: Mudit Pandey --- doc/releases/changelog-dev.md | 1 + pennylane/pytrees.py | 11 ++++++++--- 2 files changed, 9 insertions(+), 3 deletions(-) diff --git a/doc/releases/changelog-dev.md b/doc/releases/changelog-dev.md index 44c15056c0e..03ded617374 100644 --- a/doc/releases/changelog-dev.md +++ b/doc/releases/changelog-dev.md @@ -83,6 +83,7 @@ [(#5669)](https://github.com/PennyLaneAI/pennylane/pull/5669) * The `qml.pytrees` module now has `flatten` and `unflatten` methods for serializing pytrees. + [(#5701)](https://github.com/PennyLaneAI/pennylane/pull/5701)

Community contributions 🥳

diff --git a/pennylane/pytrees.py b/pennylane/pytrees.py index 42221859be7..ca1755fae6d 100644 --- a/pennylane/pytrees.py +++ b/pennylane/pytrees.py @@ -111,9 +111,14 @@ def register_pytree(pytree_type: type, flatten_fn: FlattenFn, unflatten_fn: Unfl _register_pytree_with_jax(pytree_type, flatten_fn, unflatten_fn) -class Structure(namedtuple("Structure", ["type", "metadata", "children"])): +@dataclass(repr=False) +class Structure: """A pytree data structure, holding the type, metadata, and child pytree structures.""" - + + type: type + metadata: Metadata + children: list[Leaf, 'Structure'] + def __repr__(self): return f"PyTree({self.type.__name__}, {self.metadata}, {self.children})" @@ -148,7 +153,7 @@ def flatten(obj) -> Tuple[List[Any], Union[Structure, Leaf]]: >>> data [1.2, 2.3, 3.4] >>> structure - , ()), (Leaf, Leaf, Leaf))>,))> + , ()), (Leaf, Leaf, Leaf))>,))> See also :function:`~.unflatten`. From a618d21f02bef6dcfa2a2246f558e4f9207c0027 Mon Sep 17 00:00:00 2001 From: albi3ro Date: Tue, 21 May 2024 14:36:49 -0400 Subject: [PATCH 04/28] responding to feedback, leaf is PyTreeStructure with no type --- pennylane/pytrees.py | 65 +++++++++++++++++++++++-------------------- tests/test_pytrees.py | 42 +++++++++++++--------------- 2 files changed, 55 insertions(+), 52 deletions(-) diff --git a/pennylane/pytrees.py b/pennylane/pytrees.py index ca1755fae6d..c5ce6ae8e28 100644 --- a/pennylane/pytrees.py +++ b/pennylane/pytrees.py @@ -14,8 +14,8 @@ """ An internal module for working with pytrees. """ -from collections import namedtuple -from typing import Any, Callable, Dict, List, Tuple, Union +from dataclasses import dataclass +from typing import Any, Callable, List, Optional, Tuple has_jax = True try: @@ -45,7 +45,7 @@ def flatten_dict(obj: dict): return obj.values(), tuple(obj.keys()) -flatten_registrations: Dict[type, FlattenFn] = { +flatten_registrations: dict[type, FlattenFn] = { list: flatten_list, tuple: flatten_tuple, dict: flatten_dict, @@ -67,7 +67,7 @@ def unflatten_dict(data, metadata) -> dict: return dict(zip(metadata, data)) -unflatten_registrations: Dict[type, UnflattenFn] = { +unflatten_registrations: dict[type, UnflattenFn] = { list: unflatten_list, tuple: unflatten_tuple, dict: unflatten_dict, @@ -91,7 +91,8 @@ def jax_unflatten(aux, parameters): def register_pytree(pytree_type: type, flatten_fn: FlattenFn, unflatten_fn: UnflattenFn): """Register a type with all available pytree backends. - Current backends is jax. + Current backends are jax and pennylane. + Args: pytree_type (type): the type to register, such as ``qml.RX`` flatten_fn (Callable): a function that splits an object into trainable leaves and hashable metadata. @@ -103,6 +104,8 @@ def register_pytree(pytree_type: type, flatten_fn: FlattenFn, unflatten_fn: Unfl Side Effects: ``pytree`` type becomes registered with available backends. + .. seealso:: :func:`~.flatten`, :func:`~.unflatten`. + """ _register_pytree_with_pennylane(pytree_type, flatten_fn, unflatten_fn) @@ -111,35 +114,37 @@ def register_pytree(pytree_type: type, flatten_fn: FlattenFn, unflatten_fn: Unfl _register_pytree_with_jax(pytree_type, flatten_fn, unflatten_fn) -@dataclass(repr=False) -class Structure: - """A pytree data structure, holding the type, metadata, and child pytree structures.""" - - type: type - metadata: Metadata - children: list[Leaf, 'Structure'] - - def __repr__(self): - return f"PyTree({self.type.__name__}, {self.metadata}, {self.children})" +@dataclass(repr=False, frozen=True) +class PyTreeStructure: + """A pytree data structure, holding the type, metadata, and child pytree structures. + + >>> op = qml.adjoint(qml.RX(0.1, 0)) + >>> data, structure = qml.pytrees.flatten(op) + >>> structure + PyTree(AdjointOperation, (), [PyTree(RX, (, ()), [Leaf])]) + A leaf is defined as just a ``PyTreeStructure`` with ``type=None``. + """ -class Leaf: - """A terminal node in a pytree.""" + type: Optional[type] + """The type corresponding to the node. If ``None``, then the structure is a leaf.""" - def __repr__(self): - return "Leaf" + metadata: Metadata + """Any metadata needed to reproduce the original object.""" - def __eq__(self, other): - return isinstance(other, Leaf) + children: list["PyTreeStructure"] + """The children of the pytree node. Can be either other structures or terminal leaves.""" - def __hash__(self): - return hash(Leaf) + def __repr__(self): + if self.type is None: + return "Leaf" + return f"PyTree({self.type.__name__}, {self.metadata}, {self.children})" -leaf = Leaf() +leaf = PyTreeStructure(None, (), []) -def flatten(obj) -> Tuple[List[Any], Union[Structure, Leaf]]: +def flatten(obj) -> Tuple[List[Any], PyTreeStructure]: """Flattens a pytree into leaves and a structure. Args: @@ -170,11 +175,11 @@ def flatten(obj) -> Tuple[List[Any], Union[Structure, Leaf]]: flattened_leaves += child_leaves child_structures.append(child_structure) - structure = Structure(type(obj), metadata, child_structures) + structure = PyTreeStructure(type(obj), metadata, child_structures) return flattened_leaves, structure -def unflatten(data: List[Any], structure: Union[Structure, Leaf]) -> Any: +def unflatten(data: List[Any], structure: PyTreeStructure) -> Any: """Bind data to a structure to reconstruct a pytree object. Args: @@ -196,7 +201,7 @@ def unflatten(data: List[Any], structure: Union[Structure, Leaf]) -> Any: def _unflatten(new_data, structure): - if isinstance(structure, Leaf): + if structure.type is None: # is leaf return next(new_data) - children = tuple(_unflatten(new_data, s) for s in structure[2]) - return unflatten_registrations[structure[0]](children, structure[1]) + children = tuple(_unflatten(new_data, s) for s in structure.children) + return unflatten_registrations[structure.type](children, structure.metadata) diff --git a/tests/test_pytrees.py b/tests/test_pytrees.py index 56f7c071179..17016c2ad07 100644 --- a/tests/test_pytrees.py +++ b/tests/test_pytrees.py @@ -15,7 +15,7 @@ Tests for the pennylane pytrees module """ import pennylane as qml -from pennylane.pytrees import Leaf, Structure, flatten, leaf, register_pytree, unflatten +from pennylane.pytrees import PyTreeStructure, flatten, leaf, register_pytree, unflatten def test_structure_repr(): @@ -26,16 +26,6 @@ def test_structure_repr(): assert repr(structure) == expected -def test_leaf_class(): - """Test the dunder methods of the leaf class.""" - - assert repr(leaf) == "Leaf" - assert Leaf() == Leaf() - assert hash(leaf) == hash(Leaf) - - assert set((Leaf(), Leaf())) == set((Leaf(),)) - - def test_register_new_class(): """Test that new objects can be registered, flattened, and unflattened.""" @@ -58,7 +48,7 @@ def obj_unflatten(data, _): data, structure = flatten(obj) assert data == [0.5] - assert structure == Structure(MyObj, None, [leaf]) + assert structure == PyTreeStructure(MyObj, None, [leaf]) new_obj = unflatten([1.0], structure) assert isinstance(new_obj, MyObj) @@ -72,7 +62,9 @@ def test_list(): data, structure = flatten(x) assert data == [1, 2, 3, 4] - assert structure == Structure(list, None, [leaf, leaf, Structure(list, None, [leaf, leaf])]) + assert structure == PyTreeStructure( + list, None, [leaf, leaf, PyTreeStructure(list, None, [leaf, leaf])] + ) new_x = unflatten([5, 6, 7, 8], structure) assert new_x == [5, 6, [7, 8]] @@ -84,7 +76,9 @@ def test_tuple(): data, structure = flatten(x) assert data == [1, 2, 3, 4] - assert structure == Structure(tuple, None, [leaf, leaf, Structure(tuple, None, [leaf, leaf])]) + assert structure == PyTreeStructure( + tuple, None, [leaf, leaf, PyTreeStructure(tuple, None, [leaf, leaf])] + ) new_x = unflatten([5, 6, 7, 8], structure) assert new_x == (5, 6, (7, 8)) @@ -97,8 +91,8 @@ def test_dict(): data, structure = flatten(x) assert data == [1, 2, 3] - assert structure == Structure( - dict, ("a", "b"), [leaf, Structure(dict, ("c", "d"), [leaf, leaf])] + assert structure == PyTreeStructure( + dict, ("a", "b"), [leaf, PyTreeStructure(dict, ("c", "d"), [leaf, leaf])] ) new_x = unflatten([5, 6, 7], structure) assert new_x == {"a": 5, "b": {"c": 6, "d": 7}} @@ -118,15 +112,19 @@ def test_nested_pl_object(): assert data == [0.1, 2, None] wires0 = qml.wires.Wires(0) - op_structure = Structure(tape[0].__class__, (), [Structure(qml.RX, (wires0, ()), [leaf])]) - list_op_struct = Structure(list, None, [op_structure]) + op_structure = PyTreeStructure( + tape[0].__class__, (), [PyTreeStructure(qml.RX, (wires0, ()), [leaf])] + ) + list_op_struct = PyTreeStructure(list, None, [op_structure]) - sprod_structure = Structure(qml.ops.SProd, (), [leaf, Structure(qml.X, (wires0, ()), [])]) - meas_structure = Structure( + sprod_structure = PyTreeStructure( + qml.ops.SProd, (), [leaf, PyTreeStructure(qml.X, (wires0, ()), [])] + ) + meas_structure = PyTreeStructure( qml.measurements.ExpectationMP, (("wires", None),), [sprod_structure, leaf] ) - list_meas_struct = Structure(list, None, [meas_structure]) - tape_structure = Structure( + list_meas_struct = PyTreeStructure(list, None, [meas_structure]) + tape_structure = PyTreeStructure( qml.tape.QuantumScript, (tape.shots, tape.trainable_params), [list_op_struct, list_meas_struct], From c7ca2c17b4003ea77059aec23b4802f305ed0bba Mon Sep 17 00:00:00 2001 From: Jack Brown Date: Wed, 22 May 2024 14:08:17 -0400 Subject: [PATCH 05/28] pytree module --- pennylane/pytrees/__init__.py | 1 + pennylane/{ => pytrees}/pytrees.py | 0 pennylane/pytrees/serialization.py | 78 ++++++++++++++++++++++++++++++ 3 files changed, 79 insertions(+) create mode 100644 pennylane/pytrees/__init__.py rename pennylane/{ => pytrees}/pytrees.py (100%) create mode 100644 pennylane/pytrees/serialization.py diff --git a/pennylane/pytrees/__init__.py b/pennylane/pytrees/__init__.py new file mode 100644 index 00000000000..b98b371cb8e --- /dev/null +++ b/pennylane/pytrees/__init__.py @@ -0,0 +1 @@ +from .pytrees import is_pytree, leaf, PyTreeStructure, flatten, unflatten, register_pytree diff --git a/pennylane/pytrees.py b/pennylane/pytrees/pytrees.py similarity index 100% rename from pennylane/pytrees.py rename to pennylane/pytrees/pytrees.py diff --git a/pennylane/pytrees/serialization.py b/pennylane/pytrees/serialization.py new file mode 100644 index 00000000000..b27b82eaea1 --- /dev/null +++ b/pennylane/pytrees/serialization.py @@ -0,0 +1,78 @@ +import json +from typing import Any, Union, overload, Literal +from .pytrees import get_typename, Leaf, Structure, typenames_to_type +from pennylane.wires import Wires + + +@overload +def pytree_def_to_json(struct: Structure, *, pretty: bool = False, encode: Literal[True]) -> bytes: + ... + +@overload +def pytree_def_to_json(struct: Structure, *, pretty: bool = False, encode: Literal[False] = False) -> str: + ... + +def pytree_def_to_json(struct: Structure, *, pretty: bool = False, encode: bool = False) -> Union[bytes, str]: + jsoned = _jsonify_struct(struct) + + if pretty: + data = json.dumps(jsoned, indent=2, default=_json_default) + else: + data = json.dumps(jsoned, separators=(",", ":"), default=_json_default) + + if encode: + return data.encode('utf-8') + + return data + +def _jsonify_struct(root: Structure) -> list[Any]: + jsoned: list[Any] = [get_typename(root.type), root.metadata, list(root.children)] + + todo: list[list[Union[Structure, Leaf]]] = [jsoned[2]] + + while todo: + curr = todo.pop() + + for i in range(len(curr)): + child = curr[i] + if isinstance(child, Leaf): + curr[i] = "Leaf" + continue + + child_list = list(child.children) + curr[i] = [get_typename(child.type), child.metadata, child_list] + todo.append(child_list) + + return jsoned + +def pytree_def_from_json(data: str | bytes | bytearray) -> Structure: + jsoned = json.loads(data) + + root = Structure( + typenames_to_type[jsoned[0]], jsoned[1], jsoned[2] + ) + + todo: list[list[Any]] = [root.children] + + while todo: + curr = todo.pop() + + for i in range(len(curr)): + child = curr[i] + if child == "Leaf": + curr[i] = Leaf() + continue + + curr[i] = Structure( + typenames_to_type[child[0]], child[1], child[2] + ) + todo.append(child[2]) + + return root + + +def _json_default(o: Any): + if isinstance(o, Wires): + return o.tolist() + + raise TypeError From 45141b3834a15cd95e739aae5d5ba6d0396f86b8 Mon Sep 17 00:00:00 2001 From: Christina Lee Date: Thu, 23 May 2024 08:31:42 -0400 Subject: [PATCH 06/28] Update pennylane/pytrees.py Co-authored-by: Jack Brown --- pennylane/pytrees.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pennylane/pytrees.py b/pennylane/pytrees.py index c5ce6ae8e28..50033ea2f5a 100644 --- a/pennylane/pytrees.py +++ b/pennylane/pytrees.py @@ -114,7 +114,7 @@ def register_pytree(pytree_type: type, flatten_fn: FlattenFn, unflatten_fn: Unfl _register_pytree_with_jax(pytree_type, flatten_fn, unflatten_fn) -@dataclass(repr=False, frozen=True) +@dataclass(repr=False) class PyTreeStructure: """A pytree data structure, holding the type, metadata, and child pytree structures. From 8086b3f1a9b94d17c7da65c394bfc82a1fbdb1aa Mon Sep 17 00:00:00 2001 From: Christina Lee Date: Thu, 23 May 2024 08:32:00 -0400 Subject: [PATCH 07/28] Update pennylane/pytrees.py Co-authored-by: Jack Brown --- pennylane/pytrees.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/pennylane/pytrees.py b/pennylane/pytrees.py index 50033ea2f5a..59faa37a0bb 100644 --- a/pennylane/pytrees.py +++ b/pennylane/pytrees.py @@ -126,13 +126,13 @@ class PyTreeStructure: A leaf is defined as just a ``PyTreeStructure`` with ``type=None``. """ - type: Optional[type] + type: Optional[type] = None """The type corresponding to the node. If ``None``, then the structure is a leaf.""" - metadata: Metadata + metadata: Metadata = () """Any metadata needed to reproduce the original object.""" - children: list["PyTreeStructure"] + children: list["PyTreeStructure"] = field(default_factory=list) """The children of the pytree node. Can be either other structures or terminal leaves.""" def __repr__(self): From 521e7fd6db746ed4564bcaac264cb33ed02a8985 Mon Sep 17 00:00:00 2001 From: Christina Lee Date: Thu, 23 May 2024 08:32:37 -0400 Subject: [PATCH 08/28] Apply suggestions from code review Co-authored-by: Jack Brown --- pennylane/pytrees.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/pennylane/pytrees.py b/pennylane/pytrees.py index 59faa37a0bb..4f9cab9401c 100644 --- a/pennylane/pytrees.py +++ b/pennylane/pytrees.py @@ -134,7 +134,9 @@ class PyTreeStructure: children: list["PyTreeStructure"] = field(default_factory=list) """The children of the pytree node. Can be either other structures or terminal leaves.""" - + @property + def is_leaf(self) -> bool: + return self.type is None def __repr__(self): if self.type is None: return "Leaf" @@ -144,7 +146,7 @@ def __repr__(self): leaf = PyTreeStructure(None, (), []) -def flatten(obj) -> Tuple[List[Any], PyTreeStructure]: +def flatten(obj) -> tuple[list[Any], PyTreeStructure]: """Flattens a pytree into leaves and a structure. Args: From 450de944db92ffee1347cd080e9981efc909a03a Mon Sep 17 00:00:00 2001 From: albi3ro Date: Thu, 23 May 2024 10:19:24 -0400 Subject: [PATCH 09/28] change repr, add str --- pennylane/pytrees.py | 17 +++++++++++++---- tests/test_pytrees.py | 6 ++++-- 2 files changed, 17 insertions(+), 6 deletions(-) diff --git a/pennylane/pytrees.py b/pennylane/pytrees.py index 4f9cab9401c..1952d2d0d91 100644 --- a/pennylane/pytrees.py +++ b/pennylane/pytrees.py @@ -14,7 +14,7 @@ """ An internal module for working with pytrees. """ -from dataclasses import dataclass +from dataclasses import dataclass, field from typing import Any, Callable, List, Optional, Tuple has_jax = True @@ -134,13 +134,22 @@ class PyTreeStructure: children: list["PyTreeStructure"] = field(default_factory=list) """The children of the pytree node. Can be either other structures or terminal leaves.""" + @property def is_leaf(self) -> bool: + """Whether or not the structure is a leaf.""" return self.type is None + def __repr__(self): - if self.type is None: + if self.is_leaf: + return "PyTreeStructure()" + return f"PyTreeStructure({self.type.__name__}, {self.metadata}, {self.children})" + + def __str__(self): + if self.is_leaf: return "Leaf" - return f"PyTree({self.type.__name__}, {self.metadata}, {self.children})" + children_string = ", ".join(str(c) for c in self.children) + return f"PyTree({self.type.__name__}, {self.metadata}, [{children_string}])" leaf = PyTreeStructure(None, (), []) @@ -203,7 +212,7 @@ def unflatten(data: List[Any], structure: PyTreeStructure) -> Any: def _unflatten(new_data, structure): - if structure.type is None: # is leaf + if structure.is_leaf: return next(new_data) children = tuple(_unflatten(new_data, s) for s in structure.children) return unflatten_registrations[structure.type](children, structure.metadata) diff --git a/tests/test_pytrees.py b/tests/test_pytrees.py index 17016c2ad07..c2829c4385f 100644 --- a/tests/test_pytrees.py +++ b/tests/test_pytrees.py @@ -18,12 +18,14 @@ from pennylane.pytrees import PyTreeStructure, flatten, leaf, register_pytree, unflatten -def test_structure_repr(): +def test_structure_repr_str(): """Test the repr of the structure class.""" op = qml.RX(0.1, wires=0) _, structure = qml.pytrees.flatten(op) - expected = "PyTree(RX, (, ()), [Leaf])" + expected = "PyTreeStructure(RX, (, ()), [PyTreeStructure()])" assert repr(structure) == expected + expected_str = "PyTree(RX, (, ()), [Leaf])" + assert str(structure) == expected_str def test_register_new_class(): From 73ec5ed5081917bbebe27f51088df352ad281024 Mon Sep 17 00:00:00 2001 From: Jack Brown Date: Thu, 23 May 2024 10:45:56 -0400 Subject: [PATCH 10/28] add serialization --- pennylane/pytrees/__init__.py | 18 ++++- pennylane/pytrees/pytrees.py | 72 +++++++++++++++++--- pennylane/pytrees/serialization.py | 101 +++++++++++++++++++--------- pennylane/typing.py | 2 + tests/{ => pytrees}/test_pytrees.py | 0 tests/pytrees/test_serialization.py | 80 ++++++++++++++++++++++ 6 files changed, 234 insertions(+), 39 deletions(-) rename tests/{ => pytrees}/test_pytrees.py (100%) create mode 100644 tests/pytrees/test_serialization.py diff --git a/pennylane/pytrees/__init__.py b/pennylane/pytrees/__init__.py index b98b371cb8e..184b12752b8 100644 --- a/pennylane/pytrees/__init__.py +++ b/pennylane/pytrees/__init__.py @@ -1 +1,17 @@ -from .pytrees import is_pytree, leaf, PyTreeStructure, flatten, unflatten, register_pytree +from .pytrees import ( + PyTreeStructure, + flatten, + is_pytree, + leaf, + register_pytree, + unflatten, +) + +__all__ = [ + "PyTreeStructure", + "flatten", + "is_pytree", + "leaf", + "register_pytree", + "unflatten" +] diff --git a/pennylane/pytrees/pytrees.py b/pennylane/pytrees/pytrees.py index c5ce6ae8e28..ad00bdc9bcc 100644 --- a/pennylane/pytrees/pytrees.py +++ b/pennylane/pytrees/pytrees.py @@ -15,7 +15,7 @@ An internal module for working with pytrees. """ from dataclasses import dataclass -from typing import Any, Callable, List, Optional, Tuple +from typing import Any, Callable, Optional has_jax = True try: @@ -26,7 +26,7 @@ Leaves = Any Metadata = Any -FlattenFn = Callable[[Any], Tuple[Leaves, Metadata]] +FlattenFn = Callable[[Any], tuple[Leaves, Metadata]] UnflattenFn = Callable[[Leaves, Metadata], Any] @@ -73,10 +73,21 @@ def unflatten_dict(data, metadata) -> dict: dict: unflatten_dict, } +type_to_typenames: dict[type, str] = { + list: "builtins.list", + dict: "builtins.dict", + tuple: "builtins.tuple", +} + +typename_to_type: dict[str, type] = {name: type_ for type_, name in type_to_typenames.items()} + def _register_pytree_with_pennylane( - pytree_type: type, flatten_fn: FlattenFn, unflatten_fn: UnflattenFn + pytree_type: type, typename: str, flatten_fn: FlattenFn, unflatten_fn: UnflattenFn ): + type_to_typenames[pytree_type] = typename + typename_to_type[typename] = pytree_type + flatten_registrations[pytree_type] = flatten_fn unflatten_registrations[pytree_type] = unflatten_fn @@ -88,7 +99,13 @@ def jax_unflatten(aux, parameters): jax_tree_util.register_pytree_node(pytree_type, flatten_fn, jax_unflatten) -def register_pytree(pytree_type: type, flatten_fn: FlattenFn, unflatten_fn: UnflattenFn): +def register_pytree( + pytree_type: type, + flatten_fn: FlattenFn, + unflatten_fn: UnflattenFn, + *, + typename_prefix: str = "qml", +): """Register a type with all available pytree backends. Current backends are jax and pennylane. @@ -97,6 +114,7 @@ def register_pytree(pytree_type: type, flatten_fn: FlattenFn, unflatten_fn: Unfl pytree_type (type): the type to register, such as ``qml.RX`` flatten_fn (Callable): a function that splits an object into trainable leaves and hashable metadata. unflatten_fn (Callable): a function that reconstructs an object from its leaves and metadata. + typename_prefix (str): A prefix for the name under which this type will be registered. Returns: None @@ -108,12 +126,45 @@ def register_pytree(pytree_type: type, flatten_fn: FlattenFn, unflatten_fn: Unfl """ - _register_pytree_with_pennylane(pytree_type, flatten_fn, unflatten_fn) + typename = f"{typename_prefix}.{pytree_type.__qualname__}" + _register_pytree_with_pennylane(pytree_type, typename, flatten_fn, unflatten_fn) if has_jax: _register_pytree_with_jax(pytree_type, flatten_fn, unflatten_fn) +def is_pytree(type_: type[Any]) -> bool: + """Returns True if ``type_`` is a registered Pytree.""" + return type_ in type_to_typenames + + +def get_typename(pytree_type: type[Any]) -> str: + """Return the typename under which ``pytree_type`` + was registered. + + Raises: + TypeError: If ``pytree_type`` is not a Pytree""" + + try: + return type_to_typenames[pytree_type] + except KeyError as exc: + raise TypeError(f"{repr(pytree_type)} is not a Pytree type") from exc + + +def get_typename_type(typename: str) -> type[Any]: + """Return the Pytree type with given ``typename``. + + Raises: + ValueError: If ``typename`` is not the name of a + pytree type. + """ + + try: + return typename_to_type[typename] + except KeyError as exc: + raise ValueError(f"{repr(typename)} is not the name of a Pytree type.") from exc + + @dataclass(repr=False, frozen=True) class PyTreeStructure: """A pytree data structure, holding the type, metadata, and child pytree structures. @@ -126,7 +177,7 @@ class PyTreeStructure: A leaf is defined as just a ``PyTreeStructure`` with ``type=None``. """ - type: Optional[type] + type: Optional[type[Any]] """The type corresponding to the node. If ``None``, then the structure is a leaf.""" metadata: Metadata @@ -135,6 +186,11 @@ class PyTreeStructure: children: list["PyTreeStructure"] """The children of the pytree node. Can be either other structures or terminal leaves.""" + @property + def is_leaf(self) -> bool: + """Whether or not this represents a leaf.""" + return self.type is None + def __repr__(self): if self.type is None: return "Leaf" @@ -144,7 +200,7 @@ def __repr__(self): leaf = PyTreeStructure(None, (), []) -def flatten(obj) -> Tuple[List[Any], PyTreeStructure]: +def flatten(obj: Any) -> tuple[list[Any], PyTreeStructure]: """Flattens a pytree into leaves and a structure. Args: @@ -179,7 +235,7 @@ def flatten(obj) -> Tuple[List[Any], PyTreeStructure]: return flattened_leaves, structure -def unflatten(data: List[Any], structure: PyTreeStructure) -> Any: +def unflatten(data: list[Any], structure: PyTreeStructure) -> Any: """Bind data to a structure to reconstruct a pytree object. Args: diff --git a/pennylane/pytrees/serialization.py b/pennylane/pytrees/serialization.py index b27b82eaea1..76c38936ce6 100644 --- a/pennylane/pytrees/serialization.py +++ b/pennylane/pytrees/serialization.py @@ -1,56 +1,92 @@ import json -from typing import Any, Union, overload, Literal -from .pytrees import get_typename, Leaf, Structure, typenames_to_type +from collections.abc import Callable +from typing import Any, Literal, Optional, Union, overload + +from pennylane.typing import JSON from pennylane.wires import Wires +from .pytrees import PyTreeStructure, get_typename, get_typename_type, leaf -@overload -def pytree_def_to_json(struct: Structure, *, pretty: bool = False, encode: Literal[True]) -> bytes: - ... @overload -def pytree_def_to_json(struct: Structure, *, pretty: bool = False, encode: Literal[False] = False) -> str: - ... +def pytree_structure_dump_json( + root: PyTreeStructure, *, indent: Optional[int] = None, encode: Literal[True] +) -> bytes: ... -def pytree_def_to_json(struct: Structure, *, pretty: bool = False, encode: bool = False) -> Union[bytes, str]: - jsoned = _jsonify_struct(struct) - if pretty: - data = json.dumps(jsoned, indent=2, default=_json_default) +@overload +def pytree_structure_dump_json( + root: PyTreeStructure, *, indent: Optional[int] = None, encode: Literal[False] = False +) -> str: ... + + +def pytree_structure_dump_json( + root: PyTreeStructure, + *, + indent: Optional[int] = None, + encode: bool = False, + json_default: Optional[Callable[[Any], JSON]] = None, +) -> Union[bytes, str]: + """Convert Pytree structure ``root`` into JSON. + + Args: + root: Root of a Pytree structure + indent: If not None, the resulting JSON will be pretty-printed with the + given indent level. Otherwise, the output will use the most compact + possible representation + encode: Whether to return the output as bytes + + Returns: + bytes: If ``encode`` is True + str: If ``encode`` is False + + """ + jsoned = pytree_structure_dump(root) + + if indent: + data = json.dumps(jsoned, indent=indent, default=_json_default) else: data = json.dumps(jsoned, separators=(",", ":"), default=_json_default) if encode: - return data.encode('utf-8') - + return data.encode("utf-8") + return data -def _jsonify_struct(root: Structure) -> list[Any]: + +def pytree_structure_dump(root: PyTreeStructure) -> list[JSON]: + """Convert Pytree structure at ``root`` into a JSON-able representation.""" + if root.is_leaf: + raise ValueError("Cannot dump Pytree: root node may not be a leaf") + jsoned: list[Any] = [get_typename(root.type), root.metadata, list(root.children)] - todo: list[list[Union[Structure, Leaf]]] = [jsoned[2]] + todo: list[list[Union[PyTreeStructure, None]]] = [jsoned[2]] while todo: curr = todo.pop() for i in range(len(curr)): child = curr[i] - if isinstance(child, Leaf): - curr[i] = "Leaf" + if child.is_leaf: + curr[i] = None continue - + child_list = list(child.children) curr[i] = [get_typename(child.type), child.metadata, child_list] todo.append(child_list) return jsoned -def pytree_def_from_json(data: str | bytes | bytearray) -> Structure: - jsoned = json.loads(data) - root = Structure( - typenames_to_type[jsoned[0]], jsoned[1], jsoned[2] - ) +def pytree_structure_load(data: str | bytes | bytearray | list[JSON]) -> PyTreeStructure: + """Load a previously serialized Pytree structure.""" + if isinstance(data, (str, bytes, bytearray)): + jsoned = json.loads(data) + else: + jsoned = data + + root = PyTreeStructure(get_typename_type(jsoned[0]), jsoned[1], jsoned[2]) todo: list[list[Any]] = [root.children] @@ -59,20 +95,25 @@ def pytree_def_from_json(data: str | bytes | bytearray) -> Structure: for i in range(len(curr)): child = curr[i] - if child == "Leaf": - curr[i] = Leaf() + if child is None: + curr[i] = leaf continue - curr[i] = Structure( - typenames_to_type[child[0]], child[1], child[2] - ) + curr[i] = PyTreeStructure(get_typename_type(child[0]), child[1], child[2]) + todo.append(child[2]) return root -def _json_default(o: Any): +def _json_default(default: Callable[[Any], JSON]): + def default(o: Any): + if isinstance(o, Wires): + return o.tolist() + + return + if isinstance(o, Wires): return o.tolist() - + raise TypeError diff --git a/pennylane/typing.py b/pennylane/typing.py index e84ea280cea..0fb33a52e68 100644 --- a/pennylane/typing.py +++ b/pennylane/typing.py @@ -122,3 +122,5 @@ def _is_torch(other, subclass=False): Result = TypeVar("Result", Tuple, TensorLike) ResultBatch = Tuple[Result] + +JSON = Union[None, int, str, bool, list["JSON"], dict[str, "JSON"]] diff --git a/tests/test_pytrees.py b/tests/pytrees/test_pytrees.py similarity index 100% rename from tests/test_pytrees.py rename to tests/pytrees/test_pytrees.py diff --git a/tests/pytrees/test_serialization.py b/tests/pytrees/test_serialization.py new file mode 100644 index 00000000000..038c115612f --- /dev/null +++ b/tests/pytrees/test_serialization.py @@ -0,0 +1,80 @@ +import json + +import pytest + +from pennylane.pytrees import PyTreeStructure, flatten, leaf, register_pytree +from pennylane.pytrees.serialization import pytree_structure_dump, pytree_structure_load +from pennylane.wires import Wires + + +class CustomNode: + + def __init__(self, data, metadata): + self.data = data + self.metadata = metadata + + +def flatten_custom(node): + return (node.data, node.metadata) + + +def unflatten_custom(data, metadata): + return CustomNode(data, metadata) + + +register_pytree(CustomNode, flatten_custom, unflatten_custom, typename_prefix="test") + + +def test_structure_dump(): + _, struct = flatten( + { + "list": ["a", 1], + "dict": {"a": 1}, + "tuple": ("a", 1), + "custom": CustomNode([1, 5, 7], {"wires": Wires([1, "a", 3.4, None])}), + } + ) + + assert pytree_structure_dump(struct) == [ + "builtins.dict", + ("list", "dict", "tuple", "custom"), + [ + ["builtins.list", None, [None, None]], + ["builtins.dict", ("a",), [None]], + ["builtins.tuple", None, [None, None]], + ["test.CustomNode", {"wires": Wires([1, "a", 3.4, None])}, [None, None, None]], + ], + ] + + +@pytest.mark.parametrize("string", [True, False]) +def test_structure_load(string): + jsoned = [ + "builtins.dict", + ["list", "dict", "tuple", "custom"], + [ + ["builtins.list", None, [None, None]], + [ + "builtins.dict", + [ + "a", + ], + [None], + ], + ["builtins.tuple", None, [None, None]], + ["test.CustomNode", {"wires": [1, "a", 3.4, None]}, [None, None, None]], + ], + ] + if string: + jsoned = json.dumps(jsoned) + + assert pytree_structure_load(jsoned) == PyTreeStructure( + dict, + ["list", "dict", "tuple", "custom"], + [ + PyTreeStructure(list, None, [leaf, leaf]), + PyTreeStructure(dict, ["a"], [leaf]), + PyTreeStructure(tuple, None, [leaf, leaf]), + PyTreeStructure(CustomNode, {"wires": [1, "a", 3.4, None]}, [leaf, leaf, leaf]), + ], + ) From 0757766997bcf6ca2bcd62342515a2b7ad8bbf14 Mon Sep 17 00:00:00 2001 From: Jack Brown Date: Thu, 23 May 2024 14:20:38 -0400 Subject: [PATCH 11/28] tests --- pennylane/data/attributes/__init__.py | 2 + .../data/attributes/operator/operator.py | 140 ------------------ pennylane/data/attributes/pytree.py | 46 ++++++ pennylane/pytrees/__init__.py | 11 +- pennylane/pytrees/pytrees.py | 63 +++++--- pennylane/pytrees/serialization.py | 57 ++++--- tests/pytrees/test_serialization.py | 73 ++++++--- 7 files changed, 187 insertions(+), 205 deletions(-) create mode 100644 pennylane/data/attributes/pytree.py diff --git a/pennylane/data/attributes/__init__.py b/pennylane/data/attributes/__init__.py index 387403c7e16..58f025a4c5f 100644 --- a/pennylane/data/attributes/__init__.py +++ b/pennylane/data/attributes/__init__.py @@ -24,6 +24,7 @@ from .sparse_array import DatasetSparseArray from .string import DatasetString from .tuple import DatasetTuple +from .pytree import DatasetPytree __all__ = ( "DatasetArray", @@ -32,6 +33,7 @@ "DatasetDict", "DatasetList", "DatasetOperator", + "DatasetPytree", "DatasetSparseArray", "DatasetMolecule", "DatasetNone", diff --git a/pennylane/data/attributes/operator/operator.py b/pennylane/data/attributes/operator/operator.py index 6a5c79d5ebf..1ba936c3cf4 100644 --- a/pennylane/data/attributes/operator/operator.py +++ b/pennylane/data/attributes/operator/operator.py @@ -50,146 +50,6 @@ class DatasetOperator(Generic[Op], DatasetAttribute[HDF5Group, Op, Op]): type_id = "operator" - @classmethod - @lru_cache(1) - def consumes_types(cls) -> FrozenSet[Type[Operator]]: - return frozenset( - ( - # pennylane/operation/Tensor - Tensor, - # pennylane/ops/qubit/arithmetic_qml.py - qml.QubitCarry, - qml.QubitSum, - # pennylane/ops/qubit/hamiltonian.py - qml.ops.Hamiltonian, - # pennylane/ops/op_math/linear_combination.py - qml.ops.LinearCombination, - # pennylane/ops/op_math - prod.py, s_prod.py, sum.py - qml.ops.Prod, - qml.ops.SProd, - qml.ops.Sum, - # pennylane/ops/qubit/matrix_qml.py - qml.QubitUnitary, - qml.DiagonalQubitUnitary, - # pennylane/ops/qubit/non_parametric_qml.py - qml.Hadamard, - qml.PauliX, - qml.PauliY, - qml.PauliZ, - qml.X, - qml.Y, - qml.Z, - qml.T, - qml.S, - qml.SX, - qml.CNOT, - qml.CH, - qml.SWAP, - qml.ECR, - qml.SISWAP, - qml.CSWAP, - qml.CCZ, - qml.Toffoli, - qml.WireCut, - # pennylane/ops/qubit/observables.py - qml.Hermitian, - qml.Projector, - # pennylane/ops/qubit/parametric_ops_multi_qubit.py - qml.MultiRZ, - qml.IsingXX, - qml.IsingYY, - qml.IsingZZ, - qml.IsingXY, - qml.PSWAP, - qml.CPhaseShift00, - qml.CPhaseShift01, - qml.CPhaseShift10, - # pennylane/ops/qubit/parametric_ops_single_qubit.py - qml.RX, - qml.RY, - qml.RZ, - qml.PhaseShift, - qml.Rot, - qml.U1, - qml.U2, - qml.U3, - # pennylane/ops/qubit/qchem_ops.py - qml.SingleExcitation, - qml.SingleExcitationMinus, - qml.SingleExcitationPlus, - qml.DoubleExcitation, - qml.DoubleExcitationMinus, - qml.DoubleExcitationPlus, - qml.OrbitalRotation, - qml.FermionicSWAP, - # pennylane/ops/special_unitary.py - qml.SpecialUnitary, - # pennylane/ops/state_preparation.py - qml.BasisState, - qml.QubitStateVector, - qml.StatePrep, - qml.QubitDensityMatrix, - # pennylane/ops/qutrit/matrix_obs.py - qml.QutritUnitary, - # pennylane/ops/qutrit/non_parametric_qml.py - qml.TShift, - qml.TClock, - qml.TAdd, - qml.TSWAP, - # pennylane/ops/qutrit/observables.py - qml.THermitian, - # pennylane/ops/channel.py - qml.AmplitudeDamping, - qml.GeneralizedAmplitudeDamping, - qml.PhaseDamping, - qml.DepolarizingChannel, - qml.BitFlip, - qml.ResetError, - qml.PauliError, - qml.PhaseFlip, - qml.ThermalRelaxationError, - # pennylane/ops/cv.py - qml.Rotation, - qml.Squeezing, - qml.Displacement, - qml.Beamsplitter, - qml.TwoModeSqueezing, - qml.QuadraticPhase, - qml.ControlledAddition, - qml.ControlledPhase, - qml.Kerr, - qml.CrossKerr, - qml.InterferometerUnitary, - qml.CoherentState, - qml.SqueezedState, - qml.DisplacedSqueezedState, - qml.ThermalState, - qml.GaussianState, - qml.FockState, - qml.FockStateVector, - qml.FockDensityMatrix, - qml.CatState, - qml.NumberOperator, - qml.TensorN, - qml.QuadX, - qml.QuadP, - qml.QuadOperator, - qml.PolyXP, - qml.FockStateProjector, - # pennylane/ops/identity.py - qml.Identity, - # pennylane/ops/op_math/controlled_ops.py - qml.ControlledQubitUnitary, - qml.ControlledPhaseShift, - qml.CRX, - qml.CRY, - qml.CRZ, - qml.CRot, - qml.CZ, - qml.CY, - ) - ) - def value_to_hdf5(self, bind_parent: HDF5Group, key: str, value: Op) -> HDF5Group: return self._ops_to_hdf5(bind_parent, key, [value]) diff --git a/pennylane/data/attributes/pytree.py b/pennylane/data/attributes/pytree.py new file mode 100644 index 00000000000..a67c94e9c9c --- /dev/null +++ b/pennylane/data/attributes/pytree.py @@ -0,0 +1,46 @@ +from functools import lru_cache +from typing import Any, TypeVar + +import numpy as np + +from pennylane.data.base.attribute import DatasetAttribute +from pennylane.data.base.hdf5 import HDF5Group +from pennylane.data.base.mapper import AttributeTypeMapper +from pennylane.pytrees import flatten, list_pytree_types, serialization, unflatten + +T = TypeVar("T") + + +class DatasetPytree(DatasetAttribute[HDF5Group, T, T]): + """Attribute type for an object that can be converted to + a Pytree. This is the default serialization method for + all Pennylane Pytrees, including sublcasses of ``Operator``. + """ + + type_id = "pytree" + + @classmethod + @lru_cache(1) + def consumes_types(cls) -> frozenset[type[Any]]: + return frozenset(list_pytree_types("qml")) + + def hdf5_to_value(self, bind: HDF5Group) -> T: + mapper = AttributeTypeMapper(bind) + + return unflatten( + [mapper[str(i)].get_value() for i in range(bind["num_leaves"][()])], + serialization.pytree_structure_load(bind["treedef"][()].tobytes()), + ) + + def value_to_hdf5(self, bind_parent: HDF5Group, key: str, value: T) -> HDF5Group: + bind = bind_parent.create_group(key) + mapper = AttributeTypeMapper(bind) + + leaves, treedef = flatten(value) + + bind["treedef"] = np.void(serialization.pytree_structure_dump(treedef, encode=True)) + bind["num_leaves"] = len(leaves) + for i, leaf in enumerate(leaves): + mapper[str(i)] = leaf + + return bind diff --git a/pennylane/pytrees/__init__.py b/pennylane/pytrees/__init__.py index e24165ab4b4..35ec3059f65 100644 --- a/pennylane/pytrees/__init__.py +++ b/pennylane/pytrees/__init__.py @@ -5,6 +5,15 @@ leaf, register_pytree, unflatten, + list_pytree_types, ) -__all__ = ["PyTreeStructure", "flatten", "is_pytree", "leaf", "register_pytree", "unflatten"] +__all__ = [ + "PyTreeStructure", + "flatten", + "is_pytree", + "leaf", + "list_pytree_types", + "register_pytree", + "unflatten", +] diff --git a/pennylane/pytrees/pytrees.py b/pennylane/pytrees/pytrees.py index 52f1dc66e9a..9bdbaa39456 100644 --- a/pennylane/pytrees/pytrees.py +++ b/pennylane/pytrees/pytrees.py @@ -14,8 +14,9 @@ """ An internal module for working with pytrees. """ +from collections.abc import Callable, Iterator from dataclasses import dataclass, field -from typing import Any, Callable, Optional +from typing import Any, Optional has_jax = True try: @@ -73,19 +74,19 @@ def unflatten_dict(data, metadata) -> dict: dict: unflatten_dict, } -type_to_typenames: dict[type, str] = { +type_to_typename: dict[type, str] = { list: "builtins.list", dict: "builtins.dict", tuple: "builtins.tuple", } -typename_to_type: dict[str, type] = {name: type_ for type_, name in type_to_typenames.items()} +typename_to_type: dict[str, type] = {name: type_ for type_, name in type_to_typename.items()} def _register_pytree_with_pennylane( pytree_type: type, typename: str, flatten_fn: FlattenFn, unflatten_fn: UnflattenFn ): - type_to_typenames[pytree_type] = typename + type_to_typename[pytree_type] = typename typename_to_type[typename] = pytree_type flatten_registrations[pytree_type] = flatten_fn @@ -100,11 +101,7 @@ def jax_unflatten(aux, parameters): def register_pytree( - pytree_type: type, - flatten_fn: FlattenFn, - unflatten_fn: UnflattenFn, - *, - typename_prefix: str = "qml", + pytree_type: type, flatten_fn: FlattenFn, unflatten_fn: UnflattenFn, *, namespace: str = "qml" ): """Register a type with all available pytree backends. @@ -114,7 +111,7 @@ def register_pytree( pytree_type (type): the type to register, such as ``qml.RX`` flatten_fn (Callable): a function that splits an object into trainable leaves and hashable metadata. unflatten_fn (Callable): a function that reconstructs an object from its leaves and metadata. - typename_prefix (str): A prefix for the name under which this type will be registered. + namespace (str): A prefix for the name under which this type will be registered. Returns: None @@ -126,7 +123,7 @@ def register_pytree( """ - typename = f"{typename_prefix}.{pytree_type.__qualname__}" + typename = f"{namespace}.{pytree_type.__qualname__}" _register_pytree_with_pennylane(pytree_type, typename, flatten_fn, unflatten_fn) if has_jax: @@ -135,7 +132,7 @@ def register_pytree( def is_pytree(type_: type[Any]) -> bool: """Returns True if ``type_`` is a registered Pytree.""" - return type_ in type_to_typenames + return type_ in type_to_typename def get_typename(pytree_type: type[Any]) -> str: @@ -143,10 +140,17 @@ def get_typename(pytree_type: type[Any]) -> str: was registered. Raises: - TypeError: If ``pytree_type`` is not a Pytree""" + TypeError: If ``pytree_type`` is not a Pytree. + + >>> get_typename(list) + 'builtins.list' + >>> import pennylane + >>> get_typename(pennylane.PauliX) + 'qml.PauliX' + """ try: - return type_to_typenames[pytree_type] + return type_to_typename[pytree_type] except KeyError as exc: raise TypeError(f"{repr(pytree_type)} is not a Pytree type") from exc @@ -157,15 +161,32 @@ def get_typename_type(typename: str) -> type[Any]: Raises: ValueError: If ``typename`` is not the name of a pytree type. - """ + >>> import pennylane + >>> get_typename_type("builtins.list") + + >>> get_typename_type("qml.PauliX") + + """ try: return typename_to_type[typename] except KeyError as exc: raise ValueError(f"{repr(typename)} is not the name of a Pytree type.") from exc -type_ = type +def list_pytree_types(namespace: Optional[str] = "qml") -> Iterator[type]: + """Return an iterator listing all registered Pytree types under + the given ``namespace``. + """ + if namespace: + namespace_filter = f"{namespace}." + return ( + type_ + for type_, typename in type_to_typename.items() + if typename.startswith(namespace_filter) + ) + + return (type_ for type_ in type_to_typename) @dataclass(repr=False) @@ -180,7 +201,7 @@ class PyTreeStructure: A leaf is defined as just a ``PyTreeStructure`` with ``type=None``. """ - type: Optional[type_[Any]] = None + type_: Optional[type[Any]] = None """The type corresponding to the node. If ``None``, then the structure is a leaf.""" metadata: Metadata = () @@ -192,18 +213,18 @@ class PyTreeStructure: @property def is_leaf(self) -> bool: """Whether or not the structure is a leaf.""" - return self.type is None + return self.type_ is None def __repr__(self): if self.is_leaf: return "PyTreeStructure()" - return f"PyTreeStructure({self.type.__name__}, {self.metadata}, {self.children})" + return f"PyTreeStructure({self.type_.__name__}, {self.metadata}, {self.children})" def __str__(self): if self.is_leaf: return "Leaf" children_string = ", ".join(str(c) for c in self.children) - return f"PyTree({self.type.__name__}, {self.metadata}, [{children_string}])" + return f"PyTree({self.type_.__name__}, {self.metadata}, [{children_string}])" leaf = PyTreeStructure(None, (), []) @@ -269,4 +290,4 @@ def _unflatten(new_data, structure): if structure.is_leaf: return next(new_data) children = tuple(_unflatten(new_data, s) for s in structure.children) - return unflatten_registrations[structure.type](children, structure.metadata) + return unflatten_registrations[structure.type_](children, structure.metadata) diff --git a/pennylane/pytrees/serialization.py b/pennylane/pytrees/serialization.py index 76c38936ce6..671bf78805e 100644 --- a/pennylane/pytrees/serialization.py +++ b/pennylane/pytrees/serialization.py @@ -9,18 +9,18 @@ @overload -def pytree_structure_dump_json( +def pytree_structure_dump( root: PyTreeStructure, *, indent: Optional[int] = None, encode: Literal[True] ) -> bytes: ... @overload -def pytree_structure_dump_json( +def pytree_structure_dump( root: PyTreeStructure, *, indent: Optional[int] = None, encode: Literal[False] = False ) -> str: ... -def pytree_structure_dump_json( +def pytree_structure_dump( root: PyTreeStructure, *, indent: Optional[int] = None, @@ -41,12 +41,14 @@ def pytree_structure_dump_json( str: If ``encode`` is False """ - jsoned = pytree_structure_dump(root) - - if indent: - data = json.dumps(jsoned, indent=indent, default=_json_default) + jsoned = _jsonify_pytree_structure(root) + dump_args = {"indent": indent} if indent else {"separators": (",", ":")} + if json_default: + dump_args["default"] = _wrap_user_json_default(json_default) else: - data = json.dumps(jsoned, separators=(",", ":"), default=_json_default) + dump_args["default"] = _json_default + + data = json.dumps(jsoned, **dump_args) if encode: return data.encode("utf-8") @@ -54,12 +56,12 @@ def pytree_structure_dump_json( return data -def pytree_structure_dump(root: PyTreeStructure) -> list[JSON]: +def _jsonify_pytree_structure(root: PyTreeStructure) -> list[JSON]: """Convert Pytree structure at ``root`` into a JSON-able representation.""" if root.is_leaf: raise ValueError("Cannot dump Pytree: root node may not be a leaf") - jsoned: list[Any] = [get_typename(root.type), root.metadata, list(root.children)] + jsoned: list[Any] = [get_typename(root.type_), root.metadata, list(root.children)] todo: list[list[Union[PyTreeStructure, None]]] = [jsoned[2]] @@ -73,19 +75,16 @@ def pytree_structure_dump(root: PyTreeStructure) -> list[JSON]: continue child_list = list(child.children) - curr[i] = [get_typename(child.type), child.metadata, child_list] + curr[i] = [get_typename(child.type_), child.metadata, child_list] todo.append(child_list) return jsoned -def pytree_structure_load(data: str | bytes | bytearray | list[JSON]) -> PyTreeStructure: +def pytree_structure_load(data: str | bytes | bytearray) -> PyTreeStructure: """Load a previously serialized Pytree structure.""" - if isinstance(data, (str, bytes, bytearray)): - jsoned = json.loads(data) - else: - jsoned = data + jsoned = json.loads(data) root = PyTreeStructure(get_typename_type(jsoned[0]), jsoned[1], jsoned[2]) todo: list[list[Any]] = [root.children] @@ -106,14 +105,24 @@ def pytree_structure_load(data: str | bytes | bytearray | list[JSON]) -> PyTreeS return root -def _json_default(default: Callable[[Any], JSON]): - def default(o: Any): - if isinstance(o, Wires): - return o.tolist() +def _json_default(obj: Any) -> JSON: + """Default function for ``json.dump()``. Adds handling for the following types: + - ``pennylane.wires.Wires`` + """ + if isinstance(obj, Wires): + return obj.tolist() + + raise TypeError + - return +def _wrap_user_json_default(user_default: Callable[[Any], JSON]) -> Callable[[Any], JSON]: + """Wraps a user-provided JSON default function. If ``user_default`` raises a TypeError, + calls ``_json_default``.""" - if isinstance(o, Wires): - return o.tolist() + def _default_wrapped(obj: Any) -> JSON: + try: + return user_default(obj) + except TypeError: + return _json_default(obj) - raise TypeError + return _default_wrapped diff --git a/tests/pytrees/test_serialization.py b/tests/pytrees/test_serialization.py index 038c115612f..ad63a4e3f57 100644 --- a/tests/pytrees/test_serialization.py +++ b/tests/pytrees/test_serialization.py @@ -2,7 +2,15 @@ import pytest -from pennylane.pytrees import PyTreeStructure, flatten, leaf, register_pytree +from pennylane.ops import Hadamard, PauliX, Prod, Sum +from pennylane.pytrees import ( + PyTreeStructure, + flatten, + is_pytree, + leaf, + list_pytree_types, + register_pytree, +) from pennylane.pytrees.serialization import pytree_structure_dump, pytree_structure_load from pennylane.wires import Wires @@ -22,7 +30,29 @@ def unflatten_custom(data, metadata): return CustomNode(data, metadata) -register_pytree(CustomNode, flatten_custom, unflatten_custom, typename_prefix="test") +register_pytree(CustomNode, flatten_custom, unflatten_custom, namespace="test") + + +def test_list_pytree_types(): + """Test for ``list_pytree_types()``.""" + assert list(list_pytree_types("test")) == [CustomNode] + + +@pytest.mark.parametrize( + "cls, result", + [ + (CustomNode, True), + (list, True), + (tuple, True), + (Sum, True), + (Prod, True), + (PauliX, True), + (int, False), + ], +) +def test_is_pytree(cls, result): + """Tests for ``is_pytree()``.""" + assert is_pytree(cls) is result def test_structure_dump(): @@ -35,21 +65,7 @@ def test_structure_dump(): } ) - assert pytree_structure_dump(struct) == [ - "builtins.dict", - ("list", "dict", "tuple", "custom"), - [ - ["builtins.list", None, [None, None]], - ["builtins.dict", ("a",), [None]], - ["builtins.tuple", None, [None, None]], - ["test.CustomNode", {"wires": Wires([1, "a", 3.4, None])}, [None, None, None]], - ], - ] - - -@pytest.mark.parametrize("string", [True, False]) -def test_structure_load(string): - jsoned = [ + assert json.loads(pytree_structure_dump(struct)) == [ "builtins.dict", ["list", "dict", "tuple", "custom"], [ @@ -65,8 +81,27 @@ def test_structure_load(string): ["test.CustomNode", {"wires": [1, "a", 3.4, None]}, [None, None, None]], ], ] - if string: - jsoned = json.dumps(jsoned) + + +def test_structure_load(): + jsoned = json.dumps( + [ + "builtins.dict", + ["list", "dict", "tuple", "custom"], + [ + ["builtins.list", None, [None, None]], + [ + "builtins.dict", + [ + "a", + ], + [None], + ], + ["builtins.tuple", None, [None, None]], + ["test.CustomNode", {"wires": [1, "a", 3.4, None]}, [None, None, None]], + ], + ] + ) assert pytree_structure_load(jsoned) == PyTreeStructure( dict, From 5a6d0e54aefbdbff7c92096e432ca82f1dcaffda Mon Sep 17 00:00:00 2001 From: Jack Brown Date: Thu, 23 May 2024 15:34:38 -0400 Subject: [PATCH 12/28] tests, docs --- pennylane/pytrees/serialization.py | 75 ++++++++++++++++------------- tests/pytrees/test_serialization.py | 13 +++-- 2 files changed, 50 insertions(+), 38 deletions(-) diff --git a/pennylane/pytrees/serialization.py b/pennylane/pytrees/serialization.py index 671bf78805e..5a4914042ca 100644 --- a/pennylane/pytrees/serialization.py +++ b/pennylane/pytrees/serialization.py @@ -10,13 +10,13 @@ @overload def pytree_structure_dump( - root: PyTreeStructure, *, indent: Optional[int] = None, encode: Literal[True] + root: PyTreeStructure, *, indent: Optional[int] = None, decode: Literal[False] = False ) -> bytes: ... @overload def pytree_structure_dump( - root: PyTreeStructure, *, indent: Optional[int] = None, encode: Literal[False] = False + root: PyTreeStructure, *, indent: Optional[int] = None, decode: Literal[True] ) -> str: ... @@ -24,66 +24,66 @@ def pytree_structure_dump( root: PyTreeStructure, *, indent: Optional[int] = None, - encode: bool = False, + decode: bool = False, json_default: Optional[Callable[[Any], JSON]] = None, ) -> Union[bytes, str]: """Convert Pytree structure ``root`` into JSON. + A non-leaf structure is represented as a 3-element list. The first element will + be the type name, the second element metadata, and the third element is + the list of children. + + A leaf structure is represented by `null`. + + Metadata can only contain ``pennylane.Wires`` objects, JSON-serializable + data or objects that can be handled by ``json_default`` if provided. + + >>> from pennylane.pytrees import PyTreeStructure, leaf, flatten + >>> from pennylane.pytrees.serialization import pytree_structure_dump + + >>> _, struct = flatten([{"a": 1}, 2]) + >>> struct + 'PyTreeStructure(, None, [PyTreeStructure(, ("a",), [PyTreeStructure()]), PyTreeStructure()])' + + >>> pytree_structure_dump(struct) + b'["builtins.list",null,[["builtins.dict",["a"],[null]],null]' + Args: root: Root of a Pytree structure indent: If not None, the resulting JSON will be pretty-printed with the given indent level. Otherwise, the output will use the most compact possible representation - encode: Whether to return the output as bytes + decode: If True, return a string instead of bytes + json_default: Handler for objects that can't otherwise be serialized. Should + return a JSON-compatible value or raise a ``TypeError`` if the value + can't be handled Returns: bytes: If ``encode`` is True str: If ``encode`` is False - """ - jsoned = _jsonify_pytree_structure(root) dump_args = {"indent": indent} if indent else {"separators": (",", ":")} if json_default: dump_args["default"] = _wrap_user_json_default(json_default) else: dump_args["default"] = _json_default - data = json.dumps(jsoned, **dump_args) + data = json.dumps(root, **dump_args) - if encode: + if not decode: return data.encode("utf-8") return data -def _jsonify_pytree_structure(root: PyTreeStructure) -> list[JSON]: - """Convert Pytree structure at ``root`` into a JSON-able representation.""" - if root.is_leaf: - raise ValueError("Cannot dump Pytree: root node may not be a leaf") - - jsoned: list[Any] = [get_typename(root.type_), root.metadata, list(root.children)] - - todo: list[list[Union[PyTreeStructure, None]]] = [jsoned[2]] - - while todo: - curr = todo.pop() - - for i in range(len(curr)): - child = curr[i] - if child.is_leaf: - curr[i] = None - continue - - child_list = list(child.children) - curr[i] = [get_typename(child.type_), child.metadata, child_list] - todo.append(child_list) - - return jsoned - - def pytree_structure_load(data: str | bytes | bytearray) -> PyTreeStructure: - """Load a previously serialized Pytree structure.""" + """Load a previously serialized Pytree structure. + + >>> from pennylane.pytrees.serialization import pytree_structure_dump + >>> pytree_structure_load('["builtins.list",null,[["builtins.dict",["a"],[null]],null]') + 'PyTreeStructure(, None, [PyTreeStructure(, ["a"], [PyTreeStructure()]), PyTreeStructure()])' + """ jsoned = json.loads(data) root = PyTreeStructure(get_typename_type(jsoned[0]), jsoned[1], jsoned[2]) @@ -100,6 +100,7 @@ def pytree_structure_load(data: str | bytes | bytearray) -> PyTreeStructure: curr[i] = PyTreeStructure(get_typename_type(child[0]), child[1], child[2]) + # Child structures will be converted in place todo.append(child[2]) return root @@ -107,8 +108,14 @@ def pytree_structure_load(data: str | bytes | bytearray) -> PyTreeStructure: def _json_default(obj: Any) -> JSON: """Default function for ``json.dump()``. Adds handling for the following types: + - ``pennylane.pytrees.PyTreeStructure`` - ``pennylane.wires.Wires`` """ + if isinstance(obj, PyTreeStructure): + if obj.is_leaf: + return None + return [get_typename(obj.type_), obj.metadata, obj.children] + if isinstance(obj, Wires): return obj.tolist() diff --git a/tests/pytrees/test_serialization.py b/tests/pytrees/test_serialization.py index ad63a4e3f57..11d8c938b8f 100644 --- a/tests/pytrees/test_serialization.py +++ b/tests/pytrees/test_serialization.py @@ -2,7 +2,7 @@ import pytest -from pennylane.ops import Hadamard, PauliX, Prod, Sum +from pennylane.ops import PauliX, Prod, Sum from pennylane.pytrees import ( PyTreeStructure, flatten, @@ -16,6 +16,7 @@ class CustomNode: + """Example Pytree for testing.""" def __init__(self, data, metadata): self.data = data @@ -51,11 +52,14 @@ def test_list_pytree_types(): ], ) def test_is_pytree(cls, result): - """Tests for ``is_pytree()``.""" + """Test for ``is_pytree()``.""" assert is_pytree(cls) is result -def test_structure_dump(): +@pytest.mark.parametrize("decode", [True, False]) +def test_pytree_structure_dump(decode): + """Test that ``pytree_structure_dump()`` creates JSON in the expected + format.""" _, struct = flatten( { "list": ["a", 1], @@ -65,7 +69,7 @@ def test_structure_dump(): } ) - assert json.loads(pytree_structure_dump(struct)) == [ + assert json.loads(pytree_structure_dump(struct, decode=decode)) == [ "builtins.dict", ["list", "dict", "tuple", "custom"], [ @@ -84,6 +88,7 @@ def test_structure_dump(): def test_structure_load(): + """Test that ``pytree_structure_load()`` can parse a JSON-serialized PyTree.""" jsoned = json.dumps( [ "builtins.dict", From dcf1bb96c81be98f378e9891c12b657bdb9d5628 Mon Sep 17 00:00:00 2001 From: Jack Brown Date: Thu, 23 May 2024 17:17:06 -0400 Subject: [PATCH 13/28] tests --- pennylane/data/__init__.py | 2 + pennylane/data/attributes/__init__.py | 4 +- .../data/attributes/operator/operator.py | 145 +++++++++++++++++- pennylane/data/attributes/pytree.py | 9 +- pennylane/data/base/attribute.py | 3 + pennylane/measurements/shots.py | 12 +- pennylane/pytrees/pytrees.py | 5 +- pennylane/pytrees/serialization.py | 7 +- .../data/attributes/operator/test_operator.py | 110 +++++++------ tests/data/attributes/test_pytree.py | 88 +++++++++++ tests/pytrees/test_serialization.py | 38 ++++- 11 files changed, 357 insertions(+), 66 deletions(-) create mode 100644 tests/data/attributes/test_pytree.py diff --git a/pennylane/data/__init__.py b/pennylane/data/__init__.py index d9b300eb227..126c37dd31a 100644 --- a/pennylane/data/__init__.py +++ b/pennylane/data/__init__.py @@ -211,6 +211,7 @@ class QuantumOscillator(qml.data.Dataset, data_name="quantum_oscillator", identi DatasetSparseArray, DatasetString, DatasetTuple, + DatasetPyTree, ) from .base import DatasetNotWriteableError from .base.attribute import AttributeInfo, DatasetAttribute, attribute @@ -225,6 +226,7 @@ class QuantumOscillator(qml.data.Dataset, data_name="quantum_oscillator", identi "DatasetAttribute", "DatasetNotWriteableError", "DatasetArray", + "DatasetPyTree", "DatasetScalar", "DatasetString", "DatasetList", diff --git a/pennylane/data/attributes/__init__.py b/pennylane/data/attributes/__init__.py index 58f025a4c5f..a699d6b5a05 100644 --- a/pennylane/data/attributes/__init__.py +++ b/pennylane/data/attributes/__init__.py @@ -24,7 +24,7 @@ from .sparse_array import DatasetSparseArray from .string import DatasetString from .tuple import DatasetTuple -from .pytree import DatasetPytree +from .pytree import DatasetPyTree __all__ = ( "DatasetArray", @@ -33,7 +33,7 @@ "DatasetDict", "DatasetList", "DatasetOperator", - "DatasetPytree", + "DatasetPyTree", "DatasetSparseArray", "DatasetMolecule", "DatasetNone", diff --git a/pennylane/data/attributes/operator/operator.py b/pennylane/data/attributes/operator/operator.py index 1ba936c3cf4..76fca0681fa 100644 --- a/pennylane/data/attributes/operator/operator.py +++ b/pennylane/data/attributes/operator/operator.py @@ -48,6 +48,146 @@ class DatasetOperator(Generic[Op], DatasetAttribute[HDF5Group, Op, Op]): ``Hamiltonian`` and ``Tensor`` operators. """ + @classmethod + @lru_cache(1) + def supported_ops(cls) -> FrozenSet[Type[Operator]]: + return frozenset( + ( + # pennylane/operation/Tensor + Tensor, + # pennylane/ops/qubit/arithmetic_qml.py + qml.QubitCarry, + qml.QubitSum, + # pennylane/ops/qubit/hamiltonian.py + qml.ops.Hamiltonian, + # pennylane/ops/op_math/linear_combination.py + qml.ops.LinearCombination, + # pennylane/ops/op_math - prod.py, s_prod.py, sum.py + qml.ops.Prod, + qml.ops.SProd, + qml.ops.Sum, + # pennylane/ops/qubit/matrix_qml.py + qml.QubitUnitary, + qml.DiagonalQubitUnitary, + # pennylane/ops/qubit/non_parametric_qml.py + qml.Hadamard, + qml.PauliX, + qml.PauliY, + qml.PauliZ, + qml.X, + qml.Y, + qml.Z, + qml.T, + qml.S, + qml.SX, + qml.CNOT, + qml.CH, + qml.SWAP, + qml.ECR, + qml.SISWAP, + qml.CSWAP, + qml.CCZ, + qml.Toffoli, + qml.WireCut, + # pennylane/ops/qubit/observables.py + qml.Hermitian, + qml.Projector, + # pennylane/ops/qubit/parametric_ops_multi_qubit.py + qml.MultiRZ, + qml.IsingXX, + qml.IsingYY, + qml.IsingZZ, + qml.IsingXY, + qml.PSWAP, + qml.CPhaseShift00, + qml.CPhaseShift01, + qml.CPhaseShift10, + # pennylane/ops/qubit/parametric_ops_single_qubit.py + qml.RX, + qml.RY, + qml.RZ, + qml.PhaseShift, + qml.Rot, + qml.U1, + qml.U2, + qml.U3, + # pennylane/ops/qubit/qchem_ops.py + qml.SingleExcitation, + qml.SingleExcitationMinus, + qml.SingleExcitationPlus, + qml.DoubleExcitation, + qml.DoubleExcitationMinus, + qml.DoubleExcitationPlus, + qml.OrbitalRotation, + qml.FermionicSWAP, + # pennylane/ops/special_unitary.py + qml.SpecialUnitary, + # pennylane/ops/state_preparation.py + qml.BasisState, + qml.QubitStateVector, + qml.StatePrep, + qml.QubitDensityMatrix, + # pennylane/ops/qutrit/matrix_obs.py + qml.QutritUnitary, + # pennylane/ops/qutrit/non_parametric_qml.py + qml.TShift, + qml.TClock, + qml.TAdd, + qml.TSWAP, + # pennylane/ops/qutrit/observables.py + qml.THermitian, + # pennylane/ops/channel.py + qml.AmplitudeDamping, + qml.GeneralizedAmplitudeDamping, + qml.PhaseDamping, + qml.DepolarizingChannel, + qml.BitFlip, + qml.ResetError, + qml.PauliError, + qml.PhaseFlip, + qml.ThermalRelaxationError, + # pennylane/ops/cv.py + qml.Rotation, + qml.Squeezing, + qml.Displacement, + qml.Beamsplitter, + qml.TwoModeSqueezing, + qml.QuadraticPhase, + qml.ControlledAddition, + qml.ControlledPhase, + qml.Kerr, + qml.CrossKerr, + qml.InterferometerUnitary, + qml.CoherentState, + qml.SqueezedState, + qml.DisplacedSqueezedState, + qml.ThermalState, + qml.GaussianState, + qml.FockState, + qml.FockStateVector, + qml.FockDensityMatrix, + qml.CatState, + qml.NumberOperator, + qml.TensorN, + qml.QuadX, + qml.QuadP, + qml.QuadOperator, + qml.PolyXP, + qml.FockStateProjector, + # pennylane/ops/identity.py + qml.Identity, + # pennylane/ops/op_math/controlled_ops.py + qml.ControlledQubitUnitary, + qml.ControlledPhaseShift, + qml.CRX, + qml.CRY, + qml.CRZ, + qml.CRot, + qml.CZ, + qml.CY, + ) + ) + type_id = "operator" def value_to_hdf5(self, bind_parent: HDF5Group, key: str, value: Op) -> HDF5Group: @@ -74,7 +214,7 @@ def _ops_to_hdf5( op_key = f"op_{i}" if isinstance(op, (qml.ops.Prod, qml.ops.SProd, qml.ops.Sum)): op = op.simplify() - if type(op) not in self.consumes_types(): + if type(op) not in self.supported_ops(): raise TypeError( f"Serialization of operator type '{type(op).__name__}' is not supported." ) @@ -114,6 +254,7 @@ def _hdf5_to_ops(self, bind: HDF5Group) -> List[Operator]: wires_bind = bind["op_wire_labels"] op_class_names = [] if names_bind.shape == (0,) else names_bind.asstr() op_wire_labels = [] if wires_bind.shape == (0,) else wires_bind.asstr() + with qml.QueuingManager.stop_recording(): for i, op_class_name in enumerate(op_class_names): op_key = f"op_{i}" @@ -153,4 +294,4 @@ def _hdf5_to_ops(self, bind: HDF5Group) -> List[Operator]: @lru_cache(1) def _supported_ops_dict(cls) -> Dict[str, Type[Operator]]: """Returns a dict mapping ``Operator`` subclass names to the class.""" - return {op.__name__: op for op in cls.consumes_types()} + return {op.__name__: op for op in cls.supported_ops()} diff --git a/pennylane/data/attributes/pytree.py b/pennylane/data/attributes/pytree.py index a67c94e9c9c..96df3e485fa 100644 --- a/pennylane/data/attributes/pytree.py +++ b/pennylane/data/attributes/pytree.py @@ -11,7 +11,7 @@ T = TypeVar("T") -class DatasetPytree(DatasetAttribute[HDF5Group, T, T]): +class DatasetPyTree(DatasetAttribute[HDF5Group, T, T]): """Attribute type for an object that can be converted to a Pytree. This is the default serialization method for all Pennylane Pytrees, including sublcasses of ``Operator``. @@ -19,11 +19,6 @@ class DatasetPytree(DatasetAttribute[HDF5Group, T, T]): type_id = "pytree" - @classmethod - @lru_cache(1) - def consumes_types(cls) -> frozenset[type[Any]]: - return frozenset(list_pytree_types("qml")) - def hdf5_to_value(self, bind: HDF5Group) -> T: mapper = AttributeTypeMapper(bind) @@ -38,7 +33,7 @@ def value_to_hdf5(self, bind_parent: HDF5Group, key: str, value: T) -> HDF5Group leaves, treedef = flatten(value) - bind["treedef"] = np.void(serialization.pytree_structure_dump(treedef, encode=True)) + bind["treedef"] = np.void(serialization.pytree_structure_dump(treedef, decode=False)) bind["num_leaves"] = len(leaves) for i, leaf in enumerate(leaves): mapper[str(i)] = leaf diff --git a/pennylane/data/base/attribute.py b/pennylane/data/base/attribute.py index 06b4a19c4c7..2196ca5098f 100644 --- a/pennylane/data/base/attribute.py +++ b/pennylane/data/base/attribute.py @@ -38,6 +38,7 @@ from pennylane.data.base import hdf5 from pennylane.data.base.hdf5 import HDF5, HDF5Any, HDF5Group from pennylane.data.base.typing_util import UNSET, get_type, get_type_str +from pennylane.pytrees import is_pytree T = TypeVar("T") @@ -492,5 +493,7 @@ def match_obj_type( ret = DatasetAttribute.registry["list"] elif issubclass(type_, Mapping): ret = DatasetAttribute.registry["dict"] + elif is_pytree(type_): + ret = DatasetAttribute.registry["pytree"] return ret diff --git a/pennylane/measurements/shots.py b/pennylane/measurements/shots.py index 51d66a41b03..77ab2cf9e2d 100644 --- a/pennylane/measurements/shots.py +++ b/pennylane/measurements/shots.py @@ -11,8 +11,10 @@ # See the License for the specific language governing permissions and # limitations under the License. """This module contains the Shots class to hold shot-related information.""" +from collections.abc import Sequence + # pylint:disable=inconsistent-return-statements -from typing import NamedTuple, Sequence, Tuple +from typing import NamedTuple class ShotCopies(NamedTuple): @@ -39,7 +41,7 @@ def valid_int(s): def valid_tuple(s): """Returns True if s is a tuple of the form (shots, copies).""" - return isinstance(s, tuple) and len(s) == 2 and valid_int(s[0]) and valid_int(s[1]) + return isinstance(s, Sequence) and len(s) == 2 and valid_int(s[0]) and valid_int(s[1]) class Shots: @@ -136,7 +138,7 @@ class Shots: total_shots: int = None """The total number of shots to be executed.""" - shot_vector: Tuple[ShotCopies] = None + shot_vector: tuple[ShotCopies] = None """The tuple of :class:`~ShotCopies` to be executed. Each element is of the form ``(shots, copies)``.""" _SHOT_ERROR = ValueError( @@ -167,7 +169,7 @@ def __init__(self, shots=None): elif isinstance(shots, Sequence): if not all(valid_int(s) or valid_tuple(s) for s in shots): raise self._SHOT_ERROR - self.__all_tuple_init__([s if isinstance(s, tuple) else (s, 1) for s in shots]) + self.__all_tuple_init__([s if isinstance(s, Sequence) else (s, 1) for s in shots]) elif isinstance(shots, self.__class__): return # self already _is_ shots as defined by __new__ else: @@ -211,7 +213,7 @@ def __iter__(self): for _ in range(shot_copy.copies): yield shot_copy.shots - def __all_tuple_init__(self, shots: Sequence[Tuple]): + def __all_tuple_init__(self, shots: Sequence[tuple]): res = [] total_shots = 0 current_shots, current_copies = shots[0] diff --git a/pennylane/pytrees/pytrees.py b/pennylane/pytrees/pytrees.py index 9bdbaa39456..d14a0b2bc44 100644 --- a/pennylane/pytrees/pytrees.py +++ b/pennylane/pytrees/pytrees.py @@ -18,6 +18,8 @@ from dataclasses import dataclass, field from typing import Any, Optional +import pennylane.queuing + has_jax = True try: import jax.tree_util as jax_tree_util @@ -283,7 +285,8 @@ def unflatten(data: list[Any], structure: PyTreeStructure) -> Any: Adjoint(Rot(-2, -3, -4, wires=[0])) """ - return _unflatten(iter(data), structure) + with pennylane.queuing.QueuingManager.stop_recording(): + return _unflatten(iter(data), structure) def _unflatten(new_data, structure): diff --git a/pennylane/pytrees/serialization.py b/pennylane/pytrees/serialization.py index 5a4914042ca..3bb5380c776 100644 --- a/pennylane/pytrees/serialization.py +++ b/pennylane/pytrees/serialization.py @@ -2,6 +2,7 @@ from collections.abc import Callable from typing import Any, Literal, Optional, Union, overload +from pennylane.measurements.shots import Shots from pennylane.typing import JSON from pennylane.wires import Wires @@ -110,6 +111,7 @@ def _json_default(obj: Any) -> JSON: """Default function for ``json.dump()``. Adds handling for the following types: - ``pennylane.pytrees.PyTreeStructure`` - ``pennylane.wires.Wires`` + - ``pennylane.measurements.shots.Shots`` """ if isinstance(obj, PyTreeStructure): if obj.is_leaf: @@ -119,7 +121,10 @@ def _json_default(obj: Any) -> JSON: if isinstance(obj, Wires): return obj.tolist() - raise TypeError + if isinstance(obj, Shots): + return obj.shot_vector + + raise TypeError(obj) def _wrap_user_json_default(user_default: Callable[[Any], JSON]) -> Callable[[Any], JSON]: diff --git a/tests/data/attributes/operator/test_operator.py b/tests/data/attributes/operator/test_operator.py index 0f7d9db085d..005389d64a2 100644 --- a/tests/data/attributes/operator/test_operator.py +++ b/tests/data/attributes/operator/test_operator.py @@ -12,7 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. """ -Tests for the ``DatasetOperator`` attribute type. +Tests for the serializing operators using the ``DatasetOperator`` and ``DatasetPyTree`` +attribute types. """ import itertools @@ -21,7 +22,7 @@ import pytest import pennylane as qml -from pennylane.data.attributes import DatasetOperator +from pennylane.data.attributes import DatasetOperator, DatasetPyTree from pennylane.data.base.typing_util import get_type_str from pennylane.operation import Operator, Tensor @@ -75,37 +76,38 @@ tensors = [Tensor(qml.PauliX(1), qml.PauliY(2))] +@pytest.mark.parametrize("attribute_cls", [DatasetOperator, DatasetPyTree]) @pytest.mark.parametrize("obs_in", [*hermitian_ops, *pauli_ops, *identity, *hamiltonians, *tensors]) class TestDatasetOperatorObservable: """Tests serializing Observable operators using the ``compare()`` method.""" - def test_value_init(self, obs_in): + def test_value_init(self, attribute_cls, obs_in): """Test that a DatasetOperator can be value-initialized from an observable, and that the deserialized operator is equivalent.""" if not qml.operation.active_new_opmath() and isinstance(obs_in, qml.ops.LinearCombination): obs_in = qml.operation.convert_to_legacy_H(obs_in) - dset_op = DatasetOperator(obs_in) + dset_op = attribute_cls(obs_in) - assert dset_op.info["type_id"] == "operator" + assert dset_op.info["type_id"] == attribute_cls.type_id assert dset_op.info["py_type"] == get_type_str(type(obs_in)) obs_out = dset_op.get_value() assert qml.equal(obs_out, obs_in) assert obs_in.compare(obs_out) - def test_bind_init(self, obs_in): + def test_bind_init(self, attribute_cls, obs_in): """Test that DatasetOperator can be initialized from a HDF5 group that contains a operator attribute.""" if not qml.operation.active_new_opmath() and isinstance(obs_in, qml.ops.LinearCombination): obs_in = qml.operation.convert_to_legacy_H(obs_in) - bind = DatasetOperator(obs_in).bind + bind = attribute_cls(obs_in).bind - dset_op = DatasetOperator(bind=bind) + dset_op = attribute_cls(bind=bind) - assert dset_op.info["type_id"] == "operator" + assert dset_op.info["type_id"] == attribute_cls.type_id assert dset_op.info["py_type"] == get_type_str(type(obs_in)) obs_out = dset_op.get_value() @@ -113,6 +115,7 @@ def test_bind_init(self, obs_in): assert obs_in.compare(obs_out) +@pytest.mark.parametrize("attribute_cls", [DatasetOperator, DatasetPyTree]) @pytest.mark.parametrize( "obs_in", [ @@ -126,38 +129,39 @@ def test_bind_init(self, obs_in): class TestDatasetArithmeticOperators: """Tests serializing Observable operators using the ``qml.equal()`` method.""" - def test_value_init(self, obs_in): + def test_value_init(self, attribute_cls, obs_in): """Test that a DatasetOperator can be value-initialized from an observable, and that the deserialized operator is equivalent.""" if not qml.operation.active_new_opmath() and isinstance(obs_in, qml.ops.LinearCombination): obs_in = qml.operation.convert_to_legacy_H(obs_in) - dset_op = DatasetOperator(obs_in) + dset_op = attribute_cls(obs_in) - assert dset_op.info["type_id"] == "operator" + assert dset_op.info["type_id"] == attribute_cls.type_id assert dset_op.info["py_type"] == get_type_str(type(obs_in)) obs_out = dset_op.get_value() assert qml.equal(obs_out, obs_in) - def test_bind_init(self, obs_in): + def test_bind_init(self, attribute_cls, obs_in): """Test that DatasetOperator can be initialized from a HDF5 group that contains an operator attribute.""" if not qml.operation.active_new_opmath() and isinstance(obs_in, qml.ops.LinearCombination): obs_in = qml.operation.convert_to_legacy_H(obs_in) - bind = DatasetOperator(obs_in).bind + bind = attribute_cls(obs_in).bind - dset_op = DatasetOperator(bind=bind) + dset_op = attribute_cls(bind=bind) - assert dset_op.info["type_id"] == "operator" + assert dset_op.info["type_id"] == attribute_cls.type_id assert dset_op.info["py_type"] == get_type_str(type(obs_in)) obs_out = dset_op.get_value() assert qml.equal(obs_out, obs_in) +@pytest.mark.parametrize("attribute_cls", [DatasetOperator, DatasetPyTree]) class TestDatasetOperator: @pytest.mark.parametrize( "op_in", @@ -168,37 +172,22 @@ class TestDatasetOperator: qml.Hamiltonian([], []), ], ) - def test_value_init(self, op_in): + def test_value_init(self, attribute_cls, op_in): """Test that a DatasetOperator can be value-initialized from an operator, and that the deserialized operator is equivalent.""" if not qml.operation.active_new_opmath() and isinstance(op_in, qml.ops.LinearCombination): op_in = qml.operation.convert_to_legacy_H(op_in) - dset_op = DatasetOperator(op_in) + dset_op = attribute_cls(op_in) - assert dset_op.info["type_id"] == "operator" + assert dset_op.info["type_id"] == attribute_cls.type_id assert dset_op.info["py_type"] == get_type_str(type(op_in)) op_out = dset_op.get_value() assert repr(op_out) == repr(op_in) assert op_in.data == op_out.data - def test_value_init_not_supported(self): - """Test that a ValueError is raised if attempting to serialize an unsupported operator.""" - - class NotSupported( - Operator - ): # pylint: disable=too-few-public-methods, unnecessary-ellipsis - """An operator.""" - - ... - - with pytest.raises( - TypeError, match="Serialization of operator type 'NotSupported' is not supported" - ): - DatasetOperator(NotSupported(1)) - @pytest.mark.parametrize( "op_in", [ @@ -208,18 +197,18 @@ class NotSupported( qml.Hamiltonian([], []), ], ) - def test_bind_init(self, op_in): + def test_bind_init(self, attribute_cls, op_in): """Test that a DatasetOperator can be bind-initialized from an operator, and that the deserialized operator is equivalent.""" if not qml.operation.active_new_opmath() and isinstance(op_in, qml.ops.LinearCombination): op_in = qml.operation.convert_to_legacy_H(op_in) - bind = DatasetOperator(op_in).bind + bind = attribute_cls(op_in).bind - dset_op = DatasetOperator(bind=bind) + dset_op = attribute_cls(bind=bind) - assert dset_op.info["type_id"] == "operator" + assert dset_op.info["type_id"] == attribute_cls.type_id assert dset_op.info["py_type"] == get_type_str(type(op_in)) op_out = dset_op.get_value() @@ -228,15 +217,42 @@ def test_bind_init(self, op_in): assert op_in.wires == op_out.wires assert repr(op_in) == repr(op_out) - def test_op_not_queued_on_deserialization(self): - """Tests that ops are not queued upon deserialization.""" - d = qml.data.Dataset(op=qml.PauliX(0)) - with qml.queuing.AnnotatedQueue() as q: - _ = d.op - assert len(q) == 0 +@pytest.mark.parametrize("attribute_cls", [DatasetOperator, DatasetPyTree]) +def test_op_not_queued_on_deserialization(attribute_cls): + """Tests that ops are not queued upon deserialization.""" + d = qml.data.Dataset(op=attribute_cls(qml.PauliX(0))) + with qml.queuing.AnnotatedQueue() as q: + _ = d.op + + assert len(q) == 0 + + with qml.queuing.AnnotatedQueue() as q2: + qml.apply(d.op) + + assert len(q2) == 1 + + +def test_consumed_by_pytree(): + """Test that ops are consumed by the ``DatasetPyTree`` type by default.""" + + d = qml.data.Dataset() + + d.op = qml.PauliX(0) + + assert isinstance(d.attrs["op"], DatasetPyTree) + + +def test_value_init_not_supported(): + """Test that a ValueError is raised if attempting to serialize an unsupported operator + using the ``DatasetOperator`` attribute type.""" + + class NotSupported(Operator): # pylint: disable=too-few-public-methods, unnecessary-ellipsis + """An operator.""" - with qml.queuing.AnnotatedQueue() as q2: - qml.apply(d.op) + ... - assert len(q2) == 1 + with pytest.raises( + TypeError, match="Serialization of operator type 'NotSupported' is not supported" + ): + DatasetOperator(NotSupported(1)) diff --git a/tests/data/attributes/test_pytree.py b/tests/data/attributes/test_pytree.py new file mode 100644 index 00000000000..7276ad66643 --- /dev/null +++ b/tests/data/attributes/test_pytree.py @@ -0,0 +1,88 @@ +from dataclasses import dataclass + +import pytest + +from pennylane.data import Dataset, DatasetPyTree +from pennylane.pytrees import register_pytree +from pennylane.pytrees.pytrees import ( + flatten_registrations, + type_to_typename, + typename_to_type, + unflatten_registrations, +) + + +@dataclass +class CustomNode: + """Example Pytree for testing.""" + + def __init__(self, data, metadata): + self.data = data + self.metadata = metadata + + +def flatten_custom(node): + return (node.data, node.metadata) + + +def unflatten_custom(data, metadata): + return CustomNode(data, metadata) + + +@pytest.fixture(autouse=True) +def register_test_node(): + """Fixture that temporarily registers the ``CustomNode`` class as + a Pytree.""" + register_pytree(CustomNode, flatten_custom, unflatten_custom) + + yield + + del flatten_registrations[CustomNode] + del unflatten_registrations[CustomNode] + del typename_to_type[type_to_typename[CustomNode]] + del type_to_typename[CustomNode] + + +class TestDatasetPyTree: + """Tests for ``DatasetPyTree``.""" + + def test_consumes_type(self): + """Test that PyTree-compatible types that is not a builtin are + consumed by ``DatasetPyTree``.""" + dset = Dataset() + dset.attr = CustomNode([1, 2, 3, 4], {"meta": "data"}) + + assert isinstance(dset.attrs["attr"], DatasetPyTree) + + @pytest.mark.parametrize("obj", [[1, 2], {"a": 1}, (1, 2)]) + def test_builtins_not_consumed(self, obj): + """Test that built-in containers like dict, list and tuple are + not consumed by the ``DatasetPyTree`` type.""" + + dset = Dataset() + dset.attr = obj + + assert not isinstance(dset.attrs["attr"], DatasetPyTree) + + def test_value_init(self): + """Test that ``DatasetPyTree`` can be initialized from a value.""" + + value = CustomNode( + [{"a": 1}, (3, 5), [7, 9, {"x": CustomNode("data", None)}]], {"meta": "data"} + ) + attr = DatasetPyTree(value) + + assert attr.type_id == "pytree" + assert attr.get_value() == value + + def test_bind_init(self): + """Test that a ``DatasetPyTree`` can be bind-initialized.""" + + value = CustomNode( + [{"a": 1}, (3, 5), [7, 9, {"x": CustomNode("data", None)}]], {"meta": "data"} + ) + bind = DatasetPyTree(value).bind + + attr = DatasetPyTree(bind=bind) + + assert attr == value diff --git a/tests/pytrees/test_serialization.py b/tests/pytrees/test_serialization.py index 11d8c938b8f..271cd39de55 100644 --- a/tests/pytrees/test_serialization.py +++ b/tests/pytrees/test_serialization.py @@ -2,6 +2,8 @@ import pytest +import pennylane as qml +from pennylane.measurements.shots import Shots from pennylane.ops import PauliX, Prod, Sum from pennylane.pytrees import ( PyTreeStructure, @@ -10,6 +12,13 @@ leaf, list_pytree_types, register_pytree, + unflatten, +) +from pennylane.pytrees.pytrees import ( + flatten_registrations, + type_to_typename, + typename_to_type, + unflatten_registrations, ) from pennylane.pytrees.serialization import pytree_structure_dump, pytree_structure_load from pennylane.wires import Wires @@ -31,7 +40,18 @@ def unflatten_custom(data, metadata): return CustomNode(data, metadata) -register_pytree(CustomNode, flatten_custom, unflatten_custom, namespace="test") +@pytest.fixture(autouse=True) +def register_test_node(): + """Fixture that temporarily registers the ``CustomNode`` class as + a Pytree.""" + register_pytree(CustomNode, flatten_custom, unflatten_custom, namespace="test") + + yield + + del flatten_registrations[CustomNode] + del unflatten_registrations[CustomNode] + del typename_to_type[type_to_typename[CustomNode]] + del type_to_typename[CustomNode] def test_list_pytree_types(): @@ -118,3 +138,19 @@ def test_structure_load(): PyTreeStructure(CustomNode, {"wires": [1, "a", 3.4, None]}, [leaf, leaf, leaf]), ], ) + + +def test_nested_pl_object_roundtrip(): + tape_in = qml.tape.QuantumScript( + [qml.adjoint(qml.RX(0.1, wires=0))], + [qml.expval(2 * qml.X(0))], + shots=50, + trainable_params=(0, 1), + ) + + data, struct = flatten(tape_in) + tape_out = unflatten(data, pytree_structure_load(pytree_structure_dump(struct))) + + assert type(tape_out) == type(tape_in) + assert repr(tape_out) == repr(tape_in) + assert list(tape_out) == list(tape_in) From fafc2fe253fff9dee3e1bbfb02c1a0f68920e186 Mon Sep 17 00:00:00 2001 From: Jack Brown Date: Thu, 23 May 2024 17:22:23 -0400 Subject: [PATCH 14/28] update changelog --- doc/releases/changelog-dev.md | 3 +++ 1 file changed, 3 insertions(+) diff --git a/doc/releases/changelog-dev.md b/doc/releases/changelog-dev.md index 81e843a579a..4c22252c500 100644 --- a/doc/releases/changelog-dev.md +++ b/doc/releases/changelog-dev.md @@ -93,6 +93,9 @@ * Empty initialization of `PauliVSpace` is permitted. [(#5675)](https://github.com/PennyLaneAI/pennylane/pull/5675) +* The `qml.data` module can now serialize pytrees + [(#5732)](https://github.com/PennyLaneAI/pennylane/pull/5732) +

Community contributions 🥳

* Implemented kwargs (`check_interface`, `check_trainability`, `rtol` and `atol`) support in `qml.equal` for the operators `Pow`, `Adjoint`, `Exp`, and `SProd`. From 73d162615626ac414191185c0f5006bbd0da50c4 Mon Sep 17 00:00:00 2001 From: Jack Brown Date: Fri, 24 May 2024 11:21:06 -0400 Subject: [PATCH 15/28] refactor json handling --- pennylane/pytrees/serialization.py | 69 +++++++++++++----------------- 1 file changed, 30 insertions(+), 39 deletions(-) diff --git a/pennylane/pytrees/serialization.py b/pennylane/pytrees/serialization.py index 3bb5380c776..0b19f078e8a 100644 --- a/pennylane/pytrees/serialization.py +++ b/pennylane/pytrees/serialization.py @@ -22,11 +22,7 @@ def pytree_structure_dump( def pytree_structure_dump( - root: PyTreeStructure, - *, - indent: Optional[int] = None, - decode: bool = False, - json_default: Optional[Callable[[Any], JSON]] = None, + root: PyTreeStructure, *, indent: Optional[int] = None, decode: bool = False ) -> Union[bytes, str]: """Convert Pytree structure ``root`` into JSON. @@ -36,8 +32,8 @@ def pytree_structure_dump( A leaf structure is represented by `null`. - Metadata can only contain ``pennylane.Wires`` objects, JSON-serializable - data or objects that can be handled by ``json_default`` if provided. + Metadata may contain ``pennylane.Shots`` and ``pennylane.Wires`` objects, + as well as any JSON-serializable data. >>> from pennylane.pytrees import PyTreeStructure, leaf, flatten >>> from pennylane.pytrees.serialization import pytree_structure_dump @@ -55,21 +51,14 @@ def pytree_structure_dump( given indent level. Otherwise, the output will use the most compact possible representation decode: If True, return a string instead of bytes - json_default: Handler for objects that can't otherwise be serialized. Should - return a JSON-compatible value or raise a ``TypeError`` if the value - can't be handled Returns: bytes: If ``encode`` is True str: If ``encode`` is False """ dump_args = {"indent": indent} if indent else {"separators": (",", ":")} - if json_default: - dump_args["default"] = _wrap_user_json_default(json_default) - else: - dump_args["default"] = _json_default - data = json.dumps(root, **dump_args) + data = json.dumps(root, default=_json_default, **dump_args) if not decode: return data.encode("utf-8") @@ -107,34 +96,36 @@ def pytree_structure_load(data: str | bytes | bytearray) -> PyTreeStructure: return root -def _json_default(obj: Any) -> JSON: - """Default function for ``json.dump()``. Adds handling for the following types: - - ``pennylane.pytrees.PyTreeStructure`` - - ``pennylane.wires.Wires`` - - ``pennylane.measurements.shots.Shots`` - """ - if isinstance(obj, PyTreeStructure): - if obj.is_leaf: - return None - return [get_typename(obj.type_), obj.metadata, obj.children] +def _pytree_structure_to_json(obj: PyTreeStructure) -> JSON: + """JSON handler for serializating ``PyTreeStructure``.""" + if obj.is_leaf: + return None + + return [get_typename(obj.type_), obj.metadata, obj.children] - if isinstance(obj, Wires): - return obj.tolist() - if isinstance(obj, Shots): - return obj.shot_vector +def _wires_to_json(obj: Wires) -> JSON: + """JSON handler for serializing ``Wires``.""" + return obj.tolist() - raise TypeError(obj) +def _shots_to_json(obj: Shots) -> JSON: + """JSON handler for serializing ``Shots``.""" + return obj.shot_vector -def _wrap_user_json_default(user_default: Callable[[Any], JSON]) -> Callable[[Any], JSON]: - """Wraps a user-provided JSON default function. If ``user_default`` raises a TypeError, - calls ``_json_default``.""" - def _default_wrapped(obj: Any) -> JSON: - try: - return user_default(obj) - except TypeError: - return _json_default(obj) +_json_handlers: dict[type, Callable[[Any], JSON]] = { + PyTreeStructure: _pytree_structure_to_json, + Wires: _wires_to_json, + Shots: _shots_to_json, +} - return _default_wrapped + +def _json_default(obj: Any) -> JSON: + """Default function for ``json.dump()``. Calls the handler for the type of ``obj`` + in ``_json_handlers``. Raises ``TypeError`` if ``obj`` cannot be handled. + """ + try: + return _json_handlers[type(obj)](obj) + except KeyError as exc: + raise TypeError(obj) from exc From f5d209d071e08e35df03750b5ae0b12835fa7472 Mon Sep 17 00:00:00 2001 From: Jack Brown Date: Fri, 24 May 2024 11:55:30 -0400 Subject: [PATCH 16/28] tests --- pennylane/data/attributes/pytree.py | 21 ++++++-- pennylane/pytrees/__init__.py | 11 +---- pennylane/pytrees/pytrees.py | 15 ------ pennylane/pytrees/serialization.py | 17 +++++++ tests/pytrees/test_pytrees.py | 2 +- tests/pytrees/test_serialization.py | 77 +++++++++++++++++++---------- 6 files changed, 87 insertions(+), 56 deletions(-) diff --git a/pennylane/data/attributes/pytree.py b/pennylane/data/attributes/pytree.py index 96df3e485fa..3eeb027d506 100644 --- a/pennylane/data/attributes/pytree.py +++ b/pennylane/data/attributes/pytree.py @@ -1,12 +1,27 @@ -from functools import lru_cache -from typing import Any, TypeVar +# Copyright 2018-2024 Xanadu Quantum Technologies Inc. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Contains DatasetAttribute definition for PyTree types.""" + + +from typing import TypeVar import numpy as np from pennylane.data.base.attribute import DatasetAttribute from pennylane.data.base.hdf5 import HDF5Group from pennylane.data.base.mapper import AttributeTypeMapper -from pennylane.pytrees import flatten, list_pytree_types, serialization, unflatten +from pennylane.pytrees import flatten, serialization, unflatten T = TypeVar("T") diff --git a/pennylane/pytrees/__init__.py b/pennylane/pytrees/__init__.py index 35ec3059f65..76523bceffd 100644 --- a/pennylane/pytrees/__init__.py +++ b/pennylane/pytrees/__init__.py @@ -1,19 +1,10 @@ -from .pytrees import ( - PyTreeStructure, - flatten, - is_pytree, - leaf, - register_pytree, - unflatten, - list_pytree_types, -) +from .pytrees import PyTreeStructure, flatten, is_pytree, leaf, register_pytree, unflatten __all__ = [ "PyTreeStructure", "flatten", "is_pytree", "leaf", - "list_pytree_types", "register_pytree", "unflatten", ] diff --git a/pennylane/pytrees/pytrees.py b/pennylane/pytrees/pytrees.py index 683f224cf86..8cab911fe65 100644 --- a/pennylane/pytrees/pytrees.py +++ b/pennylane/pytrees/pytrees.py @@ -176,21 +176,6 @@ def get_typename_type(typename: str) -> type[Any]: raise ValueError(f"{repr(typename)} is not the name of a Pytree type.") from exc -def list_pytree_types(namespace: Optional[str] = "qml") -> Iterator[type]: - """Return an iterator listing all registered Pytree types under - the given ``namespace``. - """ - if namespace: - namespace_filter = f"{namespace}." - return ( - type_ - for type_, typename in type_to_typename.items() - if typename.startswith(namespace_filter) - ) - - return (type_ for type_ in type_to_typename) - - @dataclass(repr=False) class PyTreeStructure: """A pytree data structure, holding the type, metadata, and child pytree structures. diff --git a/pennylane/pytrees/serialization.py b/pennylane/pytrees/serialization.py index 0b19f078e8a..52003a0b42e 100644 --- a/pennylane/pytrees/serialization.py +++ b/pennylane/pytrees/serialization.py @@ -1,3 +1,20 @@ +# Copyright 2018-2024 Xanadu Quantum Technologies Inc. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +An internal module for serializing and deserializing Pennylane pytrees. +""" + import json from collections.abc import Callable from typing import Any, Literal, Optional, Union, overload diff --git a/tests/pytrees/test_pytrees.py b/tests/pytrees/test_pytrees.py index c2829c4385f..66233981d05 100644 --- a/tests/pytrees/test_pytrees.py +++ b/tests/pytrees/test_pytrees.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """ -Tests for the pennylane pytrees module +Tests for the pennylane pytrees module. """ import pennylane as qml from pennylane.pytrees import PyTreeStructure, flatten, leaf, register_pytree, unflatten diff --git a/tests/pytrees/test_serialization.py b/tests/pytrees/test_serialization.py index 271cd39de55..f729413c62d 100644 --- a/tests/pytrees/test_serialization.py +++ b/tests/pytrees/test_serialization.py @@ -1,19 +1,29 @@ +# Copyright 2018-2024 Xanadu Quantum Technologies Inc. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Tests for the pytrees serialization module. +""" + import json +from typing import Any +import numpy as np import pytest import pennylane as qml -from pennylane.measurements.shots import Shots from pennylane.ops import PauliX, Prod, Sum -from pennylane.pytrees import ( - PyTreeStructure, - flatten, - is_pytree, - leaf, - list_pytree_types, - register_pytree, - unflatten, -) +from pennylane.pytrees import PyTreeStructure, flatten, is_pytree, leaf, register_pytree, unflatten from pennylane.pytrees.pytrees import ( flatten_registrations, type_to_typename, @@ -54,11 +64,6 @@ def register_test_node(): del type_to_typename[CustomNode] -def test_list_pytree_types(): - """Test for ``list_pytree_types()``.""" - assert list(list_pytree_types("test")) == [CustomNode] - - @pytest.mark.parametrize( "cls, result", [ @@ -69,6 +74,7 @@ def test_list_pytree_types(): (Prod, True), (PauliX, True), (int, False), + (set, False), ], ) def test_is_pytree(cls, result): @@ -140,17 +146,34 @@ def test_structure_load(): ) -def test_nested_pl_object_roundtrip(): - tape_in = qml.tape.QuantumScript( - [qml.adjoint(qml.RX(0.1, wires=0))], - [qml.expval(2 * qml.X(0))], - shots=50, - trainable_params=(0, 1), - ) +H_ONE_QUBIT = np.array([[1.0, 0.5j], [-0.5j, 2.5]]) +H_TWO_QUBITS = np.array( + [[0.5, 1.0j, 0.0, -3j], [-1.0j, -1.1, 0.0, -0.1], [0.0, 0.0, -0.9, 12.0], [3j, -0.1, 12.0, 0.0]] +) - data, struct = flatten(tape_in) - tape_out = unflatten(data, pytree_structure_load(pytree_structure_dump(struct))) - assert type(tape_out) == type(tape_in) - assert repr(tape_out) == repr(tape_in) - assert list(tape_out) == list(tape_in) +@pytest.mark.parametrize( + "obj_in", + [ + qml.tape.QuantumScript( + [qml.adjoint(qml.RX(0.1, wires=0))], + [qml.expval(2 * qml.X(0))], + shots=50, + trainable_params=[0, 1], + ), + Prod(qml.X(0), qml.RX(0.1, wires=0), qml.X(1), id="id"), + Sum( + qml.Hermitian(H_ONE_QUBIT, 2), + qml.Hermitian(H_TWO_QUBITS, [0, 1]), + qml.PauliX(1), + qml.Identity("a"), + ), + ], +) +def test_pennylane_pytree_roundtrip(obj_in: Any): + """Test that Pennylane Pytree objects are requal to themselves after + a serialization roundtrip.""" + data, struct = flatten(obj_in) + obj_out = unflatten(data, pytree_structure_load(pytree_structure_dump(struct))) + + assert qml.equal(obj_in, obj_out) From 5d63558104df4d79189424912c17be50d7803ebb Mon Sep 17 00:00:00 2001 From: Jack Brown Date: Fri, 24 May 2024 12:03:11 -0400 Subject: [PATCH 17/28] codefactor --- pennylane/data/attributes/operator/operator.py | 1 + pennylane/pytrees/__init__.py | 17 +++++++++++++++++ pennylane/pytrees/pytrees.py | 2 +- pennylane/pytrees/serialization.py | 3 +-- 4 files changed, 20 insertions(+), 3 deletions(-) diff --git a/pennylane/data/attributes/operator/operator.py b/pennylane/data/attributes/operator/operator.py index 76fca0681fa..a30b19a49bc 100644 --- a/pennylane/data/attributes/operator/operator.py +++ b/pennylane/data/attributes/operator/operator.py @@ -51,6 +51,7 @@ class DatasetOperator(Generic[Op], DatasetAttribute[HDF5Group, Op, Op]): @classmethod @lru_cache(1) def supported_ops(cls) -> FrozenSet[Type[Operator]]: + """Set of supported operators.""" return frozenset( ( # pennylane/operation/Tensor diff --git a/pennylane/pytrees/__init__.py b/pennylane/pytrees/__init__.py index 76523bceffd..6180de919f5 100644 --- a/pennylane/pytrees/__init__.py +++ b/pennylane/pytrees/__init__.py @@ -1,3 +1,20 @@ +# Copyright 2018-2024 Xanadu Quantum Technologies Inc. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +An internal module for working with pytrees. +""" + from .pytrees import PyTreeStructure, flatten, is_pytree, leaf, register_pytree, unflatten __all__ = [ diff --git a/pennylane/pytrees/pytrees.py b/pennylane/pytrees/pytrees.py index 8cab911fe65..f023c6c4877 100644 --- a/pennylane/pytrees/pytrees.py +++ b/pennylane/pytrees/pytrees.py @@ -14,7 +14,7 @@ """ An internal module for working with pytrees. """ -from collections.abc import Callable, Iterator +from collections.abc import Callable from dataclasses import dataclass, field from typing import Any, Optional diff --git a/pennylane/pytrees/serialization.py b/pennylane/pytrees/serialization.py index 52003a0b42e..ead17fb7183 100644 --- a/pennylane/pytrees/serialization.py +++ b/pennylane/pytrees/serialization.py @@ -99,8 +99,7 @@ def pytree_structure_load(data: str | bytes | bytearray) -> PyTreeStructure: while todo: curr = todo.pop() - for i in range(len(curr)): - child = curr[i] + for i, child in enumerate(curr): if child is None: curr[i] = leaf continue From c4997068260618b292ebfde84159a9ac52517558 Mon Sep 17 00:00:00 2001 From: Jack Brown Date: Fri, 24 May 2024 12:04:51 -0400 Subject: [PATCH 18/28] don't use | for union --- pennylane/pytrees/serialization.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pennylane/pytrees/serialization.py b/pennylane/pytrees/serialization.py index ead17fb7183..d189b3f1346 100644 --- a/pennylane/pytrees/serialization.py +++ b/pennylane/pytrees/serialization.py @@ -83,7 +83,7 @@ def pytree_structure_dump( return data -def pytree_structure_load(data: str | bytes | bytearray) -> PyTreeStructure: +def pytree_structure_load(data: Union[str, bytes, bytearray]) -> PyTreeStructure: """Load a previously serialized Pytree structure. >>> from pennylane.pytrees.serialization import pytree_structure_dump From 215f39c359b703e1ad7533820ce64721b2b2e199 Mon Sep 17 00:00:00 2001 From: Jack Brown Date: Fri, 24 May 2024 12:15:15 -0400 Subject: [PATCH 19/28] remove dupe test file --- .../data/attributes/operator/operator.py | 4 +- tests/pytrees/test_serialization.py | 6 + tests/test_pytrees.py | 144 ------------------ 3 files changed, 8 insertions(+), 146 deletions(-) delete mode 100644 tests/test_pytrees.py diff --git a/pennylane/data/attributes/operator/operator.py b/pennylane/data/attributes/operator/operator.py index a30b19a49bc..28604b2cfc1 100644 --- a/pennylane/data/attributes/operator/operator.py +++ b/pennylane/data/attributes/operator/operator.py @@ -48,6 +48,8 @@ class DatasetOperator(Generic[Op], DatasetAttribute[HDF5Group, Op, Op]): ``Hamiltonian`` and ``Tensor`` operators. """ + type_id = "operator" + @classmethod @lru_cache(1) def supported_ops(cls) -> FrozenSet[Type[Operator]]: @@ -189,8 +191,6 @@ def supported_ops(cls) -> FrozenSet[Type[Operator]]: ) ) - type_id = "operator" - def value_to_hdf5(self, bind_parent: HDF5Group, key: str, value: Op) -> HDF5Group: return self._ops_to_hdf5(bind_parent, key, [value]) diff --git a/tests/pytrees/test_serialization.py b/tests/pytrees/test_serialization.py index f729413c62d..3df33973579 100644 --- a/tests/pytrees/test_serialization.py +++ b/tests/pytrees/test_serialization.py @@ -168,6 +168,12 @@ def test_structure_load(): qml.PauliX(1), qml.Identity("a"), ), + qml.Hamiltonian( + (1.1, -0.4, 0.333), (qml.PauliX(0), qml.Hermitian(H_ONE_QUBIT, 2), qml.PauliZ(2)) + ), + qml.Hamiltonian( + np.array([-0.1, 0.5]), [qml.Hermitian(H_TWO_QUBITS, [0, 1]), qml.PauliY(0)] + ), ], ) def test_pennylane_pytree_roundtrip(obj_in: Any): diff --git a/tests/test_pytrees.py b/tests/test_pytrees.py deleted file mode 100644 index c2829c4385f..00000000000 --- a/tests/test_pytrees.py +++ /dev/null @@ -1,144 +0,0 @@ -# Copyright 2018-2024 Xanadu Quantum Technologies Inc. - -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at - -# http://www.apache.org/licenses/LICENSE-2.0 - -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -Tests for the pennylane pytrees module -""" -import pennylane as qml -from pennylane.pytrees import PyTreeStructure, flatten, leaf, register_pytree, unflatten - - -def test_structure_repr_str(): - """Test the repr of the structure class.""" - op = qml.RX(0.1, wires=0) - _, structure = qml.pytrees.flatten(op) - expected = "PyTreeStructure(RX, (, ()), [PyTreeStructure()])" - assert repr(structure) == expected - expected_str = "PyTree(RX, (, ()), [Leaf])" - assert str(structure) == expected_str - - -def test_register_new_class(): - """Test that new objects can be registered, flattened, and unflattened.""" - - # pylint: disable=too-few-public-methods - class MyObj: - """a dummy object.""" - - def __init__(self, a): - self.a = a - - def obj_flatten(obj): - return (obj.a,), None - - def obj_unflatten(data, _): - return MyObj(data[0]) - - register_pytree(MyObj, obj_flatten, obj_unflatten) - - obj = MyObj(0.5) - - data, structure = flatten(obj) - assert data == [0.5] - assert structure == PyTreeStructure(MyObj, None, [leaf]) - - new_obj = unflatten([1.0], structure) - assert isinstance(new_obj, MyObj) - assert new_obj.a == 1.0 - - -def test_list(): - """Test that pennylane treats list as a pytree.""" - - x = [1, 2, [3, 4]] - - data, structure = flatten(x) - assert data == [1, 2, 3, 4] - assert structure == PyTreeStructure( - list, None, [leaf, leaf, PyTreeStructure(list, None, [leaf, leaf])] - ) - - new_x = unflatten([5, 6, 7, 8], structure) - assert new_x == [5, 6, [7, 8]] - - -def test_tuple(): - """Test that pennylane can handle tuples as pytrees.""" - x = (1, 2, (3, 4)) - - data, structure = flatten(x) - assert data == [1, 2, 3, 4] - assert structure == PyTreeStructure( - tuple, None, [leaf, leaf, PyTreeStructure(tuple, None, [leaf, leaf])] - ) - - new_x = unflatten([5, 6, 7, 8], structure) - assert new_x == (5, 6, (7, 8)) - - -def test_dict(): - """Test that pennylane can handle dictionaries as pytees.""" - - x = {"a": 1, "b": {"c": 2, "d": 3}} - - data, structure = flatten(x) - assert data == [1, 2, 3] - assert structure == PyTreeStructure( - dict, ("a", "b"), [leaf, PyTreeStructure(dict, ("c", "d"), [leaf, leaf])] - ) - new_x = unflatten([5, 6, 7], structure) - assert new_x == {"a": 5, "b": {"c": 6, "d": 7}} - - -def test_nested_pl_object(): - """Test that we can flatten and unflatten nested pennylane object.""" - - tape = qml.tape.QuantumScript( - [qml.adjoint(qml.RX(0.1, wires=0))], - [qml.expval(2 * qml.X(0))], - shots=50, - trainable_params=(0, 1), - ) - - data, structure = flatten(tape) - assert data == [0.1, 2, None] - - wires0 = qml.wires.Wires(0) - op_structure = PyTreeStructure( - tape[0].__class__, (), [PyTreeStructure(qml.RX, (wires0, ()), [leaf])] - ) - list_op_struct = PyTreeStructure(list, None, [op_structure]) - - sprod_structure = PyTreeStructure( - qml.ops.SProd, (), [leaf, PyTreeStructure(qml.X, (wires0, ()), [])] - ) - meas_structure = PyTreeStructure( - qml.measurements.ExpectationMP, (("wires", None),), [sprod_structure, leaf] - ) - list_meas_struct = PyTreeStructure(list, None, [meas_structure]) - tape_structure = PyTreeStructure( - qml.tape.QuantumScript, - (tape.shots, tape.trainable_params), - [list_op_struct, list_meas_struct], - ) - - assert structure == tape_structure - - new_tape = unflatten([3, 4, None], structure) - expected_new_tape = qml.tape.QuantumScript( - [qml.adjoint(qml.RX(3, wires=0))], - [qml.expval(4 * qml.X(0))], - shots=50, - trainable_params=(0, 1), - ) - assert qml.equal(new_tape, expected_new_tape) From 7f7d5823dfd59711cca0c0eb76111ad5b7948f95 Mon Sep 17 00:00:00 2001 From: Jack Brown Date: Fri, 24 May 2024 12:32:03 -0400 Subject: [PATCH 20/28] pylint --- pennylane/pytrees/serialization.py | 3 +-- tests/data/attributes/test_pytree.py | 19 ++++++++++++++++++- tests/pytrees/test_serialization.py | 2 ++ 3 files changed, 21 insertions(+), 3 deletions(-) diff --git a/pennylane/pytrees/serialization.py b/pennylane/pytrees/serialization.py index d189b3f1346..72c835c4470 100644 --- a/pennylane/pytrees/serialization.py +++ b/pennylane/pytrees/serialization.py @@ -94,8 +94,8 @@ def pytree_structure_load(data: Union[str, bytes, bytearray]) -> PyTreeStructure jsoned = json.loads(data) root = PyTreeStructure(get_typename_type(jsoned[0]), jsoned[1], jsoned[2]) + # List of serialized child structures that will be de-serialized in place todo: list[list[Any]] = [root.children] - while todo: curr = todo.pop() @@ -106,7 +106,6 @@ def pytree_structure_load(data: Union[str, bytes, bytearray]) -> PyTreeStructure curr[i] = PyTreeStructure(get_typename_type(child[0]), child[1], child[2]) - # Child structures will be converted in place todo.append(child[2]) return root diff --git a/tests/data/attributes/test_pytree.py b/tests/data/attributes/test_pytree.py index 7276ad66643..33d81f07bb1 100644 --- a/tests/data/attributes/test_pytree.py +++ b/tests/data/attributes/test_pytree.py @@ -1,3 +1,20 @@ +# Copyright 2018-2023 Xanadu Quantum Technologies Inc. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Tests for the ``DatasetPyTree`` attribute type. +""" + from dataclasses import dataclass import pytest @@ -47,7 +64,7 @@ class TestDatasetPyTree: """Tests for ``DatasetPyTree``.""" def test_consumes_type(self): - """Test that PyTree-compatible types that is not a builtin are + """Test that PyTree-compatible types that are not builtin are consumed by ``DatasetPyTree``.""" dset = Dataset() dset.attr = CustomNode([1, 2, 3, 4], {"meta": "data"}) diff --git a/tests/pytrees/test_serialization.py b/tests/pytrees/test_serialization.py index 3df33973579..1c22d39e57b 100644 --- a/tests/pytrees/test_serialization.py +++ b/tests/pytrees/test_serialization.py @@ -37,6 +37,8 @@ class CustomNode: """Example Pytree for testing.""" + # pylint: disable=too-few-public-methods + def __init__(self, data, metadata): self.data = data self.metadata = metadata From 86bddda5cd052b8c15fa950eb687c6cb6b8aab77 Mon Sep 17 00:00:00 2001 From: Jack Brown Date: Fri, 24 May 2024 12:35:21 -0400 Subject: [PATCH 21/28] add to contrib --- doc/releases/changelog-dev.md | 1 + 1 file changed, 1 insertion(+) diff --git a/doc/releases/changelog-dev.md b/doc/releases/changelog-dev.md index 576315e168e..118405cbe24 100644 --- a/doc/releases/changelog-dev.md +++ b/doc/releases/changelog-dev.md @@ -201,6 +201,7 @@ This release contains contributions from (in alphabetical order): Lillian M. A. Frederiksen, Gabriel Bottrill, +Jack Brown, Astral Cai, Ahmed Darwish, Isaac De Vlugt, From 1c1c8e553df4b38bbd80849ae689a99f8d654a67 Mon Sep 17 00:00:00 2001 From: Jack Brown Date: Fri, 24 May 2024 13:26:00 -0400 Subject: [PATCH 22/28] coverage --- doc/releases/changelog-dev.md | 2 +- pennylane/pytrees/serialization.py | 2 +- tests/pytrees/test_pytrees.py | 40 +++++++++++++++++++++++++++++ tests/pytrees/test_serialization.py | 10 +++++++- 4 files changed, 51 insertions(+), 3 deletions(-) diff --git a/doc/releases/changelog-dev.md b/doc/releases/changelog-dev.md index 118405cbe24..cf30dba0058 100644 --- a/doc/releases/changelog-dev.md +++ b/doc/releases/changelog-dev.md @@ -101,7 +101,7 @@ * Empty initialization of `PauliVSpace` is permitted. [(#5675)](https://github.com/PennyLaneAI/pennylane/pull/5675) - + * `QuantumScript` properties are only calculated when needed, instead of on initialization. This decreases the classical overhead by >20%. `par_info`, `obs_sharing_wires`, and `obs_sharing_wires_id` are now public attributes. [(#5696)](https://github.com/PennyLaneAI/pennylane/pull/5696) diff --git a/pennylane/pytrees/serialization.py b/pennylane/pytrees/serialization.py index 72c835c4470..1d72b445a0a 100644 --- a/pennylane/pytrees/serialization.py +++ b/pennylane/pytrees/serialization.py @@ -143,4 +143,4 @@ def _json_default(obj: Any) -> JSON: try: return _json_handlers[type(obj)](obj) except KeyError as exc: - raise TypeError(obj) from exc + raise TypeError(f"Could not serialize metadata object: {repr(obj)}") from exc diff --git a/tests/pytrees/test_pytrees.py b/tests/pytrees/test_pytrees.py index 66233981d05..de1327ac596 100644 --- a/tests/pytrees/test_pytrees.py +++ b/tests/pytrees/test_pytrees.py @@ -14,8 +14,13 @@ """ Tests for the pennylane pytrees module. """ +import re + +import pytest + import pennylane as qml from pennylane.pytrees import PyTreeStructure, flatten, leaf, register_pytree, unflatten +from pennylane.pytrees.pytrees import get_typename, get_typename_type def test_structure_repr_str(): @@ -142,3 +147,38 @@ def test_nested_pl_object(): trainable_params=(0, 1), ) assert qml.equal(new_tape, expected_new_tape) + + +@pytest.mark.parametrize( + "type_,typename", [(list, "builtins.list"), (qml.Hadamard, "qml.Hadamard")] +) +def test_get_typename(type_, typename): + """Test for ``get_typename()``.""" + + assert get_typename(type_) == typename + + +def test_get_typename_invalid(): + """Tests that a ``TypeError`` is raised when passing an non-pytree + type to ``get_typename()``.""" + + with pytest.raises(TypeError, match=" is not a Pytree type"): + get_typename(int) + + +@pytest.mark.parametrize( + "type_,typename", [(list, "builtins.list"), (qml.Hadamard, "qml.Hadamard")] +) +def test_get_typename_type(type_, typename): + """Tests for ``get_typename_type()``.""" + assert get_typename_type(typename) is type_ + + +def test_get_typename_type_invalid(): + """Tests that a ``ValueError`` is raised when passing an invalid + typename to ``get_typename_type()``.""" + + with pytest.raises( + ValueError, match=re.escape("'not.a.typename' is not the name of a Pytree type.") + ): + get_typename_type("not.a.typename") diff --git a/tests/pytrees/test_serialization.py b/tests/pytrees/test_serialization.py index 1c22d39e57b..284c44146be 100644 --- a/tests/pytrees/test_serialization.py +++ b/tests/pytrees/test_serialization.py @@ -115,7 +115,15 @@ def test_pytree_structure_dump(decode): ] -def test_structure_load(): +def test_pytree_structure_dump_unserializable_metadata(): + """Test that a ``TypeError`` is raised if a Pytree has unserializable metadata.""" + _, struct = flatten(CustomNode([1, 2, 4], {"operator": qml.PauliX(0)})) + + with pytest.raises(TypeError, match=r"Could not serialize metadata object: X\(0\)"): + pytree_structure_dump(struct) + + +def test_pytree_structure_load(): """Test that ``pytree_structure_load()`` can parse a JSON-serialized PyTree.""" jsoned = json.dumps( [ From a5007efd40863a625ebd2cf2653bc1925de42e08 Mon Sep 17 00:00:00 2001 From: Jack Brown Date: Fri, 24 May 2024 14:20:23 -0400 Subject: [PATCH 23/28] use array for homogenous leaves --- pennylane/data/attributes/pytree.py | 15 +++++++-------- 1 file changed, 7 insertions(+), 8 deletions(-) diff --git a/pennylane/data/attributes/pytree.py b/pennylane/data/attributes/pytree.py index 3eeb027d506..2a7a591f2da 100644 --- a/pennylane/data/attributes/pytree.py +++ b/pennylane/data/attributes/pytree.py @@ -18,6 +18,7 @@ import numpy as np +from pennylane.data.attributes import DatasetArray, DatasetList from pennylane.data.base.attribute import DatasetAttribute from pennylane.data.base.hdf5 import HDF5Group from pennylane.data.base.mapper import AttributeTypeMapper @@ -35,22 +36,20 @@ class DatasetPyTree(DatasetAttribute[HDF5Group, T, T]): type_id = "pytree" def hdf5_to_value(self, bind: HDF5Group) -> T: - mapper = AttributeTypeMapper(bind) - return unflatten( - [mapper[str(i)].get_value() for i in range(bind["num_leaves"][()])], + AttributeTypeMapper(bind)["leaves"].get_value(), serialization.pytree_structure_load(bind["treedef"][()].tobytes()), ) def value_to_hdf5(self, bind_parent: HDF5Group, key: str, value: T) -> HDF5Group: bind = bind_parent.create_group(key) - mapper = AttributeTypeMapper(bind) - leaves, treedef = flatten(value) bind["treedef"] = np.void(serialization.pytree_structure_dump(treedef, decode=False)) - bind["num_leaves"] = len(leaves) - for i, leaf in enumerate(leaves): - mapper[str(i)] = leaf + + try: + DatasetArray(leaves, parent_and_key=(bind, "leaves")) + except (ValueError, TypeError): + DatasetList(leaves, parent_and_key=(bind, "leaves")) return bind From f4d8fb54690d92eb3bb25b32dd43cab0a99bf579 Mon Sep 17 00:00:00 2001 From: Jack Brown Date: Fri, 24 May 2024 14:42:34 -0400 Subject: [PATCH 24/28] comment --- pennylane/data/attributes/pytree.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/pennylane/data/attributes/pytree.py b/pennylane/data/attributes/pytree.py index 2a7a591f2da..048ba31e71b 100644 --- a/pennylane/data/attributes/pytree.py +++ b/pennylane/data/attributes/pytree.py @@ -48,6 +48,8 @@ def value_to_hdf5(self, bind_parent: HDF5Group, key: str, value: T) -> HDF5Group bind["treedef"] = np.void(serialization.pytree_structure_dump(treedef, decode=False)) try: + # Attempt to store leaves as an array, which will be more efficient + # but will fail if the leaves are not homogenous DatasetArray(leaves, parent_and_key=(bind, "leaves")) except (ValueError, TypeError): DatasetList(leaves, parent_and_key=(bind, "leaves")) From bc6ded63b9658c6ea5b39d1e3b42ad20aaf22b28 Mon Sep 17 00:00:00 2001 From: Jack Brown Date: Tue, 28 May 2024 10:12:40 -0400 Subject: [PATCH 25/28] suggestions from code review --- tests/data/attributes/test_pytree.py | 6 ++++-- tests/pytrees/test_serialization.py | 7 +++++-- 2 files changed, 9 insertions(+), 4 deletions(-) diff --git a/tests/data/attributes/test_pytree.py b/tests/data/attributes/test_pytree.py index 33d81f07bb1..680b4f3f08e 100644 --- a/tests/data/attributes/test_pytree.py +++ b/tests/data/attributes/test_pytree.py @@ -20,8 +20,8 @@ import pytest from pennylane.data import Dataset, DatasetPyTree -from pennylane.pytrees import register_pytree from pennylane.pytrees.pytrees import ( + _register_pytree_with_pennylane, flatten_registrations, type_to_typename, typename_to_type, @@ -50,7 +50,9 @@ def unflatten_custom(data, metadata): def register_test_node(): """Fixture that temporarily registers the ``CustomNode`` class as a Pytree.""" - register_pytree(CustomNode, flatten_custom, unflatten_custom) + # Use this instead of ``register_pytree()`` so that ``CustomNode`` will not + # be registered with jax. + _register_pytree_with_pennylane(CustomNode, "test.CustomNode", flatten_custom, unflatten_custom) yield diff --git a/tests/pytrees/test_serialization.py b/tests/pytrees/test_serialization.py index 284c44146be..e5829627aea 100644 --- a/tests/pytrees/test_serialization.py +++ b/tests/pytrees/test_serialization.py @@ -23,8 +23,9 @@ import pennylane as qml from pennylane.ops import PauliX, Prod, Sum -from pennylane.pytrees import PyTreeStructure, flatten, is_pytree, leaf, register_pytree, unflatten +from pennylane.pytrees import PyTreeStructure, flatten, is_pytree, leaf, unflatten from pennylane.pytrees.pytrees import ( + _register_pytree_with_pennylane, flatten_registrations, type_to_typename, typename_to_type, @@ -56,7 +57,9 @@ def unflatten_custom(data, metadata): def register_test_node(): """Fixture that temporarily registers the ``CustomNode`` class as a Pytree.""" - register_pytree(CustomNode, flatten_custom, unflatten_custom, namespace="test") + # Use this instead of ``register_pytree()`` so that ``CustomNode`` will not + # be registered with jax. + _register_pytree_with_pennylane(CustomNode, "test.CustomNode", flatten_custom, unflatten_custom) yield From 504e853d7b667cff0f418785361fbd9bbde9c9a7 Mon Sep 17 00:00:00 2001 From: Jack Brown Date: Tue, 28 May 2024 10:16:01 -0400 Subject: [PATCH 26/28] Apply suggestions from code review Co-authored-by: Mudit Pandey --- pennylane/data/attributes/pytree.py | 2 +- pennylane/pytrees/serialization.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/pennylane/data/attributes/pytree.py b/pennylane/data/attributes/pytree.py index 048ba31e71b..d790210b081 100644 --- a/pennylane/data/attributes/pytree.py +++ b/pennylane/data/attributes/pytree.py @@ -30,7 +30,7 @@ class DatasetPyTree(DatasetAttribute[HDF5Group, T, T]): """Attribute type for an object that can be converted to a Pytree. This is the default serialization method for - all Pennylane Pytrees, including sublcasses of ``Operator``. + all Pennylane Pytrees, including subclasses of ``Operator``. """ type_id = "pytree" diff --git a/pennylane/pytrees/serialization.py b/pennylane/pytrees/serialization.py index 1d72b445a0a..49cfd5bfdc0 100644 --- a/pennylane/pytrees/serialization.py +++ b/pennylane/pytrees/serialization.py @@ -60,7 +60,7 @@ def pytree_structure_dump( 'PyTreeStructure(, None, [PyTreeStructure(, ("a",), [PyTreeStructure()]), PyTreeStructure()])' >>> pytree_structure_dump(struct) - b'["builtins.list",null,[["builtins.dict",["a"],[null]],null]' + b'["builtins.list",null,[["builtins.dict",["a"],[null]],null]]' Args: root: Root of a Pytree structure From 11aa98301e73aad68e3ca1e46e76b721e3ab785b Mon Sep 17 00:00:00 2001 From: Jack Brown Date: Tue, 28 May 2024 10:18:57 -0400 Subject: [PATCH 27/28] docstrings --- pennylane/pytrees/serialization.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pennylane/pytrees/serialization.py b/pennylane/pytrees/serialization.py index 49cfd5bfdc0..1c53c092b4d 100644 --- a/pennylane/pytrees/serialization.py +++ b/pennylane/pytrees/serialization.py @@ -57,7 +57,7 @@ def pytree_structure_dump( >>> _, struct = flatten([{"a": 1}, 2]) >>> struct - 'PyTreeStructure(, None, [PyTreeStructure(, ("a",), [PyTreeStructure()]), PyTreeStructure()])' + PyTreeStructure(list, None, [dict, ("a",), [PyTreeStructure()]), PyTreeStructure()])' >>> pytree_structure_dump(struct) b'["builtins.list",null,[["builtins.dict",["a"],[null]],null]]' @@ -89,7 +89,7 @@ def pytree_structure_load(data: Union[str, bytes, bytearray]) -> PyTreeStructure >>> from pennylane.pytrees.serialization import pytree_structure_dump >>> pytree_structure_load('["builtins.list",null,[["builtins.dict",["a"],[null]],null]') - 'PyTreeStructure(, None, [PyTreeStructure(, ["a"], [PyTreeStructure()]), PyTreeStructure()])' + PyTreeStructure(list, None, [PyTreeStructure(dict, ["a"], [PyTreeStructure()]), PyTreeStructure()])' """ jsoned = json.loads(data) root = PyTreeStructure(get_typename_type(jsoned[0]), jsoned[1], jsoned[2]) From 13331287781c42f860f24be9c4cd4edd136f4184 Mon Sep 17 00:00:00 2001 From: Diego <67476785+DSGuala@users.noreply.github.com> Date: Wed, 19 Jun 2024 15:11:24 -0400 Subject: [PATCH 28/28] Apply suggestions from code review Co-authored-by: Utkarsh --- doc/releases/changelog-dev.md | 2 +- pennylane/data/attributes/pytree.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/doc/releases/changelog-dev.md b/doc/releases/changelog-dev.md index dc985e79f98..dac6f0af5e5 100644 --- a/doc/releases/changelog-dev.md +++ b/doc/releases/changelog-dev.md @@ -195,7 +195,7 @@ `par_info`, `obs_sharing_wires`, and `obs_sharing_wires_id` are now public attributes. [(#5696)](https://github.com/PennyLaneAI/pennylane/pull/5696) -* The `qml.data` module now supports PyTree types as dataset attributes +* The `qml.data` module now supports PyTree data types as dataset attributes [(#5732)](https://github.com/PennyLaneAI/pennylane/pull/5732) diff --git a/pennylane/data/attributes/pytree.py b/pennylane/data/attributes/pytree.py index d790210b081..712a2a19452 100644 --- a/pennylane/data/attributes/pytree.py +++ b/pennylane/data/attributes/pytree.py @@ -30,7 +30,7 @@ class DatasetPyTree(DatasetAttribute[HDF5Group, T, T]): """Attribute type for an object that can be converted to a Pytree. This is the default serialization method for - all Pennylane Pytrees, including subclasses of ``Operator``. + all PennyLane Pytrees, including subclasses of ``Operator``. """ type_id = "pytree"