Skip to content

Commit

Permalink
Make templates valid Pytrees (#5698)
Browse files Browse the repository at this point in the history
**Context:**
Most templates are operations themselves. Therefore they should be valid
PyTrees, but some are not.

**Description of the Change:**
This PR adapts the `parameters` and/or `hyperparameters`, as well as the
methods `_flatten` and/or `_unflatten` for a number of parameters.
It adds `standard_validity` tests to all template test files where
`assert_valid` is not being called on the respective template yet.
Explicit flatten/unflatten tests are removed, because they are contained
in `assert_valid`.

**Benefits:**
Code quality/self-consistency and consistency of templates with PL
functionality.
Test suite quality.

**Possible Drawbacks:**

**Related GitHub Issues:**
This helps us go forward with capturing of templates, because
`primitive_bind_call` can be created in a way that is consistent with
`__init__` and `_unflatten`.

[sc-63810]

---------

Co-authored-by: Jay Soni <jbsoni@uwaterloo.ca>
  • Loading branch information
dwierichs and Jaybsoni authored May 30, 2024
1 parent 41dd7ab commit 856b14e
Show file tree
Hide file tree
Showing 43 changed files with 581 additions and 547 deletions.
13 changes: 13 additions & 0 deletions doc/releases/changelog-dev.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@

<h3>Improvements 🛠</h3>

* A number of templates have been updated to be valid pytrees and PennyLane operations.
[(#5698)](https://github.com/PennyLaneAI/pennylane/pull/5698)

* `ctrl` now works with tuple-valued `control_values` when applied to any already controlled operation.
[(#5725)](https://github.com/PennyLaneAI/pennylane/pull/5725)

Expand Down Expand Up @@ -137,6 +140,10 @@

<h3>Breaking changes 💔</h3>

* A custom decomposition can no longer be provided to `QDrift`. Instead, apply the operations in your custom
operation directly with `qml.apply`.
[(#5698)](https://github.com/PennyLaneAI/pennylane/pull/5698)

* Sampling observables composed of `X`, `Y`, `Z` and `Hadamard` now returns values of type `float` instead of `int`.
[(#5607)](https://github.com/PennyLaneAI/pennylane/pull/5607)

Expand Down Expand Up @@ -179,6 +186,12 @@

<h3>Bug fixes 🐛</h3>

* `QuantumPhaseEstimation.map_wires` on longer modifies the original operation instance.
[(#5698)](https://github.com/PennyLaneAI/pennylane/pull/5698)

* The decomposition of `AmplitudeAmplification` now correctly queues all operations.
[(#5698)](https://github.com/PennyLaneAI/pennylane/pull/5698)

* Replaced `semantic_version` with `packaging.version.Version`, since the former cannot
handle the metadata `.post` in the version string.
[(#5754)](https://github.com/PennyLaneAI/pennylane/pull/5754)
Expand Down
24 changes: 19 additions & 5 deletions pennylane/ops/functions/assert_valid.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def inner_func(*args, **kwargs):
return inner_func


def _check_decomposition(op):
def _check_decomposition(op, skip_wire_mapping):
"""Checks involving the decomposition."""
if op.has_decomposition:
decomp = op.decomposition()
Expand All @@ -64,6 +64,18 @@ def _check_decomposition(op):
assert o1 == o3, "decomposition must match queued operations"
assert o1 == o4, "decomposition must match expansion"
assert isinstance(o1, qml.operation.Operator), "decomposition must contain operators"

if skip_wire_mapping:
return
# Check that mapping wires transitions to the decomposition
wire_map = {w: ascii_lowercase[i] for i, w in enumerate(op.wires)}
mapped_op = op.map_wires(wire_map)
mapped_decomp = mapped_op.decomposition()
orig_decomp = op.decomposition()
for mapped_op, orig_op in zip(mapped_decomp, orig_decomp):
assert (
mapped_op.wires == qml.map_wires(orig_op, wire_map).wires
), "Operators in decomposition of wire-mapped operator must have mapped wires."
else:
failure_comment = "If has_decomposition is False, then decomposition must raise a ``DecompositionUndefinedError``."
_assert_error_raised(
Expand Down Expand Up @@ -216,17 +228,19 @@ def _check_bind_new_parameters(op):
assert qml.math.allclose(d1, d2), failure_comment


def _check_wires(op):
def _check_wires(op, skip_wire_mapping):
"""Check that wires are a ``Wires`` class and can be mapped."""
assert isinstance(op.wires, qml.wires.Wires), "wires must be a wires instance"

if skip_wire_mapping:
return
wire_map = {w: ascii_lowercase[i] for i, w in enumerate(op.wires)}
mapped_op = op.map_wires(wire_map)
new_wires = qml.wires.Wires(list(ascii_lowercase[: len(op.wires)]))
assert mapped_op.wires == new_wires, "wires must be mappable with map_wires"


def assert_valid(op: qml.operation.Operator, skip_pickle=False) -> None:
def assert_valid(op: qml.operation.Operator, skip_pickle=False, skip_wire_mapping=False) -> None:
"""Runs basic validation checks on an :class:`~.operation.Operator` to make
sure it has been correctly defined.
Expand Down Expand Up @@ -278,14 +292,14 @@ def __init__(self, wires):
assert qml.math.allclose(d, p), "data and parameters must match."

if len(op.wires) <= 26:
_check_wires(op)
_check_wires(op, skip_wire_mapping)
_check_copy(op)
_check_pytree(op)
if not skip_pickle:
_check_pickle(op)
_check_bind_new_parameters(op)

_check_decomposition(op)
_check_decomposition(op, skip_wire_mapping)
_check_matrix(op)
_check_matrix_matches_decomp(op)
_check_eigendecomposition(op)
10 changes: 10 additions & 0 deletions pennylane/ops/functions/bind_new_parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,16 @@ def bind_new_parameters_commuting_evolution(
return qml.CommutingEvolution(new_hamiltonian, time, frequencies=freq, shifts=shifts)


@bind_new_parameters.register
def bind_new_parameters_qdrift(op: qml.QDrift, params: Sequence[TensorLike]):
new_hamiltonian = bind_new_parameters(op.hyperparameters["base"], params[:-1])
time = params[-1]
n = op.hyperparameters["n"]
seed = op.hyperparameters["seed"]

return qml.QDrift(new_hamiltonian, time, n=n, seed=seed)


@bind_new_parameters.register
def bind_new_parameters_fermionic_double_excitation(
op: qml.FermionicDoubleExcitation, params: Sequence[TensorLike]
Expand Down
2 changes: 1 addition & 1 deletion pennylane/ops/qubit/arithmetic_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -513,7 +513,7 @@ def compute_decomposition(value, geq=True, wires=None, work_wires=None, **kwargs
small_val = not geq and value == 0
large_val = geq and value > 2 ** len(control_wires) - 1
if small_val or large_val:
gates = [Identity(0)]
gates = [Identity(wires[0])]

else:
values = range(value, 2 ** (len(control_wires))) if geq else range(value)
Expand Down
7 changes: 5 additions & 2 deletions pennylane/ops/qubit/matrix_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -396,10 +396,13 @@ def compute_decomposition(D, wires):
ops = [QubitUnitary(qml.math.tensordot(global_phase, qml.math.eye(2), axes=0), wires[0])]
for wire0 in range(n):
# Single PauliZ generators correspond to the coeffs at powers of two
ops.append(qml.RZ(coeffs[1 << wire0], n - 1 - wire0))
ops.append(qml.RZ(coeffs[1 << wire0], wires[n - 1 - wire0]))
# Double PauliZ generators correspond to the coeffs at the sum of two powers of two
ops.extend(
qml.IsingZZ(coeffs[(1 << wire0) + (1 << wire1)], [n - 1 - wire0, n - 1 - wire1])
qml.IsingZZ(
coeffs[(1 << wire0) + (1 << wire1)],
[wires[n - 1 - wire0], wires[n - 1 - wire1]],
)
for wire1 in range(wire0)
)

Expand Down
11 changes: 11 additions & 0 deletions pennylane/templates/embeddings/iqp.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,12 @@
Contains the IQPEmbedding template.
"""
# pylint: disable-msg=too-many-branches,too-many-arguments,protected-access
import copy
from itertools import combinations

import pennylane as qml
from pennylane.operation import AnyWires, Operation
from pennylane.wires import Wires


class IQPEmbedding(Operation):
Expand Down Expand Up @@ -186,6 +188,15 @@ def __init__(self, features, wires, n_repeats=1, pattern=None, id=None):

super().__init__(features, wires=wires, id=id)

def map_wires(self, wire_map):
# pylint: disable=protected-access
new_op = copy.deepcopy(self)
new_op._wires = Wires([wire_map.get(wire, wire) for wire in self.wires])
new_op._hyperparameters["pattern"] = [
[wire_map.get(w, w) for w in wires] for wires in new_op._hyperparameters["pattern"]
]
return new_op

@property
def num_params(self):
return 1
Expand Down
12 changes: 12 additions & 0 deletions pennylane/templates/subroutines/all_singles_doubles.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,14 @@
Contains the AllSinglesDoubles template.
"""
# pylint: disable-msg=too-many-branches,too-many-arguments,protected-access
import copy

import numpy as np

import pennylane as qml
from pennylane.operation import AnyWires, Operation
from pennylane.ops import BasisState
from pennylane.wires import Wires


class AllSinglesDoubles(Operation):
Expand Down Expand Up @@ -155,6 +158,15 @@ def __init__(self, weights, wires, hf_state, singles=None, doubles=None, id=None

super().__init__(weights, wires=wires, id=id)

def map_wires(self, wire_map: dict):
new_op = copy.deepcopy(self)
new_op._wires = Wires([wire_map.get(wire, wire) for wire in self.wires])
for key in ["singles", "doubles"]:
new_op._hyperparameters[key] = tuple(
tuple(wire_map[w] for w in wires) for wires in new_op._hyperparameters[key]
)
return new_op

@property
def num_params(self):
return 1
Expand Down
27 changes: 21 additions & 6 deletions pennylane/templates/subroutines/amplitude_amplification.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,13 @@
"""

# pylint: disable-msg=too-many-arguments
import copy

import numpy as np

import pennylane as qml
from pennylane.operation import Operation
from pennylane.wires import Wires


def _get_fixed_point_angles(iters, p_min):
Expand Down Expand Up @@ -101,16 +104,12 @@ def circuit():

def _flatten(self):
data = (self.hyperparameters["U"], self.hyperparameters["O"])
metadata = tuple(
(key, value) for key, value in self.hyperparameters.items() if key not in ["O", "U"]
)
metadata = tuple(item for item in self.hyperparameters.items() if item[0] not in ["O", "U"])
return data, metadata

@classmethod
def _unflatten(cls, data, metadata):
U, O = (data[0], data[1])
hyperparams_dict = dict(metadata)
return cls(U, O, **hyperparams_dict)
return cls(*data, **dict(metadata))

def __init__(
self, U, O, iters=1, fixed_point=False, work_wire=None, p_min=0.9, reflection_wires=None
Expand Down Expand Up @@ -169,10 +168,26 @@ def compute_decomposition(**kwargs):
else:
for _ in range(iters):
ops.append(O)
if qml.QueuingManager.recording():
qml.apply(O)
ops.append(qml.Reflection(U, np.pi, reflection_wires=reflection_wires))

return ops

def map_wires(self, wire_map: dict):
# pylint: disable=protected-access
new_op = copy.deepcopy(self)
new_op._wires = Wires([wire_map.get(wire, wire) for wire in self.wires])
new_op._hyperparameters["U"] = new_op._hyperparameters["U"].map_wires(wire_map)
new_op._hyperparameters["O"] = new_op._hyperparameters["O"].map_wires(wire_map)
new_op._hyperparameters["reflection_wires"] = Wires(
[wire_map.get(wire, wire) for wire in new_op._hyperparameters["reflection_wires"]]
)
new_op._hyperparameters["work_wire"] = wire_map.get(
w := new_op._hyperparameters["work_wire"], w
)
return new_op

def queue(self, context=qml.QueuingManager):
for op in [self.hyperparameters["U"], self.hyperparameters["O"]]:
context.remove(op)
Expand Down
11 changes: 11 additions & 0 deletions pennylane/templates/subroutines/approx_time_evolution.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,12 @@
Contains the ApproxTimeEvolution template.
"""
# pylint: disable-msg=too-many-branches,too-many-arguments,protected-access
import copy

import pennylane as qml
from pennylane.operation import AnyWires, Operation
from pennylane.ops import PauliRot
from pennylane.wires import Wires


class ApproxTimeEvolution(Operation):
Expand Down Expand Up @@ -139,6 +142,14 @@ def __init__(self, hamiltonian, time, n, id=None):
# trainable parameters are passed to the base init method
super().__init__(*hamiltonian.data, time, wires=wires, id=id)

def map_wires(self, wire_map: dict):
new_op = copy.deepcopy(self)
new_op._wires = Wires([wire_map.get(wire, wire) for wire in self.wires])
new_op._hyperparameters["hamiltonian"] = qml.map_wires(
new_op._hyperparameters["hamiltonian"], wire_map
)
return new_op

def queue(self, context=qml.QueuingManager):
context.remove(self.hyperparameters["hamiltonian"])
context.append(self)
Expand Down
12 changes: 12 additions & 0 deletions pennylane/templates/subroutines/commuting_evolution.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,11 @@
Contains the CommutingEvolution template.
"""
# pylint: disable-msg=too-many-arguments,import-outside-toplevel
import copy

import pennylane as qml
from pennylane.operation import AnyWires, Operation
from pennylane.wires import Wires


class CommutingEvolution(Operation):
Expand Down Expand Up @@ -139,6 +142,15 @@ def __init__(self, hamiltonian, time, frequencies=None, shifts=None, id=None):

super().__init__(time, *hamiltonian.parameters, wires=hamiltonian.wires, id=id)

def map_wires(self, wire_map: dict):
# pylint: disable=protected-access
new_op = copy.deepcopy(self)
new_op._wires = Wires([wire_map.get(wire, wire) for wire in self.wires])
new_op._hyperparameters["hamiltonian"] = qml.map_wires(
new_op._hyperparameters["hamiltonian"], wire_map
)
return new_op

def queue(self, context=qml.QueuingManager):
context.remove(self.hyperparameters["hamiltonian"])
context.append(self)
Expand Down
12 changes: 12 additions & 0 deletions pennylane/templates/subroutines/fermionic_double_excitation.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,14 @@
Contains the FermionicDoubleExcitation template.
"""
# pylint: disable-msg=too-many-branches,too-many-arguments,protected-access
import copy

import numpy as np

import pennylane as qml
from pennylane.operation import AnyWires, Operation
from pennylane.ops import CNOT, RX, RZ, Hadamard
from pennylane.wires import Wires


def _layer1(weight, s, r, q, p, set_cnot_wires):
Expand Down Expand Up @@ -532,6 +535,15 @@ def __init__(self, weight, wires1=None, wires2=None, id=None):
wires = wires1 + wires2
super().__init__(weight, wires=wires, id=id)

def map_wires(self, wire_map: dict):
new_op = copy.deepcopy(self)
new_op._wires = Wires([wire_map.get(wire, wire) for wire in self.wires])
for key in ["wires1", "wires2"]:
new_op._hyperparameters[key] = Wires(
[wire_map.get(wire, wire) for wire in self._hyperparameters[key]]
)
return new_op

@property
def num_params(self):
return 1
Expand Down
3 changes: 3 additions & 0 deletions pennylane/templates/subroutines/hilbert_schmidt.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,9 @@ def __init__(self, *params, v_function, v_wires, u_tape, id=None):

super().__init__(*params, wires=wires, id=id)

def map_wires(self, wire_map: dict):
raise NotImplementedError("Mapping the wires of HilbertSchmidt is not implemented.")

@property
def num_params(self):
return self._num_params
Expand Down
Loading

0 comments on commit 856b14e

Please sign in to comment.