Skip to content

Commit

Permalink
Improve support for snapshot on default.mixed (#5552)
Browse files Browse the repository at this point in the history
**Context:** Improves support for performing snapshots on
`default.mixed`.

**Description of the Change:** Adds the ability to perform arbitrary
state-based measurements.

**Benefits:** Device will support `qml.Snapshot` fully.

**Possible Drawbacks:**

**Related GitHub Issues:**

---------

Co-authored-by: Astral Cai <astral.cai@xanadu.ai>
  • Loading branch information
obliviateandsurrender and astralcai committed Apr 25, 2024
1 parent 1f107a5 commit 96b6241
Show file tree
Hide file tree
Showing 3 changed files with 136 additions and 11 deletions.
3 changes: 3 additions & 0 deletions doc/releases/changelog-dev.md
Original file line number Diff line number Diff line change
Expand Up @@ -373,6 +373,9 @@
[(#5514)](https://github.com/PennyLaneAI/pennylane/pull/5514)
[(#5530)](https://github.com/PennyLaneAI/pennylane/pull/5530)

* `default.mixed` now supports arbitrary state-based measurements with `qml.Snapshot`.
[(#5552)](https://github.com/PennyLaneAI/pennylane/pull/5552)

* Replaced `cache_execute` with an alternate implementation based on `@transform`.
[(#5318)](https://github.com/PennyLaneAI/pennylane/pull/5318)

Expand Down
85 changes: 80 additions & 5 deletions pennylane/devices/default_mixed.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,18 @@
StatePrep,
Snapshot,
)
from pennylane.measurements import CountsMP, MutualInfoMP, SampleMP, StateMP, VnEntropyMP, PurityMP
from pennylane.measurements import (
CountsMP,
MutualInfoMP,
SampleMP,
StateMP,
VnEntropyMP,
PurityMP,
DensityMatrixMP,
ExpectationMP,
VarianceMP,
ProbabilityMP,
)
from pennylane.operation import Channel
from pennylane.ops.qubit.attributes import diagonal_in_z_basis
from pennylane.wires import Wires
Expand Down Expand Up @@ -582,18 +593,82 @@ def _apply_density_matrix(self, state, device_wires):
self._state = qnp.asarray(rho, dtype=self.C_DTYPE)
self._pre_rotated_state = self._state

def _snapshot_measurements(self, density_matrix, measurement):
"""Perform state-based snapshot measurement"""
meas_wires = measurement.wires

pre_rotated_state = self._state
if isinstance(measurement, (ProbabilityMP, ExpectationMP, VarianceMP)):
for diag_gate in measurement.diagonalizing_gates():
self._apply_operation(diag_gate)

if isinstance(measurement, (StateMP, DensityMatrixMP)):
map_wires = self.map_wires(meas_wires)
snap_result = qml.math.reduce_dm(
density_matrix, indices=map_wires, c_dtype=self.C_DTYPE
)

elif isinstance(measurement, PurityMP):
map_wires = self.map_wires(meas_wires)
snap_result = qml.math.purity(density_matrix, indices=map_wires, c_dtype=self.C_DTYPE)

elif isinstance(measurement, ProbabilityMP):
snap_result = self.analytic_probability(wires=meas_wires)

elif isinstance(measurement, ExpectationMP):
eigvals = self._asarray(measurement.obs.eigvals(), dtype=self.R_DTYPE)
probs = self.analytic_probability(wires=meas_wires)
snap_result = self._dot(probs, eigvals)

elif isinstance(measurement, VarianceMP):
eigvals = self._asarray(measurement.obs.eigvals(), dtype=self.R_DTYPE)
probs = self.analytic_probability(wires=meas_wires)
snap_result = self._dot(probs, (eigvals**2)) - self._dot(probs, eigvals) ** 2

elif isinstance(measurement, VnEntropyMP):
base = measurement.log_base
map_wires = self.map_wires(meas_wires)
snap_result = qml.math.vn_entropy(
density_matrix, indices=map_wires, c_dtype=self.C_DTYPE, base=base
)

elif isinstance(measurement, MutualInfoMP):
base = measurement.log_base
wires0, wires1 = list(map(self.map_wires, measurement.raw_wires))
snap_result = qml.math.mutual_info(
density_matrix,
indices0=wires0,
indices1=wires1,
c_dtype=self.C_DTYPE,
base=base,
)

else:
raise DeviceError(
f"Snapshots of {type(measurement)} are not yet supported on default.mixed"
)

self._state = pre_rotated_state
self._pre_rotated_state = self._state

return snap_result

def _apply_snapshot(self, operation):
"""Applies the snapshot operation"""
measurement = operation.hyperparameters["measurement"]
if measurement:
raise DeviceError("Snapshots of measurements are not yet supported on default.mixed")

if self._debugger and self._debugger.active:
dim = 2**self.num_wires
density_matrix = qnp.reshape(self._state, (dim, dim))
snap_result = density_matrix

if measurement:
snap_result = self._snapshot_measurements(density_matrix, measurement)

if operation.tag:
self._debugger.snapshots[operation.tag] = density_matrix
self._debugger.snapshots[operation.tag] = snap_result
else:
self._debugger.snapshots[len(self._debugger.snapshots)] = density_matrix
self._debugger.snapshots[len(self._debugger.snapshots)] = snap_result

def _apply_operation(self, operation):
"""Applies operations to the internal device state.
Expand Down
59 changes: 53 additions & 6 deletions tests/devices/test_default_mixed.py
Original file line number Diff line number Diff line change
Expand Up @@ -898,15 +898,62 @@ def test_identity_skipped(self, mocker):
spy_diagonal_unitary.assert_not_called()
spy_apply_channel.assert_not_called()

@pytest.mark.parametrize(
"measurement",
[
qml.expval(op=qml.Z(1)),
qml.expval(op=qml.Y(0) @ qml.X(1)),
qml.var(op=qml.X(0)),
qml.var(op=qml.X(0) @ qml.Z(1)),
qml.density_matrix(wires=[1]),
qml.density_matrix(wires=[0, 1]),
qml.probs(op=qml.Y(0)),
qml.probs(op=qml.X(0) @ qml.Y(1)),
qml.vn_entropy(wires=[0]),
qml.mutual_info(wires0=[1], wires1=[0]),
qml.purity(wires=[1]),
],
)
def test_snapshot_supported(self, measurement):
"""Tests that applying snapshot of measurements is done correctly"""

def circuit():
"""Snapshot circuit"""
qml.Hadamard(wires=0)
qml.Hadamard(wires=1)
qml.Snapshot(measurement=qml.expval(qml.Z(0) @ qml.Z(1)))
qml.RX(0.123, wires=[0])
qml.RY(0.123, wires=[0])
qml.CNOT(wires=[0, 1])
qml.Snapshot(measurement=measurement)
qml.RZ(0.467, wires=[0])
qml.RX(0.235, wires=[0])
qml.CZ(wires=[1, 0])
qml.Snapshot("meas2", measurement=measurement)
return qml.probs(op=qml.Y(1) @ qml.Z(0))

dev_qubit = qml.device("default.qubit", wires=2)
dev_mixed = qml.device("default.mixed", wires=2)

qnode_qubit = qml.QNode(circuit, device=dev_qubit)
qnode_mixed = qml.QNode(circuit, device=dev_mixed)

snaps_qubit = qml.snapshots(qnode_qubit)()
snaps_mixed = qml.snapshots(qnode_mixed)()

for key1, key2 in zip(snaps_qubit, snaps_mixed):
assert key1 == key2
assert qml.math.allclose(snaps_qubit[key1], snaps_mixed[key2])

def test_snapshot_not_supported(self):
"""Tests that an error is raised when applying snapshot of measurements"""
"""Tests that an error is raised when applying snapshot of sample-based measurements"""

dev = qml.device("default.mixed", wires=1)
with pytest.raises(DeviceError, match="Snapshots of measurements are not yet supported"):
dev._apply_operation(qml.Snapshot(measurement=qml.expval(qml.PauliZ(0))))

# assert that a snapshot still works without a measurement
_ = dev._apply_operation(qml.Snapshot())
measurement = qml.sample(op=qml.Z(0))
with pytest.raises(
DeviceError, match=f"Snapshots of {type(measurement)} are not yet supported"
):
dev._snapshot_measurements(dev.state, measurement)


class TestApply:
Expand Down

0 comments on commit 96b6241

Please sign in to comment.