diff --git a/doc/releases/changelog-dev.md b/doc/releases/changelog-dev.md
index 73aa52d3d6e..6ba157a4efa 100644
--- a/doc/releases/changelog-dev.md
+++ b/doc/releases/changelog-dev.md
@@ -95,6 +95,9 @@
Bug fixes 🐛
+* `qml.Projector` is now compatible with jax-jit.
+ [(#5595)](https://github.com/PennyLaneAI/pennylane/pull/5595)
+
* Finite shot circuits with a `qml.probs` measurement, both with a `wires` or `op` argument, can now be compiled with `jax.jit`.
[(#5619)](https://github.com/PennyLaneAI/pennylane/pull/5619)
diff --git a/pennylane/math/single_dispatch.py b/pennylane/math/single_dispatch.py
index f1ed70589a3..acb18879e03 100644
--- a/pennylane/math/single_dispatch.py
+++ b/pennylane/math/single_dispatch.py
@@ -243,6 +243,7 @@ def _take_autograd(tensor, indices, axis=None):
ar.autoray._SUBMODULE_ALIASES["tensorflow", "isclose"] = "tensorflow.experimental.numpy"
ar.autoray._SUBMODULE_ALIASES["tensorflow", "atleast_1d"] = "tensorflow.experimental.numpy"
ar.autoray._SUBMODULE_ALIASES["tensorflow", "all"] = "tensorflow.experimental.numpy"
+ar.autoray._SUBMODULE_ALIASES["tensorflow", "vstack"] = "tensorflow.experimental.numpy"
tf_fft_functions = [
"fft",
diff --git a/pennylane/ops/op_math/composite.py b/pennylane/ops/op_math/composite.py
index 400535c4bd7..c2a7573bfbd 100644
--- a/pennylane/ops/op_math/composite.py
+++ b/pennylane/ops/op_math/composite.py
@@ -175,8 +175,9 @@ def eigvals(self):
list(self.wires),
)
)
-
- return self._math_op(math.asarray(eigvals, like=math.get_deep_interface(eigvals)), axis=0)
+ framework = math.get_deep_interface(eigvals)
+ eigvals = [math.asarray(ei, like=framework) for ei in eigvals]
+ return self._math_op(math.vstack(eigvals), axis=0)
@abc.abstractmethod
def matrix(self, wire_order=None):
diff --git a/pennylane/ops/qubit/observables.py b/pennylane/ops/qubit/observables.py
index 5d4bf5d19b6..b66e2dbc16b 100644
--- a/pennylane/ops/qubit/observables.py
+++ b/pennylane/ops/qubit/observables.py
@@ -435,14 +435,23 @@ class BasisStateProjector(Projector, Operation):
r"""Observable corresponding to the state projector :math:`P=\ket{\phi}\bra{\phi}`, where
:math:`\phi` denotes a basis state."""
+ grad_method = None
+
# The call signature should be the same as Projector.__new__ for the positional
# arguments, but with free key word arguments.
def __init__(self, state, wires, id=None):
wires = Wires(wires)
- state = list(qml.math.toarray(state).astype(int))
- if not set(state).issubset({0, 1}):
- raise ValueError(f"Basis state must only consist of 0s and 1s; got {state}")
+ if qml.math.get_interface(state) == "jax":
+ dtype = qml.math.dtype(state)
+ if not (np.issubdtype(dtype, np.integer) or np.issubdtype(dtype, bool)):
+ raise ValueError("Basis state must consist of integers or booleans.")
+ else:
+ # state is index into data, rather than data, so cast it to built-ins when
+ # no need for tracing
+ state = tuple(qml.math.toarray(state).astype(int))
+ if not set(state).issubset({0, 1}):
+ raise ValueError(f"Basis state must only consist of 0s and 1s; got {state}")
super().__init__(state, wires=wires, id=id)
@@ -471,7 +480,7 @@ def label(self, decimals=None, base_label=None, cache=None):
if base_label is not None:
return base_label
- basis_string = "".join(str(int(i)) for i in self.parameters[0])
+ basis_string = "".join(str(int(i)) for i in self.data[0])
return f"|{basis_string}⟩⟨{basis_string}|"
@staticmethod
@@ -497,7 +506,15 @@ def compute_matrix(basis_state): # pylint: disable=arguments-differ
[0. 0. 0. 0.]
[0. 0. 0. 0.]]
"""
- m = np.zeros((2 ** len(basis_state), 2 ** len(basis_state)))
+ shape = (2 ** len(basis_state), 2 ** len(basis_state))
+ if qml.math.get_interface(basis_state) == "jax":
+ idx = 0
+ for i, m in enumerate(basis_state):
+ idx = idx + (m << (len(basis_state) - i - 1))
+ mat = qml.math.zeros(shape, like=basis_state)
+ return mat.at[idx, idx].set(1.0)
+
+ m = np.zeros(shape)
idx = int("".join(str(i) for i in basis_state), 2)
m[idx, idx] = 1
return m
@@ -528,6 +545,12 @@ def compute_eigvals(basis_state): # pylint: disable=arguments-differ
>>> BasisStateProjector.compute_eigvals([0, 1])
[0. 1. 0. 0.]
"""
+ if qml.math.get_interface(basis_state) == "jax":
+ idx = 0
+ for i, m in enumerate(basis_state):
+ idx = idx + (m << (len(basis_state) - i - 1))
+ eigvals = qml.math.zeros(2 ** len(basis_state), like=basis_state)
+ return eigvals.at[idx].set(1.0)
w = np.zeros(2 ** len(basis_state))
idx = int("".join(str(i) for i in basis_state), 2)
w[idx] = 1
@@ -566,12 +589,12 @@ class StateVectorProjector(Projector):
r"""Observable corresponding to the state projector :math:`P=\ket{\phi}\bra{\phi}`, where
:math:`\phi` denotes a state."""
+ grad_method = None
+
# The call signature should be the same as Projector.__new__ for the positional
# arguments, but with free key word arguments.
def __init__(self, state, wires, id=None):
wires = Wires(wires)
- state = list(qml.math.toarray(state))
-
super().__init__(state, wires=wires, id=id)
def __new__(cls, *_, **__): # pylint: disable=arguments-differ
@@ -680,9 +703,11 @@ def compute_eigvals(state_vector): # pylint: disable=arguments-differ,arguments
>>> StateVectorProjector.compute_eigvals([0, 0, 1, 0])
array([1, 0, 0, 0])
"""
- w = qml.math.zeros_like(state_vector)
+ precision = qml.math.get_dtype_name(state_vector)[-2:]
+ dtype = f"float{precision}" if precision in {"32", "64"} else "float64"
+ w = np.zeros(qml.math.shape(state_vector), dtype=dtype)
w[0] = 1
- return w
+ return qml.math.convert_like(w, state_vector)
@staticmethod
def compute_diagonalizing_gates(
@@ -715,10 +740,27 @@ def compute_diagonalizing_gates(
# Adapting the approach discussed in the link below to work with arbitrary complex-valued state vectors.
# Alternatively, we could take the adjoint of the Mottonen decomposition for the state vector.
# https://quantumcomputing.stackexchange.com/questions/10239/how-can-i-fill-a-unitary-knowing-only-its-first-column
- phase = qml.math.exp(-1j * qml.math.angle(state_vector[0]))
+
+ if qml.math.get_interface(state_vector) == "tensorflow":
+ dtype_name = qml.math.get_dtype_name(state_vector)
+ if dtype_name == "int32":
+ state_vector = qml.math.cast(state_vector, np.complex64)
+ elif dtype_name == "int64":
+ state_vector = qml.math.cast(state_vector, np.complex128)
+
+ angle = qml.math.angle(state_vector[0])
+ if qml.math.get_interface(angle) == "tensorflow":
+ if qml.math.get_dtype_name(angle) == "float32":
+ angle = qml.math.cast(angle, np.complex64)
+ else:
+ angle = qml.math.cast(angle, np.complex128)
+
+ phase = qml.math.exp(-1.0j * angle)
psi = phase * state_vector
denominator = qml.math.sqrt(2 + 2 * psi[0])
- psi = qml.math.set_index(psi, 0, psi[0] + 1) # psi[0] += 1, but JAX-JIT compatible
+ summed_array = np.zeros(qml.math.shape(psi), dtype=qml.math.get_dtype_name(psi))
+ summed_array[0] = 1.0
+ psi = psi + summed_array
psi /= denominator
u = 2 * qml.math.outer(psi, qml.math.conj(psi)) - qml.math.eye(len(psi))
return [QubitUnitary(u, wires=wires)]
diff --git a/pennylane/utils.py b/pennylane/utils.py
index ac22de48d76..40cd7c4f1b2 100644
--- a/pennylane/utils.py
+++ b/pennylane/utils.py
@@ -205,4 +205,4 @@ def expand_vector(vector, original_wires, expanded_wires):
expanded_tensor, tuple(original_indices), tuple(wire_indices)
)
- return qml.math.reshape(expanded_tensor, qudit_order**M)
+ return qml.math.reshape(expanded_tensor, (qudit_order**M,))
diff --git a/tests/interfaces/default_qubit_2_integration/test_tensorflow_qnode_default_qubit_2.py b/tests/interfaces/default_qubit_2_integration/test_tensorflow_qnode_default_qubit_2.py
index 772aa0831c9..e0f032b99cb 100644
--- a/tests/interfaces/default_qubit_2_integration/test_tensorflow_qnode_default_qubit_2.py
+++ b/tests/interfaces/default_qubit_2_integration/test_tensorflow_qnode_default_qubit_2.py
@@ -14,7 +14,7 @@
"""Integration tests for using the TensorFlow interface with a QNode"""
import numpy as np
-# pylint: disable=too-many-arguments,too-few-public-methods,comparison-with-callable
+# pylint: disable=too-many-arguments,too-few-public-methods,comparison-with-callable, use-implicit-booleaness-not-comparison
import pytest
import pennylane as qml
@@ -955,8 +955,9 @@ def cost_fn(x, y):
assert np.allclose(grad, expected, atol=tol, rtol=0)
@pytest.mark.parametrize("state", [[1], [0, 1]]) # Basis state and state vector
+ @pytest.mark.parametrize("dtype", ("int32", "int64"))
def test_projector(
- self, state, dev, diff_method, grad_on_execution, device_vjp, tol, interface
+ self, state, dev, diff_method, grad_on_execution, device_vjp, tol, interface, dtype
):
"""Test that the variance of a projector is correctly returned"""
kwargs = {
@@ -974,7 +975,7 @@ def test_projector(
kwargs["num_directions"] = 20
tol = TOL_FOR_SPSA
- P = tf.constant(state)
+ P = tf.constant(state, dtype=dtype)
x, y = 0.765, -0.654
weights = tf.Variable([x, y], dtype=tf.float64)
diff --git a/tests/ops/qubit/test_observables.py b/tests/ops/qubit/test_observables.py
index e165bdd4570..c5efbfbaadf 100644
--- a/tests/ops/qubit/test_observables.py
+++ b/tests/ops/qubit/test_observables.py
@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Unit tests for qubit observables."""
-# pylint: disable=protected-access
+# pylint: disable=protected-access, use-implicit-booleaness-not-comparison
import functools
import pickle
@@ -540,6 +540,44 @@ def test_serialization(self):
assert qml.equal(new_proj, proj)
assert new_proj.id == proj.id # Ensure they are identical
+ @pytest.mark.jax
+ def test_jit_measurement(self):
+ """Test that the measurement of a projector can be jitted."""
+
+ import jax
+
+ @jax.jit
+ @qml.qnode(qml.device("default.qubit"))
+ def circuit(state):
+ qml.X(1)
+ return qml.expval(qml.Projector(state, wires=(0, 1)))
+
+ state00 = jax.numpy.array([0, 0])
+ out00 = circuit(state00)
+ assert qml.math.allclose(out00, 0)
+ state01 = jax.numpy.array([0, 1])
+ out01 = circuit(state01)
+ assert qml.math.allclose(out01, 1)
+ state10 = jax.numpy.array([True, False])
+ out10 = circuit(state10)
+ assert qml.math.allclose(out10, 0)
+
+ with pytest.raises(ValueError, match=r"Basis state must consist of integers or booleans."):
+ circuit(jax.numpy.array([0.5, 0.6]))
+
+ @pytest.mark.jax
+ def test_jit_matrix(self):
+ """Test that computing the matrix of a projector is jittable."""
+
+ import jax
+
+ basis_state = jax.numpy.array([0, 1])
+ f = jax.jit(BasisStateProjector.compute_matrix)
+ out = f(basis_state)
+
+ expected = np.array([[0, 0, 0, 0], [0, 1, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0]])
+ assert qml.math.allclose(out, expected)
+
class TestBasisStateProjector:
"""Tests for the basis state projector observable."""
@@ -716,6 +754,24 @@ def test_label_matrices_list_in_cache(self, projector):
assert len(cache["matrices"]) == 2
assert np.allclose(cache["matrices"][1], projector.parameters[0])
+ @pytest.mark.jax
+ def test_jit_execution(self):
+ """Test that executing a StateVectorProjector can be jitted."""
+
+ import jax
+
+ @jax.jit
+ @qml.qnode(qml.device("default.qubit"))
+ def circuit(state):
+ return qml.expval(qml.Projector(state, wires=(0, 1)))
+
+ basis_state = jax.numpy.array([1, 1, 1, 1.0]) / 2
+ out = circuit(basis_state)
+ assert qml.math.allclose(out, 0.25)
+
+ basis_state2 = jax.numpy.array([0, 0, 0, 0])
+ assert qml.math.allclose(circuit(basis_state2), 0)
+
label_data = [
(qml.Hermitian(np.eye(2), wires=1), "𝓗"),