From a63726b1053e32099aae35f6a641c14ad8688da2 Mon Sep 17 00:00:00 2001 From: Astral Cai Date: Thu, 20 Jun 2024 15:49:17 -0400 Subject: [PATCH] More sophisiticated measurement validation for default-qubit --- pennylane/devices/default_qubit.py | 84 +++++++++++++++++++----------- 1 file changed, 53 insertions(+), 31 deletions(-) diff --git a/pennylane/devices/default_qubit.py b/pennylane/devices/default_qubit.py index 865bedf6b85..19c7b9b576b 100644 --- a/pennylane/devices/default_qubit.py +++ b/pennylane/devices/default_qubit.py @@ -60,30 +60,6 @@ PostprocessingFn = Callable[[ResultBatch], Result_or_ResultBatch] -observables = { - "PauliX", - "PauliY", - "PauliZ", - "Hadamard", - "Hermitian", - "Identity", - "Projector", - "SparseHamiltonian", - "Hamiltonian", - "LinearCombination", - "Sum", - "SProd", - "Prod", - "Exp", - "Evolution", -} - - -def observable_stopping_condition(obs: qml.operation.Operator) -> bool: - """Specifies whether or not an observable is accepted by DefaultQubit.""" - return obs.name in observables - - def stopping_condition(op: qml.operation.Operator) -> bool: """Specify whether or not an Operator object is supported by the device.""" if op.name == "QFT" and len(op.wires) >= 6: @@ -103,16 +79,62 @@ def stopping_condition_shots(op: qml.operation.Operator) -> bool: return isinstance(op, (Conditional, MidMeasureMP)) or stopping_condition(op) +def observable_accepts_sampling(obs: qml.operation.Operator) -> bool: + """Verifies whether an observable supports sample measurement""" + + if isinstance(obs, qml.ops.CompositeOp): + return all(observable_accepts_sampling(o) for o in obs.operands) + + if isinstance(obs, qml.operation.Tensor): + return all(observable_accepts_sampling(o) for o in obs.obs) + + return obs.has_diagonalizing_gates() + + +def observable_accepts_analytic(obs: qml.operation.Operator, is_expval=False) -> bool: + """Verifies whether an observable supports analytic measurement""" + + if isinstance(obs, qml.ops.CompositeOp): + return all(observable_accepts_analytic(o, is_expval) for o in obs.operands) + + if obs.has_diagonalizing_gates(): + return True + + if is_expval: + return isinstance(obs, (qml.ops.SparseHamiltonian, qml.ops.Hermitian)) + + return True + + def accepted_sample_measurement(m: qml.measurements.MeasurementProcess) -> bool: - """Specifies whether or not a measurement is accepted when sampling.""" - return isinstance( + """Specifies whether a measurement is accepted when sampling.""" + + if not isinstance( m, ( qml.measurements.SampleMeasurement, qml.measurements.ClassicalShadowMP, qml.measurements.ShadowExpvalMP, ), - ) + ): + return False + + if m.obs is not None: + return observable_accepts_sampling(m.obs) + + return True + + +def accepted_analytic_measurement(m: qml.measurements.MeasurementProcess) -> bool: + """Specifies whether a measurement is accepted when analytic.""" + + if not isinstance(m, qml.measurements.StateMeasurement): + return False + + if m.obs is not None: + return observable_accepts_analytic(m.obs, isinstance(m, qml.measurements.ExpectationMP)) + + return True def null_postprocessing(results): @@ -514,10 +536,10 @@ def preprocess( name=self.name, ) transform_program.add_transform( - validate_measurements, sample_measurements=accepted_sample_measurement, name=self.name - ) - transform_program.add_transform( - validate_observables, stopping_condition=observable_stopping_condition, name=self.name + validate_measurements, + analytic_measurements=accepted_analytic_measurement, + sample_measurements=accepted_sample_measurement, + name=self.name, ) if config.mcm_config.mcm_method == "tree-traversal": transform_program.add_transform(qml.transforms.broadcast_expand)