diff --git a/doc/releases/changelog-dev.md b/doc/releases/changelog-dev.md index 1289ec0865f..aeeeedf42e3 100644 --- a/doc/releases/changelog-dev.md +++ b/doc/releases/changelog-dev.md @@ -360,6 +360,9 @@

Breaking changes 💔

+* State measurements preserve `dtype`. + [(#5547)](https://github.com/PennyLaneAI/pennylane/pull/5547) + * Use `SampleMP`s in the `dynamic_one_shot` transform to get back the values of the mid-circuit measurements. [(#5486)](https://github.com/PennyLaneAI/pennylane/pull/5486) diff --git a/pennylane/measurements/state.py b/pennylane/measurements/state.py index c31e3b94c85..f0d158cf608 100644 --- a/pennylane/measurements/state.py +++ b/pennylane/measurements/state.py @@ -156,9 +156,10 @@ def shape(self, device, shots): def process_state(self, state: Sequence[complex], wire_order: Wires): # pylint:disable=redefined-outer-name + is_tf_interface = qml.math.get_deep_interface(state) == "tensorflow" wires = self.wires if not wires or wire_order == wires: - return qml.math.cast(state, "complex128") + return qml.math.cast(state, "complex128") if is_tf_interface else state + 0.0j if set(wires) != set(wire_order): raise WireError( @@ -178,7 +179,7 @@ def process_state(self, state: Sequence[complex], wire_order: Wires): state = qml.math.reshape(state, shape) state = qml.math.transpose(state, desired_axes) state = qml.math.reshape(state, flat_shape) - return qml.math.cast(state, "complex128") + return qml.math.cast(state, "complex128") if is_tf_interface else state + 0.0j class DensityMatrixMP(StateMP): @@ -211,4 +212,7 @@ def process_state(self, state: Sequence[complex], wire_order: Wires): # pylint:disable=redefined-outer-name wire_map = dict(zip(wire_order, range(len(wire_order)))) mapped_wires = [wire_map[w] for w in self.wires] - return qml.math.reduce_statevector(state, indices=mapped_wires) + kwargs = {"indices": mapped_wires, "c_dtype": "complex128"} + if not qml.math.is_abstract(state) and qml.math.any(qml.math.iscomplex(state)): + kwargs["c_dtype"] = state.dtype + return qml.math.reduce_statevector(state, **kwargs) diff --git a/tests/devices/test_lightning_qubit.py b/tests/devices/test_lightning_qubit.py index 9493382d4f3..fe97c9c8b07 100644 --- a/tests/devices/test_lightning_qubit.py +++ b/tests/devices/test_lightning_qubit.py @@ -98,7 +98,7 @@ class TestDtypePreserved: @pytest.mark.parametrize( "c_dtype", [ - pytest.param(np.complex64, marks=pytest.mark.xfail(reason="dtype not preserved")), + np.complex64, np.complex128, ], ) @@ -108,18 +108,10 @@ class TestDtypePreserved: qml.state(), qml.density_matrix(wires=[1]), qml.density_matrix(wires=[2, 0]), - pytest.param( - qml.expval(qml.PauliY(0)), marks=pytest.mark.xfail(reason="incorrect type") - ), - pytest.param(qml.var(qml.PauliY(0)), marks=pytest.mark.xfail(reason="incorrect type")), - pytest.param( - qml.probs(wires=[1]), - marks=pytest.mark.skip(reason="measurement passes with complex64 but xfail strict"), - ), - pytest.param( - qml.probs(wires=[0, 2]), - marks=pytest.mark.skip(reason="measurement passes with complex64 but xfail strict"), - ), + qml.expval(qml.PauliY(0)), + qml.var(qml.PauliY(0)), + qml.probs(wires=[1]), + qml.probs(wires=[0, 2]), ], ) def test_dtype(self, c_dtype, measurement): @@ -139,4 +131,7 @@ def circuit(x): expected_dtype = c_dtype else: expected_dtype = np.float64 if c_dtype == np.complex128 else np.float32 - assert res.dtype == expected_dtype + if isinstance(res, np.ndarray): + assert res.dtype == expected_dtype + else: + assert isinstance(res, float) diff --git a/tests/gradients/core/test_pulse_gradient.py b/tests/gradients/core/test_pulse_gradient.py index b95a36d67f9..04d6b7e3525 100644 --- a/tests/gradients/core/test_pulse_gradient.py +++ b/tests/gradients/core/test_pulse_gradient.py @@ -770,7 +770,7 @@ def test_nontrainable_batched_tape(self): x = [0.4, 0.2] params = [jnp.array(0.14)] ham_single_q_const = qml.pulse.constant * qml.PauliY(0) - op = qml.evolve(ham_single_q_const)(params, 0.1) + op = qml.evolve(ham_single_q_const)(params, 0.7) tape = qml.tape.QuantumScript( [qml.RX(x, 0), op], [qml.expval(qml.PauliZ(0))], trainable_params=[1] )