Skip to content

Commit

Permalink
Support custom QNode transforms in dispatcher (#4466)
Browse files Browse the repository at this point in the history
* Support custom qnode transforms

* pylint

* changelog

* pylint

* Add to argument

* Revert "Add to argument"

This reverts commit 7d4e2c3.

* remove if empty check

---------

Co-authored-by: Romain Moyard <rmoyard@gmail.com>
  • Loading branch information
eddddddy and rmoyard committed Aug 17, 2023
1 parent 58db9aa commit edefd31
Show file tree
Hide file tree
Showing 3 changed files with 98 additions and 14 deletions.
3 changes: 3 additions & 0 deletions doc/releases/changelog-dev.md
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,9 @@ array([False, False])
and return a new batch of circuits and a single post processing function.
[(#4364)](https://github.com/PennyLaneAI/pennylane/pull/4364)

* `TransformDispatcher` now allows registration of custom `QNode` transforms.
[(#4466)](https://github.com/PennyLaneAI/pennylane/pull/4466)

* `HardwareHamiltonian`s can now be summed with `int` or `float`.
A sequence of `HardwareHamiltonian`s can now be summed via the builtin `sum`.
[(#4343)](https://github.com/PennyLaneAI/pennylane/pull/4343)
Expand Down
65 changes: 51 additions & 14 deletions pennylane/transforms/core/transform_dispatcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
This module contains the transform function, the transform dispatcher and the transform container.
"""
import copy
import types

import pennylane as qml


Expand Down Expand Up @@ -48,6 +50,8 @@ def __init__(
self._classical_cotransform = classical_cotransform
self._is_informative = is_informative

self._qnode_transform = self.default_qnode_transform

def __call__(self, *targs, **tkwargs):
obj = None

Expand Down Expand Up @@ -92,6 +96,53 @@ def is_informative(self):
"""Return True is the transform does not need to be executed."""
return self._is_informative

def custom_qnode_transform(self, fn):
"""Register a custom QNode execution wrapper function
for the batch transform.
**Example**
.. code-block:: python
@transform
def my_transform(tape, *targs, **tkwargs):
...
return tapes, processing_fn
@my_transform.custom_qnode_transform
def my_custom_qnode_wrapper(self, qnode, targs, tkwargs):
tkwargs = {**tkwargs, shots=100}
return self.default_qnode_transform(qnode, targs, tkwargs)
The custom QNode execution wrapper must have arguments
``self`` (the batch transform object), ``qnode`` (the input QNode
to transform and execute), ``targs`` and ``tkwargs`` (the transform
arguments and keyword arguments respectively).
It should return a QNode that accepts the *same* arguments as the
input QNode with the transform applied.
The default :meth:`~.default_qnode_transform` method may be called
if only pre- or post-processing dependent on QNode arguments is required.
"""
self._qnode_transform = types.MethodType(fn, self)

def default_qnode_transform(self, qnode, targs, tkwargs):
"""
The default method that takes in a QNode and returns another QNode
with the transform applied.
"""
qnode = copy.deepcopy(qnode)

if self.expand_transform:
qnode.add_transform(TransformContainer(self._expand_transform))
qnode.add_transform(
TransformContainer(
self._transform, targs, tkwargs, self._classical_cotransform, self._is_informative
)
)
return qnode

def _qfunc_transform(self, qfunc, targs, tkwargs):
"""Apply the transform on a quantum function."""

Expand All @@ -114,20 +165,6 @@ def qfunc_transformed(*args, **kwargs):

return qfunc_transformed

def _qnode_transform(self, qnode, targs, tkwargs):
"""Apply the transform on a QNode. It populates the transform program of a QNode"""
if qnode.transform_program.is_empty():
qnode = copy.deepcopy(qnode)

if self.expand_transform:
qnode.add_transform(TransformContainer(self._expand_transform))
qnode.add_transform(
TransformContainer(
self._transform, targs, tkwargs, self._classical_cotransform, self._is_informative
)
)
return qnode


class TransformContainer:
"""Class to store a quantum transform with its args, kwargs and classical co-transforms. Use
Expand Down
44 changes: 44 additions & 0 deletions tests/transforms/test_experimental/test_transform_dispatcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -385,3 +385,47 @@ def test_the_transform_container_attributes(self):
assert not container.kwargs
assert container.classical_cotransform is None
assert not container.is_informative

@pytest.mark.parametrize("valid_transform", valid_transforms)
def test_custom_qnode_transform(self, valid_transform):
"""Test that the custom qnode transform is correctly executed"""

dispatched_transform = transform(valid_transform)

history = []

@dispatched_transform.custom_qnode_transform
def _custom_qnode_transform(self, qnode, targs, tkwargs):
history.append((targs, tkwargs))
return self.default_qnode_transform(qnode, targs, tkwargs)

@partial(dispatched_transform, index=0)
@qml.qnode(dev)
def qnode1():
"""QNode circuit."""
qml.Hadamard(wires=0)
return qml.expval(qml.PauliZ(wires=0))

assert isinstance(qnode1, qml.QNode)
assert isinstance(qnode1.transform_program, qml.transforms.core.TransformProgram)
assert isinstance(
qnode1.transform_program.pop_front(), qml.transforms.core.TransformContainer
)

@qml.qnode(dev)
def qnode2():
"""QNode circuit."""
qml.Hadamard(wires=0)
qml.CNOT(wires=[0, 1])
return qml.expval(qml.PauliZ(wires=0))

qnode2 = dispatched_transform(qnode2, 1)

assert isinstance(qnode2, qml.QNode)
assert isinstance(qnode2.transform_program, qml.transforms.core.TransformProgram)
assert isinstance(
qnode2.transform_program.pop_front(), qml.transforms.core.TransformContainer
)

# check that the custom qnode transform was called
assert history == [([], {"index": 0}), ([1], {})]

0 comments on commit edefd31

Please sign in to comment.