Skip to content

Commit

Permalink
Add native tools for flattening and unflattening pytrees (#5701)
Browse files Browse the repository at this point in the history
**Context:**

Pytrees are nested data structures. Pytree tools can make it easier to
handle said data structure.

Jax, optree, and other packages provide tools for handling pytrees
already. To take advantage of those tools in core pennylane, we would
need to make one of those packages a dependency. Instead of adding an
extra dependency, we can have our own stripped-down version of pytree
tools.

**Description of the Change:**

This PR adds `tree_flatten`, `tree_unflatten`, `Structure`, and `Leaf`
to the pytrees module. It also updates `qml.pytrees.register_pytree` to
register the type with pennylane's pytree setup as well.

```pycon
>>> op = qml.adjoint(qml.Rot(1.2, 2.3, 3.4, wires=0))
>>> data, structure = flatten(op)
>>> data
[1.2, 2.3, 3.4]
>>> structure
<Tree(AdjointOperation, (), (<Tree(Rot, (<Wires = [0]>, ()), (Leaf, Leaf, Leaf))>,))>
>>> unflatten([-2, -3, -4], structure)
Adjoint(Rot(-2, -3, -4, wires=[0]))
```

**Benefits:**

Use in datasets module to serialize and de-serialize pennylane objects.

Potential future use in `bind_new_parameters`.

Easy extraction and reset of parameters.

**Possible Drawbacks:**

Not as performant as the C++ bound versions used by jax and optree.

**Related GitHub Issues:**

[sc-46349]

---------

Co-authored-by: Thomas R. Bromley <49409390+trbromley@users.noreply.github.com>
Co-authored-by: Jack Brown <jack@xanadu.ai>
Co-authored-by: Mudit Pandey <mudit.pandey@xanadu.ai>
  • Loading branch information
4 people authored May 24, 2024
1 parent 0900c84 commit 960204f
Show file tree
Hide file tree
Showing 3 changed files with 311 additions and 4 deletions.
3 changes: 3 additions & 0 deletions doc/releases/changelog-dev.md
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,9 @@
allowing error types to be more consistent with the context the `decompose` function is used in.
[(#5669)](https://github.com/PennyLaneAI/pennylane/pull/5669)

* The `qml.pytrees` module now has `flatten` and `unflatten` methods for serializing pytrees.
[(#5701)](https://github.com/PennyLaneAI/pennylane/pull/5701)

* Empty initialization of `PauliVSpace` is permitted.
[(#5675)](https://github.com/PennyLaneAI/pennylane/pull/5675)

Expand Down
168 changes: 164 additions & 4 deletions pennylane/pytrees.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2018-2023 Xanadu Quantum Technologies Inc.
# Copyright 2018-2024 Xanadu Quantum Technologies Inc.

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand All @@ -14,8 +14,8 @@
"""
An internal module for working with pytrees.
"""

from typing import Any, Callable, Tuple
from dataclasses import dataclass, field
from typing import Any, Callable, List, Optional, Tuple

has_jax = True
try:
Expand All @@ -30,6 +30,57 @@
UnflattenFn = Callable[[Leaves, Metadata], Any]


def flatten_list(obj: list):
"""Flatten a list."""
return obj, None


def flatten_tuple(obj: tuple):
"""Flatten a tuple."""
return obj, None


def flatten_dict(obj: dict):
"""Flatten a dictionary."""
return obj.values(), tuple(obj.keys())


flatten_registrations: dict[type, FlattenFn] = {
list: flatten_list,
tuple: flatten_tuple,
dict: flatten_dict,
}


def unflatten_list(data, _) -> list:
"""Unflatten a list."""
return data if isinstance(data, list) else list(data)


def unflatten_tuple(data, _) -> tuple:
"""Unflatten a tuple."""
return tuple(data)


def unflatten_dict(data, metadata) -> dict:
"""Unflatten a dictinoary."""
return dict(zip(metadata, data))


unflatten_registrations: dict[type, UnflattenFn] = {
list: unflatten_list,
tuple: unflatten_tuple,
dict: unflatten_dict,
}


def _register_pytree_with_pennylane(
pytree_type: type, flatten_fn: FlattenFn, unflatten_fn: UnflattenFn
):
flatten_registrations[pytree_type] = flatten_fn
unflatten_registrations[pytree_type] = unflatten_fn


def _register_pytree_with_jax(pytree_type: type, flatten_fn: FlattenFn, unflatten_fn: UnflattenFn):
def jax_unflatten(aux, parameters):
return unflatten_fn(parameters, aux)
Expand All @@ -40,7 +91,8 @@ def jax_unflatten(aux, parameters):
def register_pytree(pytree_type: type, flatten_fn: FlattenFn, unflatten_fn: UnflattenFn):
"""Register a type with all available pytree backends.
Current backends is jax.
Current backends are jax and pennylane.
Args:
pytree_type (type): the type to register, such as ``qml.RX``
flatten_fn (Callable): a function that splits an object into trainable leaves and hashable metadata.
Expand All @@ -52,7 +104,115 @@ def register_pytree(pytree_type: type, flatten_fn: FlattenFn, unflatten_fn: Unfl
Side Effects:
``pytree`` type becomes registered with available backends.
.. seealso:: :func:`~.flatten`, :func:`~.unflatten`.
"""

_register_pytree_with_pennylane(pytree_type, flatten_fn, unflatten_fn)

if has_jax:
_register_pytree_with_jax(pytree_type, flatten_fn, unflatten_fn)


@dataclass(repr=False)
class PyTreeStructure:
"""A pytree data structure, holding the type, metadata, and child pytree structures.
>>> op = qml.adjoint(qml.RX(0.1, 0))
>>> data, structure = qml.pytrees.flatten(op)
>>> structure
PyTree(AdjointOperation, (), [PyTree(RX, (<Wires = [0]>, ()), [Leaf])])
A leaf is defined as just a ``PyTreeStructure`` with ``type=None``.
"""

type: Optional[type] = None
"""The type corresponding to the node. If ``None``, then the structure is a leaf."""

metadata: Metadata = ()
"""Any metadata needed to reproduce the original object."""

children: list["PyTreeStructure"] = field(default_factory=list)
"""The children of the pytree node. Can be either other structures or terminal leaves."""

@property
def is_leaf(self) -> bool:
"""Whether or not the structure is a leaf."""
return self.type is None

def __repr__(self):
if self.is_leaf:
return "PyTreeStructure()"
return f"PyTreeStructure({self.type.__name__}, {self.metadata}, {self.children})"

def __str__(self):
if self.is_leaf:
return "Leaf"
children_string = ", ".join(str(c) for c in self.children)
return f"PyTree({self.type.__name__}, {self.metadata}, [{children_string}])"


leaf = PyTreeStructure(None, (), [])


def flatten(obj) -> tuple[list[Any], PyTreeStructure]:
"""Flattens a pytree into leaves and a structure.
Args:
obj (Any): any object
Returns:
List[Any], Union[Structure, Leaf]: a list of leaves and a structure representing the object
>>> op = qml.adjoint(qml.Rot(1.2, 2.3, 3.4, wires=0))
>>> data, structure = flatten(op)
>>> data
[1.2, 2.3, 3.4]
>>> structure
<PyTree(AdjointOperation, (), (<PyTree(Rot, (<Wires = [0]>, ()), (Leaf, Leaf, Leaf))>,))>
See also :function:`~.unflatten`.
"""
flatten_fn = flatten_registrations.get(type(obj), None)
if flatten_fn is None:
return [obj], leaf
leaves, metadata = flatten_fn(obj)

flattened_leaves = []
child_structures = []
for l in leaves:
child_leaves, child_structure = flatten(l)
flattened_leaves += child_leaves
child_structures.append(child_structure)

structure = PyTreeStructure(type(obj), metadata, child_structures)
return flattened_leaves, structure


def unflatten(data: List[Any], structure: PyTreeStructure) -> Any:
"""Bind data to a structure to reconstruct a pytree object.
Args:
data (Iterable): iterable of numbers and numeric arrays
structure (Structure, Leaf): The pytree structure object
Returns:
A repacked pytree.
.. see-also:: :function:`~.flatten`
>>> op = qml.adjoint(qml.Rot(1.2, 2.3, 3.4, wires=0))
>>> data, structure = flatten(op)
>>> unflatten([-2, -3, -4], structure)
Adjoint(Rot(-2, -3, -4, wires=[0]))
"""
return _unflatten(iter(data), structure)


def _unflatten(new_data, structure):
if structure.is_leaf:
return next(new_data)
children = tuple(_unflatten(new_data, s) for s in structure.children)
return unflatten_registrations[structure.type](children, structure.metadata)
144 changes: 144 additions & 0 deletions tests/test_pytrees.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
# Copyright 2018-2024 Xanadu Quantum Technologies Inc.

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at

# http://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Tests for the pennylane pytrees module
"""
import pennylane as qml
from pennylane.pytrees import PyTreeStructure, flatten, leaf, register_pytree, unflatten


def test_structure_repr_str():
"""Test the repr of the structure class."""
op = qml.RX(0.1, wires=0)
_, structure = qml.pytrees.flatten(op)
expected = "PyTreeStructure(RX, (<Wires = [0]>, ()), [PyTreeStructure()])"
assert repr(structure) == expected
expected_str = "PyTree(RX, (<Wires = [0]>, ()), [Leaf])"
assert str(structure) == expected_str


def test_register_new_class():
"""Test that new objects can be registered, flattened, and unflattened."""

# pylint: disable=too-few-public-methods
class MyObj:
"""a dummy object."""

def __init__(self, a):
self.a = a

def obj_flatten(obj):
return (obj.a,), None

def obj_unflatten(data, _):
return MyObj(data[0])

register_pytree(MyObj, obj_flatten, obj_unflatten)

obj = MyObj(0.5)

data, structure = flatten(obj)
assert data == [0.5]
assert structure == PyTreeStructure(MyObj, None, [leaf])

new_obj = unflatten([1.0], structure)
assert isinstance(new_obj, MyObj)
assert new_obj.a == 1.0


def test_list():
"""Test that pennylane treats list as a pytree."""

x = [1, 2, [3, 4]]

data, structure = flatten(x)
assert data == [1, 2, 3, 4]
assert structure == PyTreeStructure(
list, None, [leaf, leaf, PyTreeStructure(list, None, [leaf, leaf])]
)

new_x = unflatten([5, 6, 7, 8], structure)
assert new_x == [5, 6, [7, 8]]


def test_tuple():
"""Test that pennylane can handle tuples as pytrees."""
x = (1, 2, (3, 4))

data, structure = flatten(x)
assert data == [1, 2, 3, 4]
assert structure == PyTreeStructure(
tuple, None, [leaf, leaf, PyTreeStructure(tuple, None, [leaf, leaf])]
)

new_x = unflatten([5, 6, 7, 8], structure)
assert new_x == (5, 6, (7, 8))


def test_dict():
"""Test that pennylane can handle dictionaries as pytees."""

x = {"a": 1, "b": {"c": 2, "d": 3}}

data, structure = flatten(x)
assert data == [1, 2, 3]
assert structure == PyTreeStructure(
dict, ("a", "b"), [leaf, PyTreeStructure(dict, ("c", "d"), [leaf, leaf])]
)
new_x = unflatten([5, 6, 7], structure)
assert new_x == {"a": 5, "b": {"c": 6, "d": 7}}


def test_nested_pl_object():
"""Test that we can flatten and unflatten nested pennylane object."""

tape = qml.tape.QuantumScript(
[qml.adjoint(qml.RX(0.1, wires=0))],
[qml.expval(2 * qml.X(0))],
shots=50,
trainable_params=(0, 1),
)

data, structure = flatten(tape)
assert data == [0.1, 2, None]

wires0 = qml.wires.Wires(0)
op_structure = PyTreeStructure(
tape[0].__class__, (), [PyTreeStructure(qml.RX, (wires0, ()), [leaf])]
)
list_op_struct = PyTreeStructure(list, None, [op_structure])

sprod_structure = PyTreeStructure(
qml.ops.SProd, (), [leaf, PyTreeStructure(qml.X, (wires0, ()), [])]
)
meas_structure = PyTreeStructure(
qml.measurements.ExpectationMP, (("wires", None),), [sprod_structure, leaf]
)
list_meas_struct = PyTreeStructure(list, None, [meas_structure])
tape_structure = PyTreeStructure(
qml.tape.QuantumScript,
(tape.shots, tape.trainable_params),
[list_op_struct, list_meas_struct],
)

assert structure == tape_structure

new_tape = unflatten([3, 4, None], structure)
expected_new_tape = qml.tape.QuantumScript(
[qml.adjoint(qml.RX(3, wires=0))],
[qml.expval(4 * qml.X(0))],
shots=50,
trainable_params=(0, 1),
)
assert qml.equal(new_tape, expected_new_tape)

0 comments on commit 960204f

Please sign in to comment.