Skip to content

Commit

Permalink
Plxpr integrates well with all templates (#5523)
Browse files Browse the repository at this point in the history
**To do**
- [x] Actually implement the custom binding functions
- [x] Write the custom binding tests
- [x] Write a test that asserts that all modified templates as
determined automatically have their own test.

**Context:**
Introducing the plxpr capturing mechanism requires us to pay attention
to operations that have special call signatures, such as templates.

**Description of the Change:**
This PR adds a test suite for templates.
For templates with special initialization signatures, a custom
`_primitive_bind_call` method is added, together with dedicated tests
for these custom methods.

**Benefits:**
Compatibility of Plxpr with the code stack

**Possible Drawbacks:**
More than 1 second of testing time added.
A lot of test cases
Future template additions _will_ have to write additional tests in
`tests/capture/test_templates.py` in the current test suite design.

**Related GitHub Issues:**


[sc-61425]

---------

Co-authored-by: albi3ro <chrissie.c.l@gmail.com>
Co-authored-by: Christina Lee <christina@xanadu.ai>
Co-authored-by: Thomas R. Bromley <49409390+trbromley@users.noreply.github.com>
Co-authored-by: Jay Soni <jbsoni@uwaterloo.ca>
  • Loading branch information
5 people authored May 31, 2024
1 parent 8adf29a commit cd52843
Show file tree
Hide file tree
Showing 25 changed files with 889 additions and 5 deletions.
1 change: 1 addition & 0 deletions doc/releases/changelog-dev.md
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,7 @@
* PennyLane operators and measurements can now automatically be captured as instructions in JAXPR.
[(#5564)](https://github.com/PennyLaneAI/pennylane/pull/5564)
[(#5511)](https://github.com/PennyLaneAI/pennylane/pull/5511)
[(#5523)](https://github.com/PennyLaneAI/pennylane/pull/5523)

* The `decompose` transform has an `error` kwarg to specify the type of error that should be raised,
allowing error types to be more consistent with the context the `decompose` function is used in.
Expand Down
2 changes: 1 addition & 1 deletion pennylane/capture/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ def qfunc(a):
(where ``cls`` indicates the class) if:
* The operator does not accept wires, like :class:`~.SymbolicOp` or :class:`~.CompositeOp`.
* The operator needs to enforce a data/ metadata distinction, like :class:`~.PauliRot`.
* The operator needs to enforce a data / metadata distinction, like :class:`~.PauliRot`.
In such cases, the operator developer can override ``cls._primitive_bind_call``, which
will be called when constructing a new class instance instead of ``type.__call__``. For example,
Expand Down
2 changes: 1 addition & 1 deletion pennylane/templates/layers/gate_fabric.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,7 +281,7 @@ def shape(n_layers, n_wires):

if n_wires < 4:
raise ValueError(
f"This template requires the number of qubits to be greater than four; got 'n_wires' = {n_wires}"
f"This template requires the number of qubits to be at least four; got 'n_wires' = {n_wires}"
)

if n_wires % 2:
Expand Down
2 changes: 1 addition & 1 deletion pennylane/templates/state_preparations/mottonen.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,7 +308,7 @@ def __init__(self, state_vector, wires, id=None):
norm = qml.math.sum(qml.math.abs(state) ** 2)
if not qml.math.allclose(norm, 1.0, atol=1e-3):
raise ValueError(
f"State vectors have to be of norm 1.0, vector {i} has norm {norm}"
f"State vectors have to be of norm 1.0, vector {i} has squared norm {norm}"
)

super().__init__(state_vector, wires=wires, id=id)
Expand Down
2 changes: 1 addition & 1 deletion pennylane/templates/subroutines/all_singles_doubles.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ def __init__(self, weights, wires, hf_state, singles=None, doubles=None, id=None
raise ValueError(f"'weights' tensor must be of shape {exp_shape}; got {weights_shape}.")

if hf_state[0].dtype != np.dtype("int"):
raise ValueError(f"Elements of 'hf_state' must be integers; got {hf_state.dtype}")
raise ValueError(f"Elements of 'hf_state' must be integers; got {hf_state[0].dtype}")

singles = tuple(tuple(s) for s in singles)
doubles = tuple(tuple(d) for d in doubles)
Expand Down
4 changes: 4 additions & 0 deletions pennylane/templates/subroutines/amplitude_amplification.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,10 @@ def _flatten(self):
metadata = tuple(item for item in self.hyperparameters.items() if item[0] not in ["O", "U"])
return data, metadata

@classmethod
def _primitive_bind_call(cls, *args, **kwargs):
return cls._primitive.bind(*args, **kwargs)

@classmethod
def _unflatten(cls, data, metadata):
return cls(*data, **dict(metadata))
Expand Down
4 changes: 4 additions & 0 deletions pennylane/templates/subroutines/approx_time_evolution.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,10 @@ def _flatten(self):
data = (h, self.data[-1])
return data, (self.hyperparameters["n"],)

@classmethod
def _primitive_bind_call(cls, *args, **kwargs):
return cls._primitive.bind(*args, **kwargs)

@classmethod
def _unflatten(cls, data, metadata):
return cls(data[0], data[1], n=metadata[0])
Expand Down
24 changes: 24 additions & 0 deletions pennylane/templates/subroutines/basis_rotation.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,15 @@ class BasisRotation(Operation):
num_wires = AnyWires
grad_method = None

@classmethod
def _primitive_bind_call(cls, wires, unitary_matrix, check=False, id=None):
# pylint: disable=arguments-differ
if cls._primitive is None:
# guard against this being called when primitive is not defined.
return type.__call__(cls, wires, unitary_matrix, check=check, id=id) # pragma: no cover

return cls._primitive.bind(*wires, unitary_matrix, check=check, id=id)

def __init__(self, wires, unitary_matrix, check=False, id=None):
M, N = unitary_matrix.shape
if M != N:
Expand Down Expand Up @@ -176,3 +185,18 @@ def compute_decomposition(
op_list.append(qml.PhaseShift(phi, wires=wires[indices[0]]))

return op_list


# Program capture needs to unpack and re-pack the wires to support dynamic wires. For
# BasisRotation, the unconventional argument ordering requires custom def_impl code.
# See capture module for more information on primitives
# If None, jax isn't installed so the class never got a primitive.
if BasisRotation._primitive is not None: # pylint: disable=protected-access

@BasisRotation._primitive.def_impl # pylint: disable=protected-access
def _(*args, **kwargs):
# If there are more than two args, we are calling with unpacked wires, so that
# we have to repack them. This replaces the n_wires logic in the general case.
if len(args) != 2:
args = (args[:-1], args[-1])
return type.__call__(BasisRotation, *args, **kwargs)
4 changes: 4 additions & 0 deletions pennylane/templates/subroutines/commuting_evolution.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,10 @@ def _flatten(self):
data = (self.data[0], h)
return data, (self.hyperparameters["frequencies"], self.hyperparameters["shifts"])

@classmethod
def _primitive_bind_call(cls, *args, **kwargs):
return cls._primitive.bind(*args, **kwargs)

@classmethod
def _unflatten(cls, data, metadata) -> "CommutingEvolution":
return cls(data[1], data[0], frequencies=metadata[0], shifts=metadata[1])
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -504,6 +504,10 @@ def circuit(weight, wires1=None, wires2=None):
def _flatten(self):
return self.data, (self.hyperparameters["wires1"], self.hyperparameters["wires2"])

@classmethod
def _primitive_bind_call(cls, *args, **kwargs):
return cls._primitive.bind(*args, **kwargs)

@classmethod
def _unflatten(cls, data, metadata) -> "FermionicDoubleExcitation":
return cls(data[0], wires1=metadata[0], wires2=metadata[1])
Expand Down
6 changes: 6 additions & 0 deletions pennylane/templates/subroutines/hilbert_schmidt.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,12 @@ def _flatten(self):
)
return self.data, metadata

@classmethod
def _primitive_bind_call(cls, *params, v_function, v_wires, u_tape, id=None):
# pylint: disable=arguments-differ
kwargs = {"v_function": v_function, "v_wires": v_wires, "u_tape": u_tape, "id": id}
return cls._primitive.bind(*params, **kwargs)

@classmethod
def _unflatten(cls, data, metadata):
return cls(*data, **dict(metadata))
Expand Down
4 changes: 4 additions & 0 deletions pennylane/templates/subroutines/qdrift.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,10 @@ def my_circ(time):
"""

@classmethod
def _primitive_bind_call(cls, *args, **kwargs):
return cls._primitive.bind(*args, **kwargs)

def _flatten(self):
h = self.hyperparameters["base"]
hashable_hyperparameters = tuple(
Expand Down
4 changes: 4 additions & 0 deletions pennylane/templates/subroutines/qmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -344,6 +344,10 @@ def circuit():
num_wires = AnyWires
grad_method = None

@classmethod
def _primitive_bind_call(cls, *args, **kwargs):
return cls._primitive.bind(*args, **kwargs)

@classmethod
def _unflatten(cls, data, metadata):
new_op = cls.__new__(cls)
Expand Down
4 changes: 4 additions & 0 deletions pennylane/templates/subroutines/qpe.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,10 @@ def _flatten(self):
metadata = (self.hyperparameters["estimation_wires"],)
return data, metadata

@classmethod
def _primitive_bind_call(cls, *args, **kwargs):
return cls._primitive.bind(*args, **kwargs)

@classmethod
def _unflatten(cls, data, metadata) -> "QuantumPhaseEstimation":
return cls(data[0], estimation_wires=metadata[0])
Expand Down
4 changes: 4 additions & 0 deletions pennylane/templates/subroutines/qsvt.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,6 +268,10 @@ def _flatten(self):
data = (self.hyperparameters["UA"], self.hyperparameters["projectors"])
return data, tuple()

@classmethod
def _primitive_bind_call(cls, *args, **kwargs):
return cls._primitive.bind(*args, **kwargs)

@classmethod
def _unflatten(cls, data, _) -> "QSVT":
return cls(*data)
Expand Down
4 changes: 4 additions & 0 deletions pennylane/templates/subroutines/qubitization.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,10 @@ def circuit():
eigenvalue: 0.7
"""

@classmethod
def _primitive_bind_call(cls, *args, **kwargs):
return cls._primitive.bind(*args, **kwargs)

def __init__(self, hamiltonian, control, id=None):
wires = hamiltonian.wires + qml.wires.Wires(control)

Expand Down
4 changes: 4 additions & 0 deletions pennylane/templates/subroutines/reflection.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,10 @@ def circuit():
"""

@classmethod
def _primitive_bind_call(cls, *args, **kwargs):
return cls._primitive.bind(*args, **kwargs)

def _flatten(self):
data = (self.hyperparameters["base"], self.parameters[0])
return data, (self.hyperparameters["reflection_wires"],)
Expand Down
4 changes: 4 additions & 0 deletions pennylane/templates/subroutines/select.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,10 @@ class Select(Operation):
def _flatten(self):
return (self.ops), (self.control)

@classmethod
def _primitive_bind_call(cls, *args, **kwargs):
return cls._primitive.bind(*args, **kwargs)

@classmethod
def _unflatten(cls, data, metadata) -> "Select":
return cls(data, metadata)
Expand Down
13 changes: 13 additions & 0 deletions pennylane/templates/tensornetworks/mera.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,19 @@ def circuit(template_weights):
def num_params(self):
return 1

@classmethod
def _primitive_bind_call(
cls, wires, n_block_wires, block, n_params_block, template_weights=None, id=None
): # pylint: disable=arguments-differ
return super()._primitive_bind_call(
wires=wires,
n_block_wires=n_block_wires,
block=block,
n_params_block=n_params_block,
template_weights=template_weights,
id=id,
)

@classmethod
def _unflatten(cls, data, metadata):
new_op = cls.__new__(cls)
Expand Down
23 changes: 23 additions & 0 deletions pennylane/templates/tensornetworks/mps.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,29 @@ def circuit():
num_wires = AnyWires
par_domain = "A"

@classmethod
def _primitive_bind_call(
cls,
wires,
n_block_wires,
block,
n_params_block,
template_weights=None,
offset=None,
id=None,
**kwargs,
): # pylint: disable=arguments-differ
return super()._primitive_bind_call(
wires=wires,
n_block_wires=n_block_wires,
block=block,
n_params_block=n_params_block,
template_weights=template_weights,
id=id,
offset=offset,
**kwargs,
)

@classmethod
def _unflatten(cls, data, metadata):
new_op = cls.__new__(cls)
Expand Down
13 changes: 13 additions & 0 deletions pennylane/templates/tensornetworks/ttn.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,19 @@ def circuit(template_weights):
def num_params(self):
return 1

@classmethod
def _primitive_bind_call(
cls, wires, n_block_wires, block, n_params_block, template_weights=None, id=None
): # pylint: disable=arguments-differ
return super()._primitive_bind_call(
wires=wires,
n_block_wires=n_block_wires,
block=block,
n_params_block=n_params_block,
template_weights=template_weights,
id=id,
)

@classmethod
def _unflatten(cls, data, metadata):
new_op = cls.__new__(cls)
Expand Down
Loading

0 comments on commit cd52843

Please sign in to comment.