diff --git a/doc/releases/changelog-dev.md b/doc/releases/changelog-dev.md index 73aa52d3d6e..6ba157a4efa 100644 --- a/doc/releases/changelog-dev.md +++ b/doc/releases/changelog-dev.md @@ -95,6 +95,9 @@

Bug fixes 🐛

+* `qml.Projector` is now compatible with jax-jit. + [(#5595)](https://github.com/PennyLaneAI/pennylane/pull/5595) + * Finite shot circuits with a `qml.probs` measurement, both with a `wires` or `op` argument, can now be compiled with `jax.jit`. [(#5619)](https://github.com/PennyLaneAI/pennylane/pull/5619) diff --git a/pennylane/math/single_dispatch.py b/pennylane/math/single_dispatch.py index f1ed70589a3..acb18879e03 100644 --- a/pennylane/math/single_dispatch.py +++ b/pennylane/math/single_dispatch.py @@ -243,6 +243,7 @@ def _take_autograd(tensor, indices, axis=None): ar.autoray._SUBMODULE_ALIASES["tensorflow", "isclose"] = "tensorflow.experimental.numpy" ar.autoray._SUBMODULE_ALIASES["tensorflow", "atleast_1d"] = "tensorflow.experimental.numpy" ar.autoray._SUBMODULE_ALIASES["tensorflow", "all"] = "tensorflow.experimental.numpy" +ar.autoray._SUBMODULE_ALIASES["tensorflow", "vstack"] = "tensorflow.experimental.numpy" tf_fft_functions = [ "fft", diff --git a/pennylane/ops/op_math/composite.py b/pennylane/ops/op_math/composite.py index 400535c4bd7..c2a7573bfbd 100644 --- a/pennylane/ops/op_math/composite.py +++ b/pennylane/ops/op_math/composite.py @@ -175,8 +175,9 @@ def eigvals(self): list(self.wires), ) ) - - return self._math_op(math.asarray(eigvals, like=math.get_deep_interface(eigvals)), axis=0) + framework = math.get_deep_interface(eigvals) + eigvals = [math.asarray(ei, like=framework) for ei in eigvals] + return self._math_op(math.vstack(eigvals), axis=0) @abc.abstractmethod def matrix(self, wire_order=None): diff --git a/pennylane/ops/qubit/observables.py b/pennylane/ops/qubit/observables.py index 5d4bf5d19b6..b66e2dbc16b 100644 --- a/pennylane/ops/qubit/observables.py +++ b/pennylane/ops/qubit/observables.py @@ -435,14 +435,23 @@ class BasisStateProjector(Projector, Operation): r"""Observable corresponding to the state projector :math:`P=\ket{\phi}\bra{\phi}`, where :math:`\phi` denotes a basis state.""" + grad_method = None + # The call signature should be the same as Projector.__new__ for the positional # arguments, but with free key word arguments. def __init__(self, state, wires, id=None): wires = Wires(wires) - state = list(qml.math.toarray(state).astype(int)) - if not set(state).issubset({0, 1}): - raise ValueError(f"Basis state must only consist of 0s and 1s; got {state}") + if qml.math.get_interface(state) == "jax": + dtype = qml.math.dtype(state) + if not (np.issubdtype(dtype, np.integer) or np.issubdtype(dtype, bool)): + raise ValueError("Basis state must consist of integers or booleans.") + else: + # state is index into data, rather than data, so cast it to built-ins when + # no need for tracing + state = tuple(qml.math.toarray(state).astype(int)) + if not set(state).issubset({0, 1}): + raise ValueError(f"Basis state must only consist of 0s and 1s; got {state}") super().__init__(state, wires=wires, id=id) @@ -471,7 +480,7 @@ def label(self, decimals=None, base_label=None, cache=None): if base_label is not None: return base_label - basis_string = "".join(str(int(i)) for i in self.parameters[0]) + basis_string = "".join(str(int(i)) for i in self.data[0]) return f"|{basis_string}⟩⟨{basis_string}|" @staticmethod @@ -497,7 +506,15 @@ def compute_matrix(basis_state): # pylint: disable=arguments-differ [0. 0. 0. 0.] [0. 0. 0. 0.]] """ - m = np.zeros((2 ** len(basis_state), 2 ** len(basis_state))) + shape = (2 ** len(basis_state), 2 ** len(basis_state)) + if qml.math.get_interface(basis_state) == "jax": + idx = 0 + for i, m in enumerate(basis_state): + idx = idx + (m << (len(basis_state) - i - 1)) + mat = qml.math.zeros(shape, like=basis_state) + return mat.at[idx, idx].set(1.0) + + m = np.zeros(shape) idx = int("".join(str(i) for i in basis_state), 2) m[idx, idx] = 1 return m @@ -528,6 +545,12 @@ def compute_eigvals(basis_state): # pylint: disable=arguments-differ >>> BasisStateProjector.compute_eigvals([0, 1]) [0. 1. 0. 0.] """ + if qml.math.get_interface(basis_state) == "jax": + idx = 0 + for i, m in enumerate(basis_state): + idx = idx + (m << (len(basis_state) - i - 1)) + eigvals = qml.math.zeros(2 ** len(basis_state), like=basis_state) + return eigvals.at[idx].set(1.0) w = np.zeros(2 ** len(basis_state)) idx = int("".join(str(i) for i in basis_state), 2) w[idx] = 1 @@ -566,12 +589,12 @@ class StateVectorProjector(Projector): r"""Observable corresponding to the state projector :math:`P=\ket{\phi}\bra{\phi}`, where :math:`\phi` denotes a state.""" + grad_method = None + # The call signature should be the same as Projector.__new__ for the positional # arguments, but with free key word arguments. def __init__(self, state, wires, id=None): wires = Wires(wires) - state = list(qml.math.toarray(state)) - super().__init__(state, wires=wires, id=id) def __new__(cls, *_, **__): # pylint: disable=arguments-differ @@ -680,9 +703,11 @@ def compute_eigvals(state_vector): # pylint: disable=arguments-differ,arguments >>> StateVectorProjector.compute_eigvals([0, 0, 1, 0]) array([1, 0, 0, 0]) """ - w = qml.math.zeros_like(state_vector) + precision = qml.math.get_dtype_name(state_vector)[-2:] + dtype = f"float{precision}" if precision in {"32", "64"} else "float64" + w = np.zeros(qml.math.shape(state_vector), dtype=dtype) w[0] = 1 - return w + return qml.math.convert_like(w, state_vector) @staticmethod def compute_diagonalizing_gates( @@ -715,10 +740,27 @@ def compute_diagonalizing_gates( # Adapting the approach discussed in the link below to work with arbitrary complex-valued state vectors. # Alternatively, we could take the adjoint of the Mottonen decomposition for the state vector. # https://quantumcomputing.stackexchange.com/questions/10239/how-can-i-fill-a-unitary-knowing-only-its-first-column - phase = qml.math.exp(-1j * qml.math.angle(state_vector[0])) + + if qml.math.get_interface(state_vector) == "tensorflow": + dtype_name = qml.math.get_dtype_name(state_vector) + if dtype_name == "int32": + state_vector = qml.math.cast(state_vector, np.complex64) + elif dtype_name == "int64": + state_vector = qml.math.cast(state_vector, np.complex128) + + angle = qml.math.angle(state_vector[0]) + if qml.math.get_interface(angle) == "tensorflow": + if qml.math.get_dtype_name(angle) == "float32": + angle = qml.math.cast(angle, np.complex64) + else: + angle = qml.math.cast(angle, np.complex128) + + phase = qml.math.exp(-1.0j * angle) psi = phase * state_vector denominator = qml.math.sqrt(2 + 2 * psi[0]) - psi = qml.math.set_index(psi, 0, psi[0] + 1) # psi[0] += 1, but JAX-JIT compatible + summed_array = np.zeros(qml.math.shape(psi), dtype=qml.math.get_dtype_name(psi)) + summed_array[0] = 1.0 + psi = psi + summed_array psi /= denominator u = 2 * qml.math.outer(psi, qml.math.conj(psi)) - qml.math.eye(len(psi)) return [QubitUnitary(u, wires=wires)] diff --git a/pennylane/utils.py b/pennylane/utils.py index ac22de48d76..40cd7c4f1b2 100644 --- a/pennylane/utils.py +++ b/pennylane/utils.py @@ -205,4 +205,4 @@ def expand_vector(vector, original_wires, expanded_wires): expanded_tensor, tuple(original_indices), tuple(wire_indices) ) - return qml.math.reshape(expanded_tensor, qudit_order**M) + return qml.math.reshape(expanded_tensor, (qudit_order**M,)) diff --git a/tests/interfaces/default_qubit_2_integration/test_tensorflow_qnode_default_qubit_2.py b/tests/interfaces/default_qubit_2_integration/test_tensorflow_qnode_default_qubit_2.py index 772aa0831c9..e0f032b99cb 100644 --- a/tests/interfaces/default_qubit_2_integration/test_tensorflow_qnode_default_qubit_2.py +++ b/tests/interfaces/default_qubit_2_integration/test_tensorflow_qnode_default_qubit_2.py @@ -14,7 +14,7 @@ """Integration tests for using the TensorFlow interface with a QNode""" import numpy as np -# pylint: disable=too-many-arguments,too-few-public-methods,comparison-with-callable +# pylint: disable=too-many-arguments,too-few-public-methods,comparison-with-callable, use-implicit-booleaness-not-comparison import pytest import pennylane as qml @@ -955,8 +955,9 @@ def cost_fn(x, y): assert np.allclose(grad, expected, atol=tol, rtol=0) @pytest.mark.parametrize("state", [[1], [0, 1]]) # Basis state and state vector + @pytest.mark.parametrize("dtype", ("int32", "int64")) def test_projector( - self, state, dev, diff_method, grad_on_execution, device_vjp, tol, interface + self, state, dev, diff_method, grad_on_execution, device_vjp, tol, interface, dtype ): """Test that the variance of a projector is correctly returned""" kwargs = { @@ -974,7 +975,7 @@ def test_projector( kwargs["num_directions"] = 20 tol = TOL_FOR_SPSA - P = tf.constant(state) + P = tf.constant(state, dtype=dtype) x, y = 0.765, -0.654 weights = tf.Variable([x, y], dtype=tf.float64) diff --git a/tests/ops/qubit/test_observables.py b/tests/ops/qubit/test_observables.py index e165bdd4570..c5efbfbaadf 100644 --- a/tests/ops/qubit/test_observables.py +++ b/tests/ops/qubit/test_observables.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """Unit tests for qubit observables.""" -# pylint: disable=protected-access +# pylint: disable=protected-access, use-implicit-booleaness-not-comparison import functools import pickle @@ -540,6 +540,44 @@ def test_serialization(self): assert qml.equal(new_proj, proj) assert new_proj.id == proj.id # Ensure they are identical + @pytest.mark.jax + def test_jit_measurement(self): + """Test that the measurement of a projector can be jitted.""" + + import jax + + @jax.jit + @qml.qnode(qml.device("default.qubit")) + def circuit(state): + qml.X(1) + return qml.expval(qml.Projector(state, wires=(0, 1))) + + state00 = jax.numpy.array([0, 0]) + out00 = circuit(state00) + assert qml.math.allclose(out00, 0) + state01 = jax.numpy.array([0, 1]) + out01 = circuit(state01) + assert qml.math.allclose(out01, 1) + state10 = jax.numpy.array([True, False]) + out10 = circuit(state10) + assert qml.math.allclose(out10, 0) + + with pytest.raises(ValueError, match=r"Basis state must consist of integers or booleans."): + circuit(jax.numpy.array([0.5, 0.6])) + + @pytest.mark.jax + def test_jit_matrix(self): + """Test that computing the matrix of a projector is jittable.""" + + import jax + + basis_state = jax.numpy.array([0, 1]) + f = jax.jit(BasisStateProjector.compute_matrix) + out = f(basis_state) + + expected = np.array([[0, 0, 0, 0], [0, 1, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0]]) + assert qml.math.allclose(out, expected) + class TestBasisStateProjector: """Tests for the basis state projector observable.""" @@ -716,6 +754,24 @@ def test_label_matrices_list_in_cache(self, projector): assert len(cache["matrices"]) == 2 assert np.allclose(cache["matrices"][1], projector.parameters[0]) + @pytest.mark.jax + def test_jit_execution(self): + """Test that executing a StateVectorProjector can be jitted.""" + + import jax + + @jax.jit + @qml.qnode(qml.device("default.qubit")) + def circuit(state): + return qml.expval(qml.Projector(state, wires=(0, 1))) + + basis_state = jax.numpy.array([1, 1, 1, 1.0]) / 2 + out = circuit(basis_state) + assert qml.math.allclose(out, 0.25) + + basis_state2 = jax.numpy.array([0, 0, 0, 0]) + assert qml.math.allclose(circuit(basis_state2), 0) + label_data = [ (qml.Hermitian(np.eye(2), wires=1), "𝓗"),