Skip to content

Commit

Permalink
Do not cast state to complex128 (#5547)
Browse files Browse the repository at this point in the history
### Before submitting

Please complete the following checklist when submitting a PR:

- [x] All new features must include a unit test.
If you've fixed a bug or added code that should be tested, add a test to
the
      test directory!

- [x] All new functions and code must be clearly commented and
documented.
If you do make documentation changes, make sure that the docs build and
      render correctly by running `make docs`.

- [x] Ensure that the test suite passes, by running `make test`.

- [x] Add a new entry to the `doc/releases/changelog-dev.md` file,
summarizing the
      change, and including a link back to the PR.

- [x] The PennyLane source code conforms to
      [PEP8 standards](https://www.python.org/dev/peps/pep-0008/).
We check all of our code against [Pylint](https://www.pylint.org/).
      To lint modified files, simply `pip install pylint`, and then
      run `pylint pennylane/path/to/file.py`.

When all the above are checked, delete everything above the dashed
line and fill in the pull request template.


------------------------------------------------------------------------------------------------------------

**Context:**
The LQ new device API does not preserve the `dtype` of measurement
results,

``` py
import pennylane as qml
import numpy as np
dev = qml.device("lightning.qubit", wires=2, c_dtype=np.complex64)

@qml.qnode(dev)
def circ():
    return qml.state()

>>> circ().dtype
complex128
```

The issue comes from `measurementprocess.process_state` as this method
changes the specified `dtype` to the default `complex128`.

**Description of the Change:**
Modify `StateMP` and `DensityMatrixMP` avoiding explicitly casting to
`complex128`, relying on the various frameworks casting rules, by adding
`0.0j`. This does not work in TensorFlow for which the current behaviour
is preserved.

**Benefits:**

**Possible Drawbacks:**

**Related GitHub Issues:**
[sc-60855]
  • Loading branch information
vincentmr committed Apr 23, 2024
1 parent 6d1fd42 commit f834328
Show file tree
Hide file tree
Showing 4 changed files with 20 additions and 18 deletions.
3 changes: 3 additions & 0 deletions doc/releases/changelog-dev.md
Original file line number Diff line number Diff line change
Expand Up @@ -360,6 +360,9 @@

<h3>Breaking changes 💔</h3>

* State measurements preserve `dtype`.
[(#5547)](https://github.com/PennyLaneAI/pennylane/pull/5547)

* Use `SampleMP`s in the `dynamic_one_shot` transform to get back the values of the mid-circuit measurements.
[(#5486)](https://github.com/PennyLaneAI/pennylane/pull/5486)

Expand Down
10 changes: 7 additions & 3 deletions pennylane/measurements/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,9 +156,10 @@ def shape(self, device, shots):

def process_state(self, state: Sequence[complex], wire_order: Wires):
# pylint:disable=redefined-outer-name
is_tf_interface = qml.math.get_deep_interface(state) == "tensorflow"
wires = self.wires
if not wires or wire_order == wires:
return qml.math.cast(state, "complex128")
return qml.math.cast(state, "complex128") if is_tf_interface else state + 0.0j

if set(wires) != set(wire_order):
raise WireError(
Expand All @@ -178,7 +179,7 @@ def process_state(self, state: Sequence[complex], wire_order: Wires):
state = qml.math.reshape(state, shape)
state = qml.math.transpose(state, desired_axes)
state = qml.math.reshape(state, flat_shape)
return qml.math.cast(state, "complex128")
return qml.math.cast(state, "complex128") if is_tf_interface else state + 0.0j


class DensityMatrixMP(StateMP):
Expand Down Expand Up @@ -211,4 +212,7 @@ def process_state(self, state: Sequence[complex], wire_order: Wires):
# pylint:disable=redefined-outer-name
wire_map = dict(zip(wire_order, range(len(wire_order))))
mapped_wires = [wire_map[w] for w in self.wires]
return qml.math.reduce_statevector(state, indices=mapped_wires)
kwargs = {"indices": mapped_wires, "c_dtype": "complex128"}
if not qml.math.is_abstract(state) and qml.math.any(qml.math.iscomplex(state)):
kwargs["c_dtype"] = state.dtype
return qml.math.reduce_statevector(state, **kwargs)
23 changes: 9 additions & 14 deletions tests/devices/test_lightning_qubit.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ class TestDtypePreserved:
@pytest.mark.parametrize(
"c_dtype",
[
pytest.param(np.complex64, marks=pytest.mark.xfail(reason="dtype not preserved")),
np.complex64,
np.complex128,
],
)
Expand All @@ -108,18 +108,10 @@ class TestDtypePreserved:
qml.state(),
qml.density_matrix(wires=[1]),
qml.density_matrix(wires=[2, 0]),
pytest.param(
qml.expval(qml.PauliY(0)), marks=pytest.mark.xfail(reason="incorrect type")
),
pytest.param(qml.var(qml.PauliY(0)), marks=pytest.mark.xfail(reason="incorrect type")),
pytest.param(
qml.probs(wires=[1]),
marks=pytest.mark.skip(reason="measurement passes with complex64 but xfail strict"),
),
pytest.param(
qml.probs(wires=[0, 2]),
marks=pytest.mark.skip(reason="measurement passes with complex64 but xfail strict"),
),
qml.expval(qml.PauliY(0)),
qml.var(qml.PauliY(0)),
qml.probs(wires=[1]),
qml.probs(wires=[0, 2]),
],
)
def test_dtype(self, c_dtype, measurement):
Expand All @@ -139,4 +131,7 @@ def circuit(x):
expected_dtype = c_dtype
else:
expected_dtype = np.float64 if c_dtype == np.complex128 else np.float32
assert res.dtype == expected_dtype
if isinstance(res, np.ndarray):
assert res.dtype == expected_dtype
else:
assert isinstance(res, float)
2 changes: 1 addition & 1 deletion tests/gradients/core/test_pulse_gradient.py
Original file line number Diff line number Diff line change
Expand Up @@ -770,7 +770,7 @@ def test_nontrainable_batched_tape(self):
x = [0.4, 0.2]
params = [jnp.array(0.14)]
ham_single_q_const = qml.pulse.constant * qml.PauliY(0)
op = qml.evolve(ham_single_q_const)(params, 0.1)
op = qml.evolve(ham_single_q_const)(params, 0.7)
tape = qml.tape.QuantumScript(
[qml.RX(x, 0), op], [qml.expval(qml.PauliZ(0))], trainable_params=[1]
)
Expand Down

0 comments on commit f834328

Please sign in to comment.