From c50481c8839afbb25a901abdb88a934fd279a4ff Mon Sep 17 00:00:00 2001 From: paul0403 <79805239+paul0403@users.noreply.github.com> Date: Thu, 30 May 2024 10:53:12 -0400 Subject: [PATCH] patch `qml.matrix` to accept qnodes compiled by `catalyst.qjit` (#5753) **Context:** Currently `qml.matrix` breaks when the input is a catalyst-compiled QNode instead of a plain QNode, and the QNode need to be manually retrieved by the user: https://github.com/PennyLaneAI/catalyst/issues/765 It would be beneficial to do this "unwrap" in `qml.matrix` , like those in qml.draw: https://github.com/PennyLaneAI/pennylane/blob/59a1e0586e707d057a0c92d4239036afa5312b73/pennylane/drawer/draw.py#L213-L214 **Description of the Change:** 1. in `pennylane/ops/functions/matrix.py`, in `qml.matrix(op)`, if the input `op` is a catalyst.qjit compiled function, dispatches the behavior to be `qml.matrix(op.user_function)`. 2. Added a test in `tests/ops/functions/test_matrix.py` **Benefits:** a qjit compiled qnode can be passed into `qml.matrix` directly to query the matrix representation of the circuit **Related GitHub Issues:** https://github.com/PennyLaneAI/catalyst/issues/765 [sc-64247] --- doc/releases/changelog-dev.md | 4 ++++ pennylane/drawer/draw.py | 9 ++------- pennylane/ops/functions/matrix.py | 8 ++++++++ tests/ops/functions/test_matrix.py | 28 ++++++++++++++++++++++++++++ tests/pytest.ini | 1 + 5 files changed, 43 insertions(+), 7 deletions(-) diff --git a/doc/releases/changelog-dev.md b/doc/releases/changelog-dev.md index 8850f917fe4..20f1436e2e4 100644 --- a/doc/releases/changelog-dev.md +++ b/doc/releases/changelog-dev.md @@ -239,6 +239,9 @@ * Fixes a bug in `qml.math.dot` that raises an error when only one of the operands is a scalar. [(#5702)](https://github.com/PennyLaneAI/pennylane/pull/5702) +* `qml.matrix` is now compatible with qnodes compiled by catalyst.qjit. + [(#5753)](https://github.com/PennyLaneAI/pennylane/pull/5753) +

Contributors ✍️

This release contains contributions from (in alphabetical order): @@ -259,4 +262,5 @@ Vincent Michaud-Rioux, Lee James O'Riordan, Mudit Pandey, Kenya Sakka, +Haochen Paul Wang, David Wierichs. diff --git a/pennylane/drawer/draw.py b/pennylane/drawer/draw.py index 97ebf3084fb..3655f9b5cc8 100644 --- a/pennylane/drawer/draw.py +++ b/pennylane/drawer/draw.py @@ -18,7 +18,6 @@ """ import warnings from functools import wraps -from importlib.metadata import distribution import pennylane as qml @@ -27,12 +26,8 @@ def catalyst_qjit(qnode): - """The ``catalyst.while`` wrapper method""" - try: - distribution("pennylane_catalyst") - return qnode.__class__.__name__ == "QJIT" - except ImportError: - return False + """A method checking whether a qnode is compiled by catalyst.qjit""" + return qnode.__class__.__name__ == "QJIT" and hasattr(qnode, "user_function") def draw( diff --git a/pennylane/ops/functions/matrix.py b/pennylane/ops/functions/matrix.py index 035fb438a90..eafb8dc38a1 100644 --- a/pennylane/ops/functions/matrix.py +++ b/pennylane/ops/functions/matrix.py @@ -27,6 +27,11 @@ from pennylane.typing import TensorLike +def catalyst_qjit(qnode): + """A method checking whether a qnode is compiled by catalyst.qjit""" + return qnode.__class__.__name__ == "QJIT" and hasattr(qnode, "user_function") + + def matrix(op: Union[Operator, PauliWord, PauliSentence], wire_order=None) -> TensorLike: r"""The matrix representation of an operation or quantum circuit. @@ -177,6 +182,9 @@ def circuit(): wires specified, and this is the order in which wires appear in ``circuit()``. """ + if catalyst_qjit(op): + op = op.user_function + if not isinstance(op, Operator): if isinstance(op, (PauliWord, PauliSentence)): diff --git a/tests/ops/functions/test_matrix.py b/tests/ops/functions/test_matrix.py index 84c4e7bb9c0..e199032c2de 100644 --- a/tests/ops/functions/test_matrix.py +++ b/tests/ops/functions/test_matrix.py @@ -659,6 +659,34 @@ def circuit(theta): assert np.allclose(matrix, expected_matrix) + @pytest.mark.catalyst + @pytest.mark.external + def test_catalyst(self): + """Test with Catalyst interface""" + + import catalyst + + dev = qml.device("lightning.qubit", wires=1) + + # create a plain QNode + @qml.qnode(dev) + def f(): + qml.PauliX(0) + return qml.state() + + # create a qjit-compiled QNode by decorating a function + @catalyst.qjit + @qml.qnode(dev) + def g(): + qml.PauliX(0) + return qml.state() + + # create a qjit-compiled QNode by passing in the plain QNode directly + h = catalyst.qjit(f) + + assert np.allclose(f(), g(), h()) + assert np.allclose(qml.matrix(f)(), qml.matrix(g)(), qml.matrix(h)()) + @pytest.mark.jax def test_get_unitary_matrix_interface_jax(self): """Test with JAX interface""" diff --git a/tests/pytest.ini b/tests/pytest.ini index 4c2812d4a75..1897c3c65e4 100644 --- a/tests/pytest.ini +++ b/tests/pytest.ini @@ -18,6 +18,7 @@ markers = param-shift: marks tests for the parameter shift (deselect with '-m "not param-shift"') logging: marks tests for pennylane logging external: marks tests that require external packages such as matplotlib and PyZX + catalyst: marks tests for catalyst testing (select with '-m "catalyst"') filterwarnings = ignore::DeprecationWarning:autograd.numpy.numpy_wrapper ignore:Casting complex values to real::autograd.numpy.numpy_wrapper