From c50481c8839afbb25a901abdb88a934fd279a4ff Mon Sep 17 00:00:00 2001
From: paul0403 <79805239+paul0403@users.noreply.github.com>
Date: Thu, 30 May 2024 10:53:12 -0400
Subject: [PATCH] patch `qml.matrix` to accept qnodes compiled by
`catalyst.qjit` (#5753)
**Context:** Currently `qml.matrix`
breaks when the input is a catalyst-compiled QNode instead of a plain
QNode, and the QNode need to be manually retrieved by the user:
https://github.com/PennyLaneAI/catalyst/issues/765
It would be beneficial to do this "unwrap" in `qml.matrix` , like those
in qml.draw:
https://github.com/PennyLaneAI/pennylane/blob/59a1e0586e707d057a0c92d4239036afa5312b73/pennylane/drawer/draw.py#L213-L214
**Description of the Change:**
1. in `pennylane/ops/functions/matrix.py`, in `qml.matrix(op)`, if the
input `op` is a catalyst.qjit compiled function, dispatches the behavior
to be `qml.matrix(op.user_function)`.
2. Added a test in `tests/ops/functions/test_matrix.py`
**Benefits:** a qjit compiled qnode can be passed into `qml.matrix`
directly to query the matrix representation of the circuit
**Related GitHub Issues:**
https://github.com/PennyLaneAI/catalyst/issues/765
[sc-64247]
---
doc/releases/changelog-dev.md | 4 ++++
pennylane/drawer/draw.py | 9 ++-------
pennylane/ops/functions/matrix.py | 8 ++++++++
tests/ops/functions/test_matrix.py | 28 ++++++++++++++++++++++++++++
tests/pytest.ini | 1 +
5 files changed, 43 insertions(+), 7 deletions(-)
diff --git a/doc/releases/changelog-dev.md b/doc/releases/changelog-dev.md
index 8850f917fe4..20f1436e2e4 100644
--- a/doc/releases/changelog-dev.md
+++ b/doc/releases/changelog-dev.md
@@ -239,6 +239,9 @@
* Fixes a bug in `qml.math.dot` that raises an error when only one of the operands is a scalar.
[(#5702)](https://github.com/PennyLaneAI/pennylane/pull/5702)
+* `qml.matrix` is now compatible with qnodes compiled by catalyst.qjit.
+ [(#5753)](https://github.com/PennyLaneAI/pennylane/pull/5753)
+
Contributors ✍️
This release contains contributions from (in alphabetical order):
@@ -259,4 +262,5 @@ Vincent Michaud-Rioux,
Lee James O'Riordan,
Mudit Pandey,
Kenya Sakka,
+Haochen Paul Wang,
David Wierichs.
diff --git a/pennylane/drawer/draw.py b/pennylane/drawer/draw.py
index 97ebf3084fb..3655f9b5cc8 100644
--- a/pennylane/drawer/draw.py
+++ b/pennylane/drawer/draw.py
@@ -18,7 +18,6 @@
"""
import warnings
from functools import wraps
-from importlib.metadata import distribution
import pennylane as qml
@@ -27,12 +26,8 @@
def catalyst_qjit(qnode):
- """The ``catalyst.while`` wrapper method"""
- try:
- distribution("pennylane_catalyst")
- return qnode.__class__.__name__ == "QJIT"
- except ImportError:
- return False
+ """A method checking whether a qnode is compiled by catalyst.qjit"""
+ return qnode.__class__.__name__ == "QJIT" and hasattr(qnode, "user_function")
def draw(
diff --git a/pennylane/ops/functions/matrix.py b/pennylane/ops/functions/matrix.py
index 035fb438a90..eafb8dc38a1 100644
--- a/pennylane/ops/functions/matrix.py
+++ b/pennylane/ops/functions/matrix.py
@@ -27,6 +27,11 @@
from pennylane.typing import TensorLike
+def catalyst_qjit(qnode):
+ """A method checking whether a qnode is compiled by catalyst.qjit"""
+ return qnode.__class__.__name__ == "QJIT" and hasattr(qnode, "user_function")
+
+
def matrix(op: Union[Operator, PauliWord, PauliSentence], wire_order=None) -> TensorLike:
r"""The matrix representation of an operation or quantum circuit.
@@ -177,6 +182,9 @@ def circuit():
wires specified, and this is the order in which wires appear in ``circuit()``.
"""
+ if catalyst_qjit(op):
+ op = op.user_function
+
if not isinstance(op, Operator):
if isinstance(op, (PauliWord, PauliSentence)):
diff --git a/tests/ops/functions/test_matrix.py b/tests/ops/functions/test_matrix.py
index 84c4e7bb9c0..e199032c2de 100644
--- a/tests/ops/functions/test_matrix.py
+++ b/tests/ops/functions/test_matrix.py
@@ -659,6 +659,34 @@ def circuit(theta):
assert np.allclose(matrix, expected_matrix)
+ @pytest.mark.catalyst
+ @pytest.mark.external
+ def test_catalyst(self):
+ """Test with Catalyst interface"""
+
+ import catalyst
+
+ dev = qml.device("lightning.qubit", wires=1)
+
+ # create a plain QNode
+ @qml.qnode(dev)
+ def f():
+ qml.PauliX(0)
+ return qml.state()
+
+ # create a qjit-compiled QNode by decorating a function
+ @catalyst.qjit
+ @qml.qnode(dev)
+ def g():
+ qml.PauliX(0)
+ return qml.state()
+
+ # create a qjit-compiled QNode by passing in the plain QNode directly
+ h = catalyst.qjit(f)
+
+ assert np.allclose(f(), g(), h())
+ assert np.allclose(qml.matrix(f)(), qml.matrix(g)(), qml.matrix(h)())
+
@pytest.mark.jax
def test_get_unitary_matrix_interface_jax(self):
"""Test with JAX interface"""
diff --git a/tests/pytest.ini b/tests/pytest.ini
index 4c2812d4a75..1897c3c65e4 100644
--- a/tests/pytest.ini
+++ b/tests/pytest.ini
@@ -18,6 +18,7 @@ markers =
param-shift: marks tests for the parameter shift (deselect with '-m "not param-shift"')
logging: marks tests for pennylane logging
external: marks tests that require external packages such as matplotlib and PyZX
+ catalyst: marks tests for catalyst testing (select with '-m "catalyst"')
filterwarnings =
ignore::DeprecationWarning:autograd.numpy.numpy_wrapper
ignore:Casting complex values to real::autograd.numpy.numpy_wrapper