Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Dataset Attribute type for Pytrees #5732

Merged
merged 49 commits into from
Jun 21, 2024
Merged
Show file tree
Hide file tree
Changes from 46 commits
Commits
Show all changes
49 commits
Select commit Hold shift + click to select a range
253515f
add tools for flattening and unflattening pytrees
albi3ro May 13, 2024
7180296
Merge branch 'master' into pytree-flatten-unflatten
albi3ro May 16, 2024
1a7fa19
adding coverage
albi3ro May 16, 2024
9858055
Apply suggestions from code review
albi3ro May 21, 2024
a618d21
responding to feedback, leaf is PyTreeStructure with no type
albi3ro May 21, 2024
dd5c7e4
Merge branch 'master' into pytree-flatten-unflatten
albi3ro May 21, 2024
c7ca2c1
pytree module
brownj85 May 22, 2024
45141b3
Update pennylane/pytrees.py
albi3ro May 23, 2024
8086b3f
Update pennylane/pytrees.py
albi3ro May 23, 2024
521e7fd
Apply suggestions from code review
albi3ro May 23, 2024
450de94
change repr, add str
albi3ro May 23, 2024
73ec5ed
add serialization
brownj85 May 23, 2024
6c71b57
Merge branch 'pytree-flatten-unflatten' into datasets-pytrees
brownj85 May 23, 2024
0757766
tests
brownj85 May 23, 2024
2882ce5
Merge branch 'master' into pytree-flatten-unflatten
albi3ro May 23, 2024
5a6d0e5
tests, docs
brownj85 May 23, 2024
dcf1bb9
tests
brownj85 May 23, 2024
4088fa6
Merge branch 'pytree-flatten-unflatten' into datasets-pytrees
brownj85 May 23, 2024
fafc2fe
update changelog
brownj85 May 23, 2024
703e3ab
Merge branch 'master' into datasets-pytrees
brownj85 May 24, 2024
73d1626
refactor json handling
brownj85 May 24, 2024
f5d209d
tests
brownj85 May 24, 2024
e7af0b3
Merge branch 'master' into datasets-pytrees
brownj85 May 24, 2024
5d63558
codefactor
brownj85 May 24, 2024
c499706
don't use | for union
brownj85 May 24, 2024
215f39c
remove dupe test file
brownj85 May 24, 2024
7f7d582
pylint
brownj85 May 24, 2024
0c13848
Merge branch 'master' into datasets-pytrees
brownj85 May 24, 2024
86bddda
add to contrib
brownj85 May 24, 2024
1c1c8e5
coverage
brownj85 May 24, 2024
a5007ef
use array for homogenous leaves
brownj85 May 24, 2024
f4d8fb5
comment
brownj85 May 24, 2024
f8458d1
Merge branch 'master' into datasets-pytrees
brownj85 May 24, 2024
cf851af
Merge branch 'master' into datasets-pytrees
brownj85 May 28, 2024
bc6ded6
suggestions from code review
brownj85 May 28, 2024
504e853
Apply suggestions from code review
brownj85 May 28, 2024
11aa983
docstrings
brownj85 May 28, 2024
8bfc3f2
Merge branch 'master' into datasets-pytrees
brownj85 May 30, 2024
001ada4
Merge branch 'master' into datasets-pytrees
brownj85 May 30, 2024
671023c
Merge branch 'master' into datasets-pytrees
brownj85 Jun 3, 2024
4803d12
Merge branch 'master' into datasets-pytrees
brownj85 Jun 3, 2024
136a331
Merge branch 'master' into datasets-pytrees
brownj85 Jun 5, 2024
2e71c62
Merge branch 'master' into datasets-pytrees
brownj85 Jun 6, 2024
b80543d
Merge branch 'master' into datasets-pytrees
brownj85 Jun 6, 2024
45b6006
Merge branch 'master' into datasets-pytrees
doctorperceptron Jun 11, 2024
1333128
Apply suggestions from code review
DSGuala Jun 19, 2024
2904d1e
Merge branch 'master' into datasets-pytrees
DSGuala Jun 20, 2024
e22b1dc
Merge branch 'master' into datasets-pytrees
DSGuala Jun 20, 2024
f8ce69e
Merge branch 'master' into datasets-pytrees
DSGuala Jun 21, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions doc/releases/changelog-dev.md
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,10 @@
* `QuantumScript` properties are only calculated when needed, instead of on initialization. This decreases the classical overhead by >20%.
`par_info`, `obs_sharing_wires`, and `obs_sharing_wires_id` are now public attributes.
[(#5696)](https://github.com/PennyLaneAI/pennylane/pull/5696)

* The `qml.data` module now supports PyTree data types as dataset attributes
[(#5732)](https://github.com/PennyLaneAI/pennylane/pull/5732)


* `qml.ops.Conditional` now inherits from `qml.ops.SymbolicOp`, thus it inherits several useful common functionalities. Other properties such as adjoint and diagonalizing gates have been added using the `base` properties.
[(##5772)](https://github.com/PennyLaneAI/pennylane/pull/5772)
Expand Down Expand Up @@ -383,6 +387,7 @@ Guillermo Alonso-Linaje,
Utkarsh Azad,
Lillian M. A. Frederiksen,
Gabriel Bottrill,
Jack Brown,
Astral Cai,
Ahmed Darwish,
Isaac De Vlugt,
Expand Down
2 changes: 2 additions & 0 deletions pennylane/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,7 @@ class QuantumOscillator(qml.data.Dataset, data_name="quantum_oscillator", identi
DatasetSparseArray,
DatasetString,
DatasetTuple,
DatasetPyTree,
)
from .base import DatasetNotWriteableError
from .base.attribute import AttributeInfo, DatasetAttribute, attribute
Expand All @@ -225,6 +226,7 @@ class QuantumOscillator(qml.data.Dataset, data_name="quantum_oscillator", identi
"DatasetAttribute",
"DatasetNotWriteableError",
"DatasetArray",
"DatasetPyTree",
"DatasetScalar",
"DatasetString",
"DatasetList",
Expand Down
2 changes: 2 additions & 0 deletions pennylane/data/attributes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from .sparse_array import DatasetSparseArray
from .string import DatasetString
from .tuple import DatasetTuple
from .pytree import DatasetPyTree

__all__ = (
"DatasetArray",
Expand All @@ -32,6 +33,7 @@
"DatasetDict",
"DatasetList",
"DatasetOperator",
"DatasetPyTree",
"DatasetSparseArray",
"DatasetMolecule",
"DatasetNone",
Expand Down
8 changes: 5 additions & 3 deletions pennylane/data/attributes/operator/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,8 @@ class DatasetOperator(Generic[Op], DatasetAttribute[HDF5Group, Op, Op]):

@classmethod
@lru_cache(1)
def consumes_types(cls) -> FrozenSet[Type[Operator]]:
def supported_ops(cls) -> FrozenSet[Type[Operator]]:
"""Set of supported operators."""
return frozenset(
(
# pennylane/operation/Tensor
Expand Down Expand Up @@ -214,7 +215,7 @@ def _ops_to_hdf5(
op_key = f"op_{i}"
if isinstance(op, (qml.ops.Prod, qml.ops.SProd, qml.ops.Sum)):
op = op.simplify()
if type(op) not in self.consumes_types():
if type(op) not in self.supported_ops():
DSGuala marked this conversation as resolved.
Show resolved Hide resolved
raise TypeError(
f"Serialization of operator type '{type(op).__name__}' is not supported."
)
Expand Down Expand Up @@ -254,6 +255,7 @@ def _hdf5_to_ops(self, bind: HDF5Group) -> List[Operator]:
wires_bind = bind["op_wire_labels"]
op_class_names = [] if names_bind.shape == (0,) else names_bind.asstr()
op_wire_labels = [] if wires_bind.shape == (0,) else wires_bind.asstr()

with qml.QueuingManager.stop_recording():
for i, op_class_name in enumerate(op_class_names):
op_key = f"op_{i}"
Expand Down Expand Up @@ -293,4 +295,4 @@ def _hdf5_to_ops(self, bind: HDF5Group) -> List[Operator]:
@lru_cache(1)
def _supported_ops_dict(cls) -> Dict[str, Type[Operator]]:
"""Returns a dict mapping ``Operator`` subclass names to the class."""
return {op.__name__: op for op in cls.consumes_types()}
return {op.__name__: op for op in cls.supported_ops()}
57 changes: 57 additions & 0 deletions pennylane/data/attributes/pytree.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
# Copyright 2018-2024 Xanadu Quantum Technologies Inc.

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at

# http://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Contains DatasetAttribute definition for PyTree types."""


from typing import TypeVar

import numpy as np

from pennylane.data.attributes import DatasetArray, DatasetList
from pennylane.data.base.attribute import DatasetAttribute
from pennylane.data.base.hdf5 import HDF5Group
from pennylane.data.base.mapper import AttributeTypeMapper
from pennylane.pytrees import flatten, serialization, unflatten

T = TypeVar("T")


class DatasetPyTree(DatasetAttribute[HDF5Group, T, T]):
"""Attribute type for an object that can be converted to
a Pytree. This is the default serialization method for
all PennyLane Pytrees, including subclasses of ``Operator``.
"""

type_id = "pytree"

def hdf5_to_value(self, bind: HDF5Group) -> T:
return unflatten(
AttributeTypeMapper(bind)["leaves"].get_value(),
serialization.pytree_structure_load(bind["treedef"][()].tobytes()),
)

def value_to_hdf5(self, bind_parent: HDF5Group, key: str, value: T) -> HDF5Group:
bind = bind_parent.create_group(key)
leaves, treedef = flatten(value)

bind["treedef"] = np.void(serialization.pytree_structure_dump(treedef, decode=False))

try:
# Attempt to store leaves as an array, which will be more efficient
# but will fail if the leaves are not homogenous
DatasetArray(leaves, parent_and_key=(bind, "leaves"))
except (ValueError, TypeError):
DatasetList(leaves, parent_and_key=(bind, "leaves"))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Catching exceptions might be expensive for large data. Is there a way to rather rely on if conditions here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's better to delegate to numpy here - otherwise we'd need to implement a check that leaves is homogenous and array-compatible, which would likely just be duplicating numpy's logic. This would also be slower in the ideal case that leaves is homogenous.

I had the same concern about performance, but datasets are read a lot more than they're written, and DatasetArray is a lot more compact and performant than DatasetList. So the tradeoff makes sense IMO


return bind
3 changes: 3 additions & 0 deletions pennylane/data/base/attribute.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
from pennylane.data.base import hdf5
from pennylane.data.base.hdf5 import HDF5, HDF5Any, HDF5Group
from pennylane.data.base.typing_util import UNSET, get_type, get_type_str
from pennylane.pytrees import is_pytree

T = TypeVar("T")

Expand Down Expand Up @@ -492,5 +493,7 @@ def match_obj_type(
ret = DatasetAttribute.registry["list"]
elif issubclass(type_, Mapping):
ret = DatasetAttribute.registry["dict"]
elif is_pytree(type_):
ret = DatasetAttribute.registry["pytree"]

return ret
12 changes: 7 additions & 5 deletions pennylane/measurements/shots.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""This module contains the Shots class to hold shot-related information."""
from collections.abc import Sequence

# pylint:disable=inconsistent-return-statements
from typing import NamedTuple, Sequence, Tuple
from typing import NamedTuple
Comment on lines -15 to +17
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any chances of this causing issues with doing shot-based measurements? The tests seem to be passing, so maybe no problem?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's hard to be 100% sure but it shouldn't cause a problem - Sequence is less restrictive than tuple so anything that worked before will still work



class ShotCopies(NamedTuple):
Expand All @@ -39,7 +41,7 @@ def valid_int(s):

def valid_tuple(s):
"""Returns True if s is a tuple of the form (shots, copies)."""
return isinstance(s, tuple) and len(s) == 2 and valid_int(s[0]) and valid_int(s[1])
return isinstance(s, Sequence) and len(s) == 2 and valid_int(s[0]) and valid_int(s[1])


class Shots:
Expand Down Expand Up @@ -136,7 +138,7 @@ class Shots:
total_shots: int = None
"""The total number of shots to be executed."""

shot_vector: Tuple[ShotCopies] = None
shot_vector: tuple[ShotCopies] = None
"""The tuple of :class:`~ShotCopies` to be executed. Each element is of the form ``(shots, copies)``."""

_SHOT_ERROR = ValueError(
Expand Down Expand Up @@ -167,7 +169,7 @@ def __init__(self, shots=None):
elif isinstance(shots, Sequence):
if not all(valid_int(s) or valid_tuple(s) for s in shots):
raise self._SHOT_ERROR
self.__all_tuple_init__([s if isinstance(s, tuple) else (s, 1) for s in shots])
self.__all_tuple_init__([s if isinstance(s, Sequence) else (s, 1) for s in shots])
DSGuala marked this conversation as resolved.
Show resolved Hide resolved
elif isinstance(shots, self.__class__):
return # self already _is_ shots as defined by __new__
else:
Expand Down Expand Up @@ -211,7 +213,7 @@ def __iter__(self):
for _ in range(shot_copy.copies):
yield shot_copy.shots

def __all_tuple_init__(self, shots: Sequence[Tuple]):
def __all_tuple_init__(self, shots: Sequence[tuple]):
res = []
total_shots = 0
current_shots, current_copies = shots[0]
Expand Down
27 changes: 27 additions & 0 deletions pennylane/pytrees/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
# Copyright 2018-2024 Xanadu Quantum Technologies Inc.

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at

# http://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
An internal module for working with pytrees.
"""

from .pytrees import PyTreeStructure, flatten, is_pytree, leaf, register_pytree, unflatten

__all__ = [
"PyTreeStructure",
"flatten",
"is_pytree",
"leaf",
"register_pytree",
"unflatten",
]
Loading