Skip to content

Commit

Permalink
More sophisiticated measurement validation for default-qubit
Browse files Browse the repository at this point in the history
  • Loading branch information
astralcai committed Jun 20, 2024
1 parent 14a0b63 commit a63726b
Showing 1 changed file with 53 additions and 31 deletions.
84 changes: 53 additions & 31 deletions pennylane/devices/default_qubit.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,30 +60,6 @@
PostprocessingFn = Callable[[ResultBatch], Result_or_ResultBatch]


observables = {
"PauliX",
"PauliY",
"PauliZ",
"Hadamard",
"Hermitian",
"Identity",
"Projector",
"SparseHamiltonian",
"Hamiltonian",
"LinearCombination",
"Sum",
"SProd",
"Prod",
"Exp",
"Evolution",
}


def observable_stopping_condition(obs: qml.operation.Operator) -> bool:
"""Specifies whether or not an observable is accepted by DefaultQubit."""
return obs.name in observables


def stopping_condition(op: qml.operation.Operator) -> bool:
"""Specify whether or not an Operator object is supported by the device."""
if op.name == "QFT" and len(op.wires) >= 6:
Expand All @@ -103,16 +79,62 @@ def stopping_condition_shots(op: qml.operation.Operator) -> bool:
return isinstance(op, (Conditional, MidMeasureMP)) or stopping_condition(op)


def observable_accepts_sampling(obs: qml.operation.Operator) -> bool:
"""Verifies whether an observable supports sample measurement"""

if isinstance(obs, qml.ops.CompositeOp):
return all(observable_accepts_sampling(o) for o in obs.operands)

if isinstance(obs, qml.operation.Tensor):
return all(observable_accepts_sampling(o) for o in obs.obs)

return obs.has_diagonalizing_gates()


def observable_accepts_analytic(obs: qml.operation.Operator, is_expval=False) -> bool:
"""Verifies whether an observable supports analytic measurement"""

if isinstance(obs, qml.ops.CompositeOp):
return all(observable_accepts_analytic(o, is_expval) for o in obs.operands)

if obs.has_diagonalizing_gates():
return True

if is_expval:
return isinstance(obs, (qml.ops.SparseHamiltonian, qml.ops.Hermitian))

return True


def accepted_sample_measurement(m: qml.measurements.MeasurementProcess) -> bool:
"""Specifies whether or not a measurement is accepted when sampling."""
return isinstance(
"""Specifies whether a measurement is accepted when sampling."""

if not isinstance(
m,
(
qml.measurements.SampleMeasurement,
qml.measurements.ClassicalShadowMP,
qml.measurements.ShadowExpvalMP,
),
)
):
return False

if m.obs is not None:
return observable_accepts_sampling(m.obs)

return True


def accepted_analytic_measurement(m: qml.measurements.MeasurementProcess) -> bool:
"""Specifies whether a measurement is accepted when analytic."""

if not isinstance(m, qml.measurements.StateMeasurement):
return False

if m.obs is not None:
return observable_accepts_analytic(m.obs, isinstance(m, qml.measurements.ExpectationMP))

return True


def null_postprocessing(results):
Expand Down Expand Up @@ -514,10 +536,10 @@ def preprocess(
name=self.name,
)
transform_program.add_transform(
validate_measurements, sample_measurements=accepted_sample_measurement, name=self.name
)
transform_program.add_transform(
validate_observables, stopping_condition=observable_stopping_condition, name=self.name
validate_measurements,
analytic_measurements=accepted_analytic_measurement,
sample_measurements=accepted_sample_measurement,
name=self.name,
)
if config.mcm_config.mcm_method == "tree-traversal":
transform_program.add_transform(qml.transforms.broadcast_expand)
Expand Down

0 comments on commit a63726b

Please sign in to comment.