Skip to content

Commit

Permalink
patch qml.matrix to accept qnodes compiled by catalyst.qjit (#5753)
Browse files Browse the repository at this point in the history
**Context:** Currently `qml.matrix`
 breaks when the input is a catalyst-compiled QNode instead of a plain
QNode, and the QNode need to be manually retrieved by the user: 
PennyLaneAI/catalyst#765

It would be beneficial to do this "unwrap" in `qml.matrix` , like those
in qml.draw:

https://github.com/PennyLaneAI/pennylane/blob/59a1e0586e707d057a0c92d4239036afa5312b73/pennylane/drawer/draw.py#L213-L214

**Description of the Change:** 
1. in `pennylane/ops/functions/matrix.py`, in `qml.matrix(op)`, if the
input `op` is a catalyst.qjit compiled function, dispatches the behavior
to be `qml.matrix(op.user_function)`.
2. Added a test in `tests/ops/functions/test_matrix.py`

**Benefits:** a qjit compiled qnode can be passed into `qml.matrix`
directly to query the matrix representation of the circuit


**Related GitHub Issues:**
PennyLaneAI/catalyst#765

[sc-64247]
  • Loading branch information
paul0403 committed May 30, 2024
1 parent 7a4a92a commit c50481c
Show file tree
Hide file tree
Showing 5 changed files with 43 additions and 7 deletions.
4 changes: 4 additions & 0 deletions doc/releases/changelog-dev.md
Original file line number Diff line number Diff line change
Expand Up @@ -239,6 +239,9 @@
* Fixes a bug in `qml.math.dot` that raises an error when only one of the operands is a scalar.
[(#5702)](https://github.com/PennyLaneAI/pennylane/pull/5702)

* `qml.matrix` is now compatible with qnodes compiled by catalyst.qjit.
[(#5753)](https://github.com/PennyLaneAI/pennylane/pull/5753)

<h3>Contributors ✍️</h3>

This release contains contributions from (in alphabetical order):
Expand All @@ -259,4 +262,5 @@ Vincent Michaud-Rioux,
Lee James O'Riordan,
Mudit Pandey,
Kenya Sakka,
Haochen Paul Wang,
David Wierichs.
9 changes: 2 additions & 7 deletions pennylane/drawer/draw.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
"""
import warnings
from functools import wraps
from importlib.metadata import distribution

import pennylane as qml

Expand All @@ -27,12 +26,8 @@


def catalyst_qjit(qnode):
"""The ``catalyst.while`` wrapper method"""
try:
distribution("pennylane_catalyst")
return qnode.__class__.__name__ == "QJIT"
except ImportError:
return False
"""A method checking whether a qnode is compiled by catalyst.qjit"""
return qnode.__class__.__name__ == "QJIT" and hasattr(qnode, "user_function")


def draw(
Expand Down
8 changes: 8 additions & 0 deletions pennylane/ops/functions/matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,11 @@
from pennylane.typing import TensorLike


def catalyst_qjit(qnode):
"""A method checking whether a qnode is compiled by catalyst.qjit"""
return qnode.__class__.__name__ == "QJIT" and hasattr(qnode, "user_function")


def matrix(op: Union[Operator, PauliWord, PauliSentence], wire_order=None) -> TensorLike:
r"""The matrix representation of an operation or quantum circuit.
Expand Down Expand Up @@ -177,6 +182,9 @@ def circuit():
wires specified, and this is the order in which wires appear in ``circuit()``.
"""
if catalyst_qjit(op):
op = op.user_function

if not isinstance(op, Operator):

if isinstance(op, (PauliWord, PauliSentence)):
Expand Down
28 changes: 28 additions & 0 deletions tests/ops/functions/test_matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -659,6 +659,34 @@ def circuit(theta):

assert np.allclose(matrix, expected_matrix)

@pytest.mark.catalyst
@pytest.mark.external
def test_catalyst(self):
"""Test with Catalyst interface"""

import catalyst

dev = qml.device("lightning.qubit", wires=1)

# create a plain QNode
@qml.qnode(dev)
def f():
qml.PauliX(0)
return qml.state()

# create a qjit-compiled QNode by decorating a function
@catalyst.qjit
@qml.qnode(dev)
def g():
qml.PauliX(0)
return qml.state()

# create a qjit-compiled QNode by passing in the plain QNode directly
h = catalyst.qjit(f)

assert np.allclose(f(), g(), h())
assert np.allclose(qml.matrix(f)(), qml.matrix(g)(), qml.matrix(h)())

@pytest.mark.jax
def test_get_unitary_matrix_interface_jax(self):
"""Test with JAX interface"""
Expand Down
1 change: 1 addition & 0 deletions tests/pytest.ini
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ markers =
param-shift: marks tests for the parameter shift (deselect with '-m "not param-shift"')
logging: marks tests for pennylane logging
external: marks tests that require external packages such as matplotlib and PyZX
catalyst: marks tests for catalyst testing (select with '-m "catalyst"')
filterwarnings =
ignore::DeprecationWarning:autograd.numpy.numpy_wrapper
ignore:Casting complex values to real::autograd.numpy.numpy_wrapper
Expand Down

0 comments on commit c50481c

Please sign in to comment.