diff --git a/doc/releases/changelog-dev.md b/doc/releases/changelog-dev.md
index 1289ec0865f..aeeeedf42e3 100644
--- a/doc/releases/changelog-dev.md
+++ b/doc/releases/changelog-dev.md
@@ -360,6 +360,9 @@
Breaking changes 💔
+* State measurements preserve `dtype`.
+ [(#5547)](https://github.com/PennyLaneAI/pennylane/pull/5547)
+
* Use `SampleMP`s in the `dynamic_one_shot` transform to get back the values of the mid-circuit measurements.
[(#5486)](https://github.com/PennyLaneAI/pennylane/pull/5486)
diff --git a/pennylane/measurements/state.py b/pennylane/measurements/state.py
index c31e3b94c85..f0d158cf608 100644
--- a/pennylane/measurements/state.py
+++ b/pennylane/measurements/state.py
@@ -156,9 +156,10 @@ def shape(self, device, shots):
def process_state(self, state: Sequence[complex], wire_order: Wires):
# pylint:disable=redefined-outer-name
+ is_tf_interface = qml.math.get_deep_interface(state) == "tensorflow"
wires = self.wires
if not wires or wire_order == wires:
- return qml.math.cast(state, "complex128")
+ return qml.math.cast(state, "complex128") if is_tf_interface else state + 0.0j
if set(wires) != set(wire_order):
raise WireError(
@@ -178,7 +179,7 @@ def process_state(self, state: Sequence[complex], wire_order: Wires):
state = qml.math.reshape(state, shape)
state = qml.math.transpose(state, desired_axes)
state = qml.math.reshape(state, flat_shape)
- return qml.math.cast(state, "complex128")
+ return qml.math.cast(state, "complex128") if is_tf_interface else state + 0.0j
class DensityMatrixMP(StateMP):
@@ -211,4 +212,7 @@ def process_state(self, state: Sequence[complex], wire_order: Wires):
# pylint:disable=redefined-outer-name
wire_map = dict(zip(wire_order, range(len(wire_order))))
mapped_wires = [wire_map[w] for w in self.wires]
- return qml.math.reduce_statevector(state, indices=mapped_wires)
+ kwargs = {"indices": mapped_wires, "c_dtype": "complex128"}
+ if not qml.math.is_abstract(state) and qml.math.any(qml.math.iscomplex(state)):
+ kwargs["c_dtype"] = state.dtype
+ return qml.math.reduce_statevector(state, **kwargs)
diff --git a/tests/devices/test_lightning_qubit.py b/tests/devices/test_lightning_qubit.py
index 9493382d4f3..fe97c9c8b07 100644
--- a/tests/devices/test_lightning_qubit.py
+++ b/tests/devices/test_lightning_qubit.py
@@ -98,7 +98,7 @@ class TestDtypePreserved:
@pytest.mark.parametrize(
"c_dtype",
[
- pytest.param(np.complex64, marks=pytest.mark.xfail(reason="dtype not preserved")),
+ np.complex64,
np.complex128,
],
)
@@ -108,18 +108,10 @@ class TestDtypePreserved:
qml.state(),
qml.density_matrix(wires=[1]),
qml.density_matrix(wires=[2, 0]),
- pytest.param(
- qml.expval(qml.PauliY(0)), marks=pytest.mark.xfail(reason="incorrect type")
- ),
- pytest.param(qml.var(qml.PauliY(0)), marks=pytest.mark.xfail(reason="incorrect type")),
- pytest.param(
- qml.probs(wires=[1]),
- marks=pytest.mark.skip(reason="measurement passes with complex64 but xfail strict"),
- ),
- pytest.param(
- qml.probs(wires=[0, 2]),
- marks=pytest.mark.skip(reason="measurement passes with complex64 but xfail strict"),
- ),
+ qml.expval(qml.PauliY(0)),
+ qml.var(qml.PauliY(0)),
+ qml.probs(wires=[1]),
+ qml.probs(wires=[0, 2]),
],
)
def test_dtype(self, c_dtype, measurement):
@@ -139,4 +131,7 @@ def circuit(x):
expected_dtype = c_dtype
else:
expected_dtype = np.float64 if c_dtype == np.complex128 else np.float32
- assert res.dtype == expected_dtype
+ if isinstance(res, np.ndarray):
+ assert res.dtype == expected_dtype
+ else:
+ assert isinstance(res, float)
diff --git a/tests/gradients/core/test_pulse_gradient.py b/tests/gradients/core/test_pulse_gradient.py
index b95a36d67f9..04d6b7e3525 100644
--- a/tests/gradients/core/test_pulse_gradient.py
+++ b/tests/gradients/core/test_pulse_gradient.py
@@ -770,7 +770,7 @@ def test_nontrainable_batched_tape(self):
x = [0.4, 0.2]
params = [jnp.array(0.14)]
ham_single_q_const = qml.pulse.constant * qml.PauliY(0)
- op = qml.evolve(ham_single_q_const)(params, 0.1)
+ op = qml.evolve(ham_single_q_const)(params, 0.7)
tape = qml.tape.QuantumScript(
[qml.RX(x, 0), op], [qml.expval(qml.PauliZ(0))], trainable_params=[1]
)