From 4f5ce6a4d5890552da1628f8c44d72a3ab9a02a1 Mon Sep 17 00:00:00 2001 From: Mudit Pandey Date: Fri, 10 May 2024 16:20:59 -0400 Subject: [PATCH 01/48] Added functionality for mcm config --- pennylane/devices/default_qubit.py | 16 ++++- pennylane/devices/execution_config.py | 9 +++ pennylane/devices/preprocess.py | 17 ++++- pennylane/devices/qubit/simulate.py | 25 ++++++-- pennylane/transforms/defer_measurements.py | 1 + pennylane/transforms/dynamic_one_shot.py | 75 +++++++++++++--------- pennylane/workflow/execution.py | 25 +++++++- pennylane/workflow/qnode.py | 6 ++ 8 files changed, 135 insertions(+), 39 deletions(-) diff --git a/pennylane/devices/default_qubit.py b/pennylane/devices/default_qubit.py index d7957424b00..6039f51ccf8 100644 --- a/pennylane/devices/default_qubit.py +++ b/pennylane/devices/default_qubit.py @@ -497,7 +497,9 @@ def preprocess( transform_program = TransformProgram() transform_program.add_transform(validate_device_wires, self.wires, name=self.name) - transform_program.add_transform(mid_circuit_measurements, device=self) + transform_program.add_transform( + mid_circuit_measurements, device=self, mcm_config=config.mcm_config + ) transform_program.add_transform( decompose, stopping_condition=stopping_condition, @@ -600,6 +602,9 @@ def execute( "interface": interface, "state_cache": self._state_cache, "prng_key": _key, + "discard_invalid_shots": execution_config.mcm_config[ + "discard_invalid_shots" + ], }, ) for c, _key in zip(circuits, prng_keys) @@ -607,7 +612,14 @@ def execute( vanilla_circuits = [convert_to_numpy_parameters(c) for c in circuits] seeds = self._rng.integers(2**31 - 1, size=len(vanilla_circuits)) - simulate_kwargs = [{"rng": _rng, "prng_key": _key} for _rng, _key in zip(seeds, prng_keys)] + simulate_kwargs = [ + { + "rng": _rng, + "prng_key": _key, + "discard_invalid_shots": execution_config.mcm_config["discard_invalid_shots"], + } + for _rng, _key in zip(seeds, prng_keys) + ] with concurrent.futures.ProcessPoolExecutor(max_workers=max_workers) as executor: exec_map = executor.map(_simulate_wrapper, vanilla_circuits, simulate_kwargs) diff --git a/pennylane/devices/execution_config.py b/pennylane/devices/execution_config.py index e7ac20179a7..90ca617d6df 100644 --- a/pennylane/devices/execution_config.py +++ b/pennylane/devices/execution_config.py @@ -67,6 +67,9 @@ class ExecutionConfig: derivative_order: int = 1 """The derivative order to compute while evaluating a gradient""" + mcm_config: Optional[dict] = None + """Configuration options for handling mid-circuit measurements""" + def __post_init__(self): """ Validate the configured execution options. @@ -89,5 +92,11 @@ def __post_init__(self): if self.gradient_keyword_arguments is None: self.gradient_keyword_arguments = {} + if self.mcm_config is None: + self.mcm_config = {} + for option in ("discard_invalid_shots", "method"): + if option not in self.mcm_config: + self.mcm_config[option] = None + DefaultExecutionConfig = ExecutionConfig() diff --git a/pennylane/devices/preprocess.py b/pennylane/devices/preprocess.py index 882ed8445a1..d3bc7c983fd 100644 --- a/pennylane/devices/preprocess.py +++ b/pennylane/devices/preprocess.py @@ -145,7 +145,7 @@ def validate_device_wires( @transform def mid_circuit_measurements( - tape: qml.tape.QuantumTape, device + tape: qml.tape.QuantumTape, device, mcm_config ) -> (Sequence[qml.tape.QuantumTape], Callable): """Provide the transform to handle mid-circuit measurements. @@ -153,8 +153,21 @@ def mid_circuit_measurements( and use the ``qml.defer_measurements`` transform otherwise. """ + if (mcm_method := mcm_config.get("method", None)) is not None: + if mcm_method == "one-shot": + return qml.dynamic_one_shot( + tape, discard_invalid_shots=mcm_config.get("discard_invalid_shots", None) + ) + if mcm_method == "deferred": + return qml.defer_measurements(tape, device=device) + warnings.warn( + "Invalid mid-circuit measurements method. Automatically detecting optimal method.", + UserWarning, + ) if tape.shots: - return qml.dynamic_one_shot(tape) + return qml.dynamic_one_shot( + tape, discard_invalid_shots=mcm_config.get("discard_invalid_shots", None) + ) return qml.defer_measurements(tape, device=device) diff --git a/pennylane/devices/qubit/simulate.py b/pennylane/devices/qubit/simulate.py index 2095aff3d6d..c420208d050 100644 --- a/pennylane/devices/qubit/simulate.py +++ b/pennylane/devices/qubit/simulate.py @@ -63,7 +63,7 @@ def __init__(self, shots=None): self._frozen = True -def _postselection_postprocess(state, is_state_batched, shots, rng=None, prng_key=None): +def _postselection_postprocess(state, is_state_batched, shots, **execution_kwargs): """Update state after projector is applied.""" if is_state_batched: raise ValueError( @@ -73,11 +73,16 @@ def _postselection_postprocess(state, is_state_batched, shots, rng=None, prng_ke "postselection is used." ) + rng = execution_kwargs.get("rng", None) + prng_key = execution_kwargs.get("prng_key", None) + discard_invalid_shots = execution_kwargs.get("discard_invalid_shots", None) + # The floor function is being used here so that a norm very close to zero becomes exactly # equal to zero so that the state can become invalid. This way, execution can continue, and # bad postselection gives results that are invalid rather than results that look valid but # are incorrect. norm = qml.math.norm(state) + discard_invalid_shots = True if discard_invalid_shots is None else discard_invalid_shots if not qml.math.is_abstract(state) and qml.math.allclose(norm, 0.0): norm = 0.0 @@ -95,7 +100,7 @@ def _postselection_postprocess(state, is_state_batched, shots, rng=None, prng_ke postselected_shots = ( [int(binomial_fn(s, float(norm**2))) for s in shots] - if not qml.math.is_abstract(norm) + if discard_invalid_shots and not qml.math.is_abstract(norm) else shots ) @@ -132,6 +137,7 @@ def get_final_state(circuit, debugger=None, **execution_kwargs): prng_key = execution_kwargs.get("prng_key", None) interface = execution_kwargs.get("interface", None) mid_measurements = execution_kwargs.get("mid_measurements", None) + discard_invalid_shots = execution_kwargs.get("discard_invalid_shots", None) circuit = circuit.map_to_standard_wires() prep = None @@ -160,7 +166,12 @@ def get_final_state(circuit, debugger=None, **execution_kwargs): if isinstance(op, qml.Projector): prng_key, key = jax_random_split(prng_key) state, circuit._shots = _postselection_postprocess( - state, is_state_batched, circuit.shots, rng=rng, prng_key=key + state, + is_state_batched, + circuit.shots, + rng=rng, + prng_key=key, + discard_invalid_shots=discard_invalid_shots, ) # new state is batched if i) the old state is batched, or ii) the new op adds a batch dim @@ -276,6 +287,7 @@ def simulate( rng = execution_kwargs.get("rng", None) prng_key = execution_kwargs.get("prng_key", None) interface = execution_kwargs.get("interface", None) + discard_invalid_shots = execution_kwargs.get("discard_invalid_shots", None) has_mcm = any(isinstance(op, MidMeasureMP) for op in circuit.operations) if circuit.shots and has_mcm: @@ -285,7 +297,12 @@ def simulate( ops_key, meas_key = jax_random_split(prng_key) state, is_state_batched = get_final_state( - circuit, debugger=debugger, rng=rng, prng_key=ops_key, interface=interface + circuit, + debugger=debugger, + rng=rng, + prng_key=ops_key, + interface=interface, + discard_invalid_shots=discard_invalid_shots, ) if state_cache is not None: state_cache[circuit.hash] = state diff --git a/pennylane/transforms/defer_measurements.py b/pennylane/transforms/defer_measurements.py index d1c6f05a5f4..679b4ea2f11 100644 --- a/pennylane/transforms/defer_measurements.py +++ b/pennylane/transforms/defer_measurements.py @@ -275,6 +275,7 @@ def node(x): _check_tape_validity(tape) device = kwargs.get("device", None) + print("in defer_measurements") new_operations = [] diff --git a/pennylane/transforms/dynamic_one_shot.py b/pennylane/transforms/dynamic_one_shot.py index 86b57ca27a2..165cc3a362c 100644 --- a/pennylane/transforms/dynamic_one_shot.py +++ b/pennylane/transforms/dynamic_one_shot.py @@ -33,6 +33,7 @@ SampleMP, VarianceMP, ) +from pennylane.typing import TensorLike from .core import transform @@ -54,6 +55,8 @@ def dynamic_one_shot( Args: tape (QNode or QuantumTape or Callable): a quantum circuit to add a batch dimension to + discard_invalid_shots (bool): Whether or not to discard shots that don't match the + postselection criteria. ``True`` by default. Returns: qnode (QNode) or quantum function (Callable) or tuple[List[QuantumTape], function]: @@ -98,6 +101,8 @@ def func(x, y): "measurements." ) _ = kwargs.get("device", None) + discard_invalid_shots = kwargs.get("discard_invalid_shots", True) + print("in dynamic_one_shot") if not tape.shots: raise qml.QuantumFunctionError("dynamic_one_shot is only supported with finite shots.") @@ -148,7 +153,9 @@ def processing_fn(results, has_partitioned_shots=None, batched_results=None): ) del results[0:s] return tuple(final_results) - return parse_native_mid_circuit_measurements(tape, aux_tapes, results) + return parse_native_mid_circuit_measurements( + tape, aux_tapes, results, discard_invalid_shots + ) return output_tapes, processing_fn @@ -206,14 +213,19 @@ def init_auxiliary_tape(circuit: qml.tape.QuantumScript): def parse_native_mid_circuit_measurements( - circuit: qml.tape.QuantumScript, aux_tapes: qml.tape.QuantumScript, results + circuit: qml.tape.QuantumScript, + aux_tapes: qml.tape.QuantumScript, + results: TensorLike, + discard_invalid_shots: bool, ): """Combines, gathers and normalizes the results of native mid-circuit measurement runs. Args: - circuit (QuantumTape): A one-shot (auxiliary) QuantumScript - all_shot_meas (Sequence[Any]): List of accumulated measurement results - mcm_shot_meas (Sequence[dict]): List of dictionaries containing the mid-circuit measurement results of each shot + circuit (QuantumTape): Initial ``QuantumScript`` + aux_tapes (List[QuantumTape]): List of auxilary ``QuantumScript`` objects + results (TensorLike): Array of measurement results + discard_invalid_shots (bool): Whether or not to discard shots that don't match the + postselection criteria. Returns: tuple(TensorLike): The results of the simulation @@ -262,13 +274,13 @@ def measurement_with_no_shots(measurement): if interface != "jax" and m.mv and not has_valid: meas = measurement_with_no_shots(m) elif m.mv: - meas = gather_mcm(m, mcm_samples, is_valid) + meas = gather_mcm(m, mcm_samples, is_valid, discard_invalid_shots) elif interface != "jax" and not has_valid: meas = measurement_with_no_shots(m) m_count += 1 else: result = qml.math.array([res[m_count] for res in results], like=interface) - meas = gather_non_mcm(m, result, is_valid) + meas = gather_non_mcm(m, result, is_valid, discard_invalid_shots) m_count += 1 if isinstance(m, SampleMP): meas = qml.math.squeeze(meas) @@ -277,49 +289,54 @@ def measurement_with_no_shots(measurement): return tuple(normalized_meas) if len(normalized_meas) > 1 else normalized_meas[0] -def gather_non_mcm(circuit_measurement, measurement, is_valid): +def gather_non_mcm(measurement, samples, is_valid, discard_invalid_shots): """Combines, gathers and normalizes several measurements with trivial measurement values. Args: - circuit_measurement (MeasurementProcess): measurement - measurement (TensorLike): measurement results - samples (List[dict]): Mid-circuit measurement samples + measurement (MeasurementProcess): measurement + samples (TensorLike): measurement samples + is_valid (TensorLike): Boolean array with the same shape as ``samples`` where the value at + each index specifies whether or not the respective sample is valid. + discard_invalid_shots (bool): Whether or not to discard shots that don't match the + postselection criteria. Returns: TensorLike: The combined measurement outcome """ - if isinstance(circuit_measurement, CountsMP): + if isinstance(measurement, CountsMP): tmp = Counter() - for i, d in enumerate(measurement): + for i, d in enumerate(samples): tmp.update(dict((k, v * is_valid[i]) for k, v in d.items())) tmp = Counter({k: v for k, v in tmp.items() if v > 0}) return dict(sorted(tmp.items())) - if isinstance(circuit_measurement, ExpectationMP): - return qml.math.sum(measurement * is_valid) / qml.math.sum(is_valid) - if isinstance(circuit_measurement, ProbabilityMP): - return qml.math.sum(measurement * is_valid.reshape((-1, 1)), axis=0) / qml.math.sum( - is_valid - ) - if isinstance(circuit_measurement, SampleMP): + if isinstance(measurement, ExpectationMP): + return qml.math.sum(samples * is_valid) / qml.math.sum(is_valid) + if isinstance(measurement, ProbabilityMP): + return qml.math.sum(samples * is_valid.reshape((-1, 1)), axis=0) / qml.math.sum(is_valid) + if isinstance(measurement, SampleMP): is_interface_jax = qml.math.get_deep_interface(is_valid) == "jax" - if is_interface_jax and measurement.ndim == 2: + if (is_interface_jax or not discard_invalid_shots) and samples.ndim == 2: is_valid = is_valid.reshape((-1, 1)) return ( - qml.math.where(is_valid, measurement, fill_in_value) - if is_interface_jax - else measurement[is_valid] + qml.math.where(is_valid, samples, fill_in_value) + if is_interface_jax or not discard_invalid_shots + else samples[is_valid] ) # VarianceMP - expval = qml.math.sum(measurement * is_valid) / qml.math.sum(is_valid) - return qml.math.sum((measurement - expval) ** 2 * is_valid) / qml.math.sum(is_valid) + expval = qml.math.sum(samples * is_valid) / qml.math.sum(is_valid) + return qml.math.sum((samples - expval) ** 2 * is_valid) / qml.math.sum(is_valid) -def gather_mcm(measurement, samples, is_valid): +def gather_mcm(measurement, samples, is_valid, discard_invalid_shots): """Combines, gathers and normalizes several measurements with non-trivial measurement values. Args: measurement (MeasurementProcess): measurement samples (List[dict]): Mid-circuit measurement samples + is_valid (TensorLike): Boolean array with the same shape as ``samples`` where the value at + each index specifies whether or not the respective sample is valid. + discard_invalid_shots (bool): Whether or not to discard shots that don't match the + postselection criteria. Returns: TensorLike: The combined measurement outcome @@ -341,7 +358,7 @@ def gather_mcm(measurement, samples, is_valid): return counts / qml.math.sum(counts) if isinstance(measurement, CountsMP): mcm_samples = [{"".join(str(v) for v in tuple(s)): 1} for s in mcm_samples] - return gather_non_mcm(measurement, mcm_samples, is_valid) + return gather_non_mcm(measurement, mcm_samples, is_valid, discard_invalid_shots) if isinstance(measurement, ProbabilityMP): mcm_samples = qml.math.array(mv.concretize(samples), like=interface).ravel() counts = [qml.math.sum((mcm_samples == v) * is_valid) for v in list(mv.branches.values())] @@ -350,4 +367,4 @@ def gather_mcm(measurement, samples, is_valid): mcm_samples = qml.math.array([mv.concretize(samples)], like=interface).ravel() if isinstance(measurement, CountsMP): mcm_samples = [{s: 1} for s in mcm_samples] - return gather_non_mcm(measurement, mcm_samples, is_valid) + return gather_non_mcm(measurement, mcm_samples, is_valid, discard_invalid_shots) diff --git a/pennylane/workflow/execution.py b/pennylane/workflow/execution.py index 1d2e1f3e2ee..2ee110fe566 100644 --- a/pennylane/workflow/execution.py +++ b/pennylane/workflow/execution.py @@ -376,6 +376,7 @@ def execute( max_expansion=10, device_batch_transform=True, device_vjp=False, + mcm_config=None, ) -> ResultBatch: """New function to execute a batch of tapes on a device in an autodifferentiable-compatible manner. More cases will be added, during the project. The current version is supporting forward execution for NumPy and does not support shot vectors. @@ -422,6 +423,7 @@ def execute( constituent terms if not supported on the device. device_vjp=False (Optional[bool]): whether or not to use the device provided jacobian product if it is available. + mcm_config (dict): Dictionary containing configuration options for handling mid-circuit measurements. Returns: list[tensor_like[float]]: A nested list of tape results. Each element in @@ -543,9 +545,25 @@ def cost_fn(params, x): gradient_kwargs = gradient_kwargs or {} config = config or _get_execution_config( - gradient_fn, grad_on_execution, interface, device, device_vjp + gradient_fn, grad_on_execution, interface, device, device_vjp, mcm_config ) + if "jax" in interface and config.mcm_config["discard_invalid_shots"]: + warnings.warn( + "Cannot discard invalid shots with postselection when using the 'jax' interface. " + "Ignoring requested mid-circuit measurement configuration.", + UserWarning, + ) + config.mcm_config["discard_invalid_shots"] = None + + if any(not tape.shots for tape in tapes) and mcm_config["method"] == "one-shot": + warnings.warn( + "Cannot use the 'one-shot' method for mid-circuit measurements with " + "analytic mode. Using deferred measurements.", + UserWarning, + ) + config.mcm_config["method"] = None + if transform_program is None: if isinstance(device, qml.devices.Device): transform_program = device.preprocess(config)[0] @@ -798,7 +816,9 @@ def device_gradient_fn(inner_tapes, **gradient_kwargs): return post_processing(results) -def _get_execution_config(gradient_fn, grad_on_execution, interface, device, device_vjp): +def _get_execution_config( + gradient_fn, grad_on_execution, interface, device, device_vjp, mcm_config +): """Helper function to get the execution config.""" if gradient_fn is None: _gradient_method = None @@ -811,6 +831,7 @@ def _get_execution_config(gradient_fn, grad_on_execution, interface, device, dev gradient_method=_gradient_method, grad_on_execution=None if grad_on_execution == "best" else grad_on_execution, use_device_jacobian_product=device_vjp, + mcm_config=mcm_config, ) if isinstance(device, qml.devices.Device): _, config = device.preprocess(config) diff --git a/pennylane/workflow/qnode.py b/pennylane/workflow/qnode.py index 513abd88e83..287459fa65b 100644 --- a/pennylane/workflow/qnode.py +++ b/pennylane/workflow/qnode.py @@ -73,6 +73,8 @@ def _make_execution_config( grad_on_execution = False elif grad_on_execution == "best": grad_on_execution = None + mcm_config = getattr(circuit, "execute_kwargs", {}).get("mcm_config", None) + return qml.devices.ExecutionConfig( interface=getattr(circuit, "interface", None), gradient_method=_gradient_method, @@ -80,6 +82,7 @@ def _make_execution_config( use_device_jacobian_product=getattr(circuit, "execute_kwargs", {"device_vjp": False})[ "device_vjp" ], + mcm_config=mcm_config, ) @@ -216,6 +219,7 @@ class QNode: (classical) computational overhead during the backwards pass. device_vjp (bool): Whether or not to use the device-provided Vector Jacobian Product (VJP). A value of ``None`` indicates to use it if the device provides it, but use the full jacobian otherwise. + mcm_config (dict): Dictionary containing configuration options for handling mid-circuit measurements. Keyword Args: **kwargs: Any additional keyword arguments provided are passed to the differentiation @@ -439,6 +443,7 @@ def __init__( cachesize=10000, max_diff=1, device_vjp=False, + mcm_config=None, **gradient_kwargs, ): if logger.isEnabledFor(logging.DEBUG): @@ -513,6 +518,7 @@ def __init__( "max_diff": max_diff, "max_expansion": max_expansion, "device_vjp": device_vjp, + "mcm_config": mcm_config, } if self.expansion_strategy == "device": From 3e65b57e8fb721f582216abc9d4fd46784da9829 Mon Sep 17 00:00:00 2001 From: Mudit Pandey Date: Thu, 16 May 2024 10:26:15 -0400 Subject: [PATCH 02/48] Added qjit error --- pennylane/transforms/dynamic_one_shot.py | 12 ++++++++---- pennylane/workflow/execution.py | 2 +- 2 files changed, 9 insertions(+), 5 deletions(-) diff --git a/pennylane/transforms/dynamic_one_shot.py b/pennylane/transforms/dynamic_one_shot.py index 165cc3a362c..4bfbeb89350 100644 --- a/pennylane/transforms/dynamic_one_shot.py +++ b/pennylane/transforms/dynamic_one_shot.py @@ -101,8 +101,11 @@ def func(x, y): "measurements." ) _ = kwargs.get("device", None) - discard_invalid_shots = kwargs.get("discard_invalid_shots", True) - print("in dynamic_one_shot") + + discard_invalid_shots = kwargs.get("discard_invalid_shots", None) + if qml.compiler.active() and discard_invalid_shots: + raise ValueError("Can't discard invalid shots while using qml.qjit") + discard_invalid_shots = True if discard_invalid_shots is None else discard_invalid_shots if not tape.shots: raise qml.QuantumFunctionError("dynamic_one_shot is only supported with finite shots.") @@ -315,11 +318,12 @@ def gather_non_mcm(measurement, samples, is_valid, discard_invalid_shots): return qml.math.sum(samples * is_valid.reshape((-1, 1)), axis=0) / qml.math.sum(is_valid) if isinstance(measurement, SampleMP): is_interface_jax = qml.math.get_deep_interface(is_valid) == "jax" - if (is_interface_jax or not discard_invalid_shots) and samples.ndim == 2: + discard_invalid_shots = discard_invalid_shots and not is_interface_jax + if not discard_invalid_shots and samples.ndim == 2: is_valid = is_valid.reshape((-1, 1)) return ( qml.math.where(is_valid, samples, fill_in_value) - if is_interface_jax or not discard_invalid_shots + if not discard_invalid_shots else samples[is_valid] ) # VarianceMP diff --git a/pennylane/workflow/execution.py b/pennylane/workflow/execution.py index 2ee110fe566..b92c79d61a9 100644 --- a/pennylane/workflow/execution.py +++ b/pennylane/workflow/execution.py @@ -556,7 +556,7 @@ def cost_fn(params, x): ) config.mcm_config["discard_invalid_shots"] = None - if any(not tape.shots for tape in tapes) and mcm_config["method"] == "one-shot": + if any(not tape.shots for tape in tapes) and config.mcm_config["method"] == "one-shot": warnings.warn( "Cannot use the 'one-shot' method for mid-circuit measurements with " "analytic mode. Using deferred measurements.", From bf888833b9f85e4d3a29a6f8b6d13836b0b45524 Mon Sep 17 00:00:00 2001 From: Mudit Pandey Date: Mon, 20 May 2024 20:05:31 -0400 Subject: [PATCH 03/48] Update qnode signature --- pennylane/devices/default_qubit.py | 6 ++-- pennylane/devices/execution_config.py | 2 +- pennylane/devices/preprocess.py | 8 ++--- pennylane/devices/qubit/simulate.py | 18 ++++++----- pennylane/transforms/dynamic_one_shot.py | 40 +++++++++++------------- pennylane/workflow/execution.py | 8 ++--- pennylane/workflow/qnode.py | 13 ++++++-- 7 files changed, 50 insertions(+), 45 deletions(-) diff --git a/pennylane/devices/default_qubit.py b/pennylane/devices/default_qubit.py index 6039f51ccf8..4439a8040de 100644 --- a/pennylane/devices/default_qubit.py +++ b/pennylane/devices/default_qubit.py @@ -602,9 +602,7 @@ def execute( "interface": interface, "state_cache": self._state_cache, "prng_key": _key, - "discard_invalid_shots": execution_config.mcm_config[ - "discard_invalid_shots" - ], + "postselect_shots": execution_config.mcm_config["postselect_shots"], }, ) for c, _key in zip(circuits, prng_keys) @@ -616,7 +614,7 @@ def execute( { "rng": _rng, "prng_key": _key, - "discard_invalid_shots": execution_config.mcm_config["discard_invalid_shots"], + "postselect_shots": execution_config.mcm_config["postselect_shots"], } for _rng, _key in zip(seeds, prng_keys) ] diff --git a/pennylane/devices/execution_config.py b/pennylane/devices/execution_config.py index 90ca617d6df..849ff428299 100644 --- a/pennylane/devices/execution_config.py +++ b/pennylane/devices/execution_config.py @@ -94,7 +94,7 @@ def __post_init__(self): if self.mcm_config is None: self.mcm_config = {} - for option in ("discard_invalid_shots", "method"): + for option in ("postselect_shots", "mcm_method"): if option not in self.mcm_config: self.mcm_config[option] = None diff --git a/pennylane/devices/preprocess.py b/pennylane/devices/preprocess.py index d3bc7c983fd..9d018139872 100644 --- a/pennylane/devices/preprocess.py +++ b/pennylane/devices/preprocess.py @@ -153,10 +153,10 @@ def mid_circuit_measurements( and use the ``qml.defer_measurements`` transform otherwise. """ - if (mcm_method := mcm_config.get("method", None)) is not None: + if (mcm_method := mcm_config.get("mcm_method", None)) is not None: if mcm_method == "one-shot": return qml.dynamic_one_shot( - tape, discard_invalid_shots=mcm_config.get("discard_invalid_shots", None) + tape, postselect_shots=mcm_config.get("postselect_shots", None) ) if mcm_method == "deferred": return qml.defer_measurements(tape, device=device) @@ -165,9 +165,7 @@ def mid_circuit_measurements( UserWarning, ) if tape.shots: - return qml.dynamic_one_shot( - tape, discard_invalid_shots=mcm_config.get("discard_invalid_shots", None) - ) + return qml.dynamic_one_shot(tape, postselect_shots=mcm_config.get("postselect_shots", None)) return qml.defer_measurements(tape, device=device) diff --git a/pennylane/devices/qubit/simulate.py b/pennylane/devices/qubit/simulate.py index c420208d050..1c31e6ef6da 100644 --- a/pennylane/devices/qubit/simulate.py +++ b/pennylane/devices/qubit/simulate.py @@ -75,14 +75,14 @@ def _postselection_postprocess(state, is_state_batched, shots, **execution_kwarg rng = execution_kwargs.get("rng", None) prng_key = execution_kwargs.get("prng_key", None) - discard_invalid_shots = execution_kwargs.get("discard_invalid_shots", None) + postselect_shots = execution_kwargs.get("postselect_shots", None) # The floor function is being used here so that a norm very close to zero becomes exactly # equal to zero so that the state can become invalid. This way, execution can continue, and # bad postselection gives results that are invalid rather than results that look valid but # are incorrect. norm = qml.math.norm(state) - discard_invalid_shots = True if discard_invalid_shots is None else discard_invalid_shots + postselect_shots = True if postselect_shots is None else postselect_shots if not qml.math.is_abstract(state) and qml.math.allclose(norm, 0.0): norm = 0.0 @@ -100,7 +100,7 @@ def _postselection_postprocess(state, is_state_batched, shots, **execution_kwarg postselected_shots = ( [int(binomial_fn(s, float(norm**2))) for s in shots] - if discard_invalid_shots and not qml.math.is_abstract(norm) + if postselect_shots and not qml.math.is_abstract(norm) else shots ) @@ -127,6 +127,8 @@ def get_final_state(circuit, debugger=None, **execution_kwargs): prng_key (Optional[jax.random.PRNGKey]): An optional ``jax.random.PRNGKey``. This is the key to the JAX pseudo random number generator. Only for simulation using JAX. If None, a ``numpy.random.default_rng`` will be for sampling. + postselect_shots (bool): Whether or not to discard invalid shots when postselecting + mid-circuit measurements. Returns: Tuple[TensorLike, bool]: A tuple containing the final state of the quantum script and @@ -137,7 +139,7 @@ def get_final_state(circuit, debugger=None, **execution_kwargs): prng_key = execution_kwargs.get("prng_key", None) interface = execution_kwargs.get("interface", None) mid_measurements = execution_kwargs.get("mid_measurements", None) - discard_invalid_shots = execution_kwargs.get("discard_invalid_shots", None) + postselect_shots = execution_kwargs.get("postselect_shots", None) circuit = circuit.map_to_standard_wires() prep = None @@ -171,7 +173,7 @@ def get_final_state(circuit, debugger=None, **execution_kwargs): circuit.shots, rng=rng, prng_key=key, - discard_invalid_shots=discard_invalid_shots, + postselect_shots=postselect_shots, ) # new state is batched if i) the old state is batched, or ii) the new op adds a batch dim @@ -270,6 +272,8 @@ def simulate( the key to the JAX pseudo random number generator. If None, a random key will be generated. Only for simulation using JAX. interface (str): The machine learning interface to create the initial state with + postselect_shots (bool): Whether or not to discard invalid shots when postselecting + mid-circuit measurements. Returns: tuple(TensorLike): The results of the simulation @@ -287,7 +291,7 @@ def simulate( rng = execution_kwargs.get("rng", None) prng_key = execution_kwargs.get("prng_key", None) interface = execution_kwargs.get("interface", None) - discard_invalid_shots = execution_kwargs.get("discard_invalid_shots", None) + postselect_shots = execution_kwargs.get("postselect_shots", None) has_mcm = any(isinstance(op, MidMeasureMP) for op in circuit.operations) if circuit.shots and has_mcm: @@ -302,7 +306,7 @@ def simulate( rng=rng, prng_key=ops_key, interface=interface, - discard_invalid_shots=discard_invalid_shots, + postselect_shots=postselect_shots, ) if state_cache is not None: state_cache[circuit.hash] = state diff --git a/pennylane/transforms/dynamic_one_shot.py b/pennylane/transforms/dynamic_one_shot.py index 4bfbeb89350..80766ddb582 100644 --- a/pennylane/transforms/dynamic_one_shot.py +++ b/pennylane/transforms/dynamic_one_shot.py @@ -55,8 +55,8 @@ def dynamic_one_shot( Args: tape (QNode or QuantumTape or Callable): a quantum circuit to add a batch dimension to - discard_invalid_shots (bool): Whether or not to discard shots that don't match the - postselection criteria. ``True`` by default. + postselect_shots (bool): Whether or not to discard shots that don't match the + postselection criteria. Returns: qnode (QNode) or quantum function (Callable) or tuple[List[QuantumTape], function]: @@ -102,10 +102,10 @@ def func(x, y): ) _ = kwargs.get("device", None) - discard_invalid_shots = kwargs.get("discard_invalid_shots", None) - if qml.compiler.active() and discard_invalid_shots: + postselect_shots = kwargs.get("postselect_shots", None) + if qml.compiler.active() and postselect_shots: raise ValueError("Can't discard invalid shots while using qml.qjit") - discard_invalid_shots = True if discard_invalid_shots is None else discard_invalid_shots + postselect_shots = True if postselect_shots is None else postselect_shots if not tape.shots: raise qml.QuantumFunctionError("dynamic_one_shot is only supported with finite shots.") @@ -156,9 +156,7 @@ def processing_fn(results, has_partitioned_shots=None, batched_results=None): ) del results[0:s] return tuple(final_results) - return parse_native_mid_circuit_measurements( - tape, aux_tapes, results, discard_invalid_shots - ) + return parse_native_mid_circuit_measurements(tape, aux_tapes, results, postselect_shots) return output_tapes, processing_fn @@ -219,7 +217,7 @@ def parse_native_mid_circuit_measurements( circuit: qml.tape.QuantumScript, aux_tapes: qml.tape.QuantumScript, results: TensorLike, - discard_invalid_shots: bool, + postselect_shots: bool, ): """Combines, gathers and normalizes the results of native mid-circuit measurement runs. @@ -227,7 +225,7 @@ def parse_native_mid_circuit_measurements( circuit (QuantumTape): Initial ``QuantumScript`` aux_tapes (List[QuantumTape]): List of auxilary ``QuantumScript`` objects results (TensorLike): Array of measurement results - discard_invalid_shots (bool): Whether or not to discard shots that don't match the + postselect_shots (bool): Whether or not to discard shots that don't match the postselection criteria. Returns: @@ -277,13 +275,13 @@ def measurement_with_no_shots(measurement): if interface != "jax" and m.mv and not has_valid: meas = measurement_with_no_shots(m) elif m.mv: - meas = gather_mcm(m, mcm_samples, is_valid, discard_invalid_shots) + meas = gather_mcm(m, mcm_samples, is_valid, postselect_shots) elif interface != "jax" and not has_valid: meas = measurement_with_no_shots(m) m_count += 1 else: result = qml.math.array([res[m_count] for res in results], like=interface) - meas = gather_non_mcm(m, result, is_valid, discard_invalid_shots) + meas = gather_non_mcm(m, result, is_valid, postselect_shots) m_count += 1 if isinstance(m, SampleMP): meas = qml.math.squeeze(meas) @@ -292,7 +290,7 @@ def measurement_with_no_shots(measurement): return tuple(normalized_meas) if len(normalized_meas) > 1 else normalized_meas[0] -def gather_non_mcm(measurement, samples, is_valid, discard_invalid_shots): +def gather_non_mcm(measurement, samples, is_valid, postselect_shots): """Combines, gathers and normalizes several measurements with trivial measurement values. Args: @@ -300,7 +298,7 @@ def gather_non_mcm(measurement, samples, is_valid, discard_invalid_shots): samples (TensorLike): measurement samples is_valid (TensorLike): Boolean array with the same shape as ``samples`` where the value at each index specifies whether or not the respective sample is valid. - discard_invalid_shots (bool): Whether or not to discard shots that don't match the + postselect_shots (bool): Whether or not to discard shots that don't match the postselection criteria. Returns: @@ -318,12 +316,12 @@ def gather_non_mcm(measurement, samples, is_valid, discard_invalid_shots): return qml.math.sum(samples * is_valid.reshape((-1, 1)), axis=0) / qml.math.sum(is_valid) if isinstance(measurement, SampleMP): is_interface_jax = qml.math.get_deep_interface(is_valid) == "jax" - discard_invalid_shots = discard_invalid_shots and not is_interface_jax - if not discard_invalid_shots and samples.ndim == 2: + postselect_shots = postselect_shots and not is_interface_jax + if not postselect_shots and samples.ndim == 2: is_valid = is_valid.reshape((-1, 1)) return ( qml.math.where(is_valid, samples, fill_in_value) - if not discard_invalid_shots + if not postselect_shots else samples[is_valid] ) # VarianceMP @@ -331,7 +329,7 @@ def gather_non_mcm(measurement, samples, is_valid, discard_invalid_shots): return qml.math.sum((samples - expval) ** 2 * is_valid) / qml.math.sum(is_valid) -def gather_mcm(measurement, samples, is_valid, discard_invalid_shots): +def gather_mcm(measurement, samples, is_valid, postselect_shots): """Combines, gathers and normalizes several measurements with non-trivial measurement values. Args: @@ -339,7 +337,7 @@ def gather_mcm(measurement, samples, is_valid, discard_invalid_shots): samples (List[dict]): Mid-circuit measurement samples is_valid (TensorLike): Boolean array with the same shape as ``samples`` where the value at each index specifies whether or not the respective sample is valid. - discard_invalid_shots (bool): Whether or not to discard shots that don't match the + postselect_shots (bool): Whether or not to discard shots that don't match the postselection criteria. Returns: @@ -362,7 +360,7 @@ def gather_mcm(measurement, samples, is_valid, discard_invalid_shots): return counts / qml.math.sum(counts) if isinstance(measurement, CountsMP): mcm_samples = [{"".join(str(v) for v in tuple(s)): 1} for s in mcm_samples] - return gather_non_mcm(measurement, mcm_samples, is_valid, discard_invalid_shots) + return gather_non_mcm(measurement, mcm_samples, is_valid, postselect_shots) if isinstance(measurement, ProbabilityMP): mcm_samples = qml.math.array(mv.concretize(samples), like=interface).ravel() counts = [qml.math.sum((mcm_samples == v) * is_valid) for v in list(mv.branches.values())] @@ -371,4 +369,4 @@ def gather_mcm(measurement, samples, is_valid, discard_invalid_shots): mcm_samples = qml.math.array([mv.concretize(samples)], like=interface).ravel() if isinstance(measurement, CountsMP): mcm_samples = [{s: 1} for s in mcm_samples] - return gather_non_mcm(measurement, mcm_samples, is_valid, discard_invalid_shots) + return gather_non_mcm(measurement, mcm_samples, is_valid, postselect_shots) diff --git a/pennylane/workflow/execution.py b/pennylane/workflow/execution.py index b92c79d61a9..fddb959fa38 100644 --- a/pennylane/workflow/execution.py +++ b/pennylane/workflow/execution.py @@ -548,21 +548,21 @@ def cost_fn(params, x): gradient_fn, grad_on_execution, interface, device, device_vjp, mcm_config ) - if "jax" in interface and config.mcm_config["discard_invalid_shots"]: + if "jax" in interface and config.mcm_config["postselect_shots"]: warnings.warn( "Cannot discard invalid shots with postselection when using the 'jax' interface. " "Ignoring requested mid-circuit measurement configuration.", UserWarning, ) - config.mcm_config["discard_invalid_shots"] = None + config.mcm_config["postselect_shots"] = None - if any(not tape.shots for tape in tapes) and config.mcm_config["method"] == "one-shot": + if any(not tape.shots for tape in tapes) and config.mcm_config["mcm_method"] == "one-shot": warnings.warn( "Cannot use the 'one-shot' method for mid-circuit measurements with " "analytic mode. Using deferred measurements.", UserWarning, ) - config.mcm_config["method"] = None + config.mcm_config["mcm_method"] = None if transform_program is None: if isinstance(device, qml.devices.Device): diff --git a/pennylane/workflow/qnode.py b/pennylane/workflow/qnode.py index 287459fa65b..0909a60f58f 100644 --- a/pennylane/workflow/qnode.py +++ b/pennylane/workflow/qnode.py @@ -219,7 +219,13 @@ class QNode: (classical) computational overhead during the backwards pass. device_vjp (bool): Whether or not to use the device-provided Vector Jacobian Product (VJP). A value of ``None`` indicates to use it if the device provides it, but use the full jacobian otherwise. - mcm_config (dict): Dictionary containing configuration options for handling mid-circuit measurements. + postselect_shots (bool): Whether or not to discard invalid shots when postselecting mid-circuit measurements. + If ``True``, invalid shots will be discarded and only results for valid shots will be returned. If + ``False``, results corresponding to the original number of shots will be returned. + mcm_method (str): Strategy to use when executing circuits with mid-circuit measurements. Use ``"deferred"`` + to execute using the deferred measurements principle (applied using the + :func:`~pennylane.defer_measurements` transform), or ``"one-shot"`` if using finite shots to execute the + circuit for each shot separately. Keyword Args: **kwargs: Any additional keyword arguments provided are passed to the differentiation @@ -443,7 +449,8 @@ def __init__( cachesize=10000, max_diff=1, device_vjp=False, - mcm_config=None, + postselect_shots=None, + mcm_method=None, **gradient_kwargs, ): if logger.isEnabledFor(logging.DEBUG): @@ -518,7 +525,7 @@ def __init__( "max_diff": max_diff, "max_expansion": max_expansion, "device_vjp": device_vjp, - "mcm_config": mcm_config, + "mcm_config": {"postselect_shots": postselect_shots, "mcm_method": mcm_method}, } if self.expansion_strategy == "device": From 7c37074137e861d3abfcfdfe712fc711c5da018c Mon Sep 17 00:00:00 2001 From: Mudit Pandey Date: Mon, 20 May 2024 21:43:29 -0400 Subject: [PATCH 04/48] Minor changes to tidy up --- pennylane/devices/preprocess.py | 4 +++- pennylane/transforms/defer_measurements.py | 7 ++++++- pennylane/transforms/dynamic_one_shot.py | 1 + pennylane/workflow/execution.py | 4 ++-- 4 files changed, 12 insertions(+), 4 deletions(-) diff --git a/pennylane/devices/preprocess.py b/pennylane/devices/preprocess.py index 9d018139872..f4c70dcbef7 100644 --- a/pennylane/devices/preprocess.py +++ b/pennylane/devices/preprocess.py @@ -145,7 +145,7 @@ def validate_device_wires( @transform def mid_circuit_measurements( - tape: qml.tape.QuantumTape, device, mcm_config + tape: qml.tape.QuantumTape, device, mcm_config=None ) -> (Sequence[qml.tape.QuantumTape], Callable): """Provide the transform to handle mid-circuit measurements. @@ -153,6 +153,8 @@ def mid_circuit_measurements( and use the ``qml.defer_measurements`` transform otherwise. """ + mcm_config = mcm_config or {} + if (mcm_method := mcm_config.get("mcm_method", None)) is not None: if mcm_method == "one-shot": return qml.dynamic_one_shot( diff --git a/pennylane/transforms/defer_measurements.py b/pennylane/transforms/defer_measurements.py index 679b4ea2f11..10ce94c8e5d 100644 --- a/pennylane/transforms/defer_measurements.py +++ b/pennylane/transforms/defer_measurements.py @@ -103,7 +103,7 @@ def null_postprocessing(results): @transform def defer_measurements( - tape: QuantumTape, reduce_postselected: bool = True, **kwargs + tape: QuantumTape, reduce_postselected: bool = True, postselect_shots: bool = None, **kwargs ) -> (Sequence[QuantumTape], Callable): """Quantum function transform that substitutes operations conditioned on measurement outcomes to controlled operations. @@ -158,6 +158,8 @@ def defer_measurements( tape (QNode or QuantumTape or Callable): a quantum circuit. reduce_postselected (bool): Whether or not to use postselection information to reduce the number of operations and control wires in the output tape. Active by default. + postselect_shots (bool): Whether or not to discard shots that don't match the + postselection criteria. Returns: qnode (QNode) or quantum function (Callable) or tuple[List[QuantumTape], function]: The @@ -272,6 +274,9 @@ def node(x): if not any(isinstance(o, MidMeasureMP) for o in tape.operations): return (tape,), null_postprocessing + if qml.compiler.active() and postselect_shots: + raise ValueError("Can't discard invalid shots while using qml.qjit") + _check_tape_validity(tape) device = kwargs.get("device", None) diff --git a/pennylane/transforms/dynamic_one_shot.py b/pennylane/transforms/dynamic_one_shot.py index 80766ddb582..8dfde9f25bc 100644 --- a/pennylane/transforms/dynamic_one_shot.py +++ b/pennylane/transforms/dynamic_one_shot.py @@ -101,6 +101,7 @@ def func(x, y): "measurements." ) _ = kwargs.get("device", None) + print("in dynamic_one_shot") postselect_shots = kwargs.get("postselect_shots", None) if qml.compiler.active() and postselect_shots: diff --git a/pennylane/workflow/execution.py b/pennylane/workflow/execution.py index fddb959fa38..fdd5a32f0fd 100644 --- a/pennylane/workflow/execution.py +++ b/pennylane/workflow/execution.py @@ -554,7 +554,7 @@ def cost_fn(params, x): "Ignoring requested mid-circuit measurement configuration.", UserWarning, ) - config.mcm_config["postselect_shots"] = None + config.mcm_config["postselect_shots"] = False if any(not tape.shots for tape in tapes) and config.mcm_config["mcm_method"] == "one-shot": warnings.warn( @@ -562,7 +562,7 @@ def cost_fn(params, x): "analytic mode. Using deferred measurements.", UserWarning, ) - config.mcm_config["mcm_method"] = None + config.mcm_config["mcm_method"] = "deferred" if transform_program is None: if isinstance(device, qml.devices.Device): From b0db9207d53d016dbd44814a362f83d65d4c42a6 Mon Sep 17 00:00:00 2001 From: Mudit Pandey Date: Tue, 21 May 2024 09:54:13 -0400 Subject: [PATCH 05/48] [skip ci] update changelog --- doc/releases/changelog-dev.md | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/doc/releases/changelog-dev.md b/doc/releases/changelog-dev.md index 5ac7447f955..c05c57aafed 100644 --- a/doc/releases/changelog-dev.md +++ b/doc/releases/changelog-dev.md @@ -4,13 +4,25 @@

New features since last release

+* `qml.QNode` and `qml.qnode` now accept two new keyword arguments: `postselect_shots` and `mcm_method`. + These keyword arguments can be used to configure how the device should behave when running circuits with + mid-circuit measurements. + [(#5679)](https://github.com/PennyLaneAI/pennylane/pull/5679) + + * `postselect_shots=True` will indicate to devices to discard invalid shots when postselecting + mid-circuit measurements. Use `postselect_shots=False` to return invalid shots, which will be replaced + by dummy values. + * `mcm_method` will indicate which strategy to use for running circuits with mid-circuit measurements. + Use `mcm_method="deferred"` to use the deferred measurements principle, or `mcm_method="one-shot"` + to execute once for each shot. +

Improvements 🛠

Mid-circuit measurements and dynamic circuits

* The `dynamic_one_shot` transform can be compiled with `jax.jit`. [(#5557)](https://github.com/PennyLaneAI/pennylane/pull/5557) - + * When using `defer_measurements` with postselecting mid-circuit measurements, operations that will never be active due to the postselected state are skipped in the transformed quantum circuit. In addition, postselected controls are skipped, as they are evaluated From b20ee6598e654c678e598248d3fdc0159427bd84 Mon Sep 17 00:00:00 2001 From: Mudit Pandey Date: Tue, 21 May 2024 12:41:42 -0400 Subject: [PATCH 06/48] Update qnode processing for old devices --- pennylane/devices/preprocess.py | 13 +++++------ pennylane/workflow/qnode.py | 39 +++++++++++++++++++-------------- 2 files changed, 28 insertions(+), 24 deletions(-) diff --git a/pennylane/devices/preprocess.py b/pennylane/devices/preprocess.py index f4c70dcbef7..53a127655d8 100644 --- a/pennylane/devices/preprocess.py +++ b/pennylane/devices/preprocess.py @@ -160,15 +160,12 @@ def mid_circuit_measurements( return qml.dynamic_one_shot( tape, postselect_shots=mcm_config.get("postselect_shots", None) ) - if mcm_method == "deferred": - return qml.defer_measurements(tape, device=device) - warnings.warn( - "Invalid mid-circuit measurements method. Automatically detecting optimal method.", - UserWarning, - ) + return qml.defer_measurements(tape, device=device) + if tape.shots: - return qml.dynamic_one_shot(tape, postselect_shots=mcm_config.get("postselect_shots", None)) - return qml.defer_measurements(tape, device=device) + postselect_shots = mcm_config.get("postselect_shots", None) + return qml.dynamic_one_shot(tape, postselect_shots=postselect_shots) + return qml.defer_measurements(tape, device=device, postselect_shots=postselect_shots) @transform diff --git a/pennylane/workflow/qnode.py b/pennylane/workflow/qnode.py index 0909a60f58f..1c3ef764cc1 100644 --- a/pennylane/workflow/qnode.py +++ b/pennylane/workflow/qnode.py @@ -985,20 +985,6 @@ def construct(self, args, kwargs): # pylint: disable=too-many-branches # check here only if enough wires raise qml.QuantumFunctionError(f"Operator {obj.name} must act on all wires") - # Apply the deferred measurement principle if the device doesn't - # support mid-circuit measurements natively. - # Only apply transform with old device API as postselection with - # broadcasting will split tapes. - expand_mid_measure = ( - any(isinstance(op, MidMeasureMP) for op in self.tape.operations) - and not isinstance(self.device, qml.devices.Device) - and not self.device.capabilities().get("supports_mid_measure", False) - ) - if expand_mid_measure or self.expansion_strategy == "device": - # Assume that tapes are not split if old device is used since postselection is not supported. - tapes, _ = qml.defer_measurements(self._tape, device=self.device) - self._tape = tapes[0] - if self.expansion_strategy == "device": if isinstance(self.device, qml.devices.Device): tape, _ = self.device.preprocess()[0]([self.tape]) @@ -1034,6 +1020,15 @@ def _execution_component(self, args: tuple, kwargs: dict, override_shots) -> qml ) self._tape_cached = using_custom_cache and self.tape.hash in cache + mcm_method = self.execute_kwargs["mcm_config"].get("mcm_method", None) + if mcm_method not in ("deferred", "one-shot", None): + warnings.warn( + f"Invalid mid-circuit measurements method '{mcm_method}'. Automatically " + "detecting optimal method.", + UserWarning, + ) + mcm_method = self.execute_kwargs["mcm_config"]["mcm_method"] = None + # Add the device program to the QNode program if isinstance(self.device, qml.devices.Device): config = _make_execution_config(self, self.gradient_fn) @@ -1042,14 +1037,26 @@ def _execution_component(self, args: tuple, kwargs: dict, override_shots) -> qml else: config = None full_transform_program = qml.transforms.core.TransformProgram(self.transform_program) + has_mcm_support = ( any(isinstance(op, MidMeasureMP) for op in self._tape) and hasattr(self.device, "capabilities") and self.device.capabilities().get("supports_mid_measure", False) ) if has_mcm_support: - full_transform_program.add_transform(qml.dynamic_one_shot) - override_shots = 1 + full_transform_program.add_transform( + qml.devices.preprocess.mid_circuit_measurements, + device=self.device, + mcm_config=self.execute_kwargs["mcm_config"], + ) + if override_shots and mcm_method in ("one-shot", None): + override_shots = 1 + elif hasattr(self.device, "capabilities"): + full_transform_program.add_transform( + qml.defer_measurements, + postselect_shots=self.execute_kwargs["mcm_config"]["postselect_shots"], + device=self.device, + ) # Add the gradient expand to the program if necessary if getattr(self.gradient_fn, "expand_transform", False): From 173d03fb91083bb84d137d0889f14a38e1a30671 Mon Sep 17 00:00:00 2001 From: Mudit Pandey Date: Tue, 21 May 2024 12:48:45 -0400 Subject: [PATCH 07/48] [skip ci] Skip CI From bb0db18d6ee280de738c952001bd7bd0f7a57e6c Mon Sep 17 00:00:00 2001 From: Mudit Pandey Date: Wed, 22 May 2024 13:20:43 -0400 Subject: [PATCH 08/48] Reverting debugging changes --- pennylane/devices/preprocess.py | 9 ++++----- pennylane/transforms/defer_measurements.py | 3 +-- pennylane/transforms/dynamic_one_shot.py | 3 +-- pennylane/workflow/qnode.py | 3 ++- .../experimental/test_execution_config.py | 1 + tests/transforms/test_dynamic_one_shot.py | 20 +++++++++++++++++++ 6 files changed, 29 insertions(+), 10 deletions(-) diff --git a/pennylane/devices/preprocess.py b/pennylane/devices/preprocess.py index 53a127655d8..1d923159bce 100644 --- a/pennylane/devices/preprocess.py +++ b/pennylane/devices/preprocess.py @@ -154,16 +154,15 @@ def mid_circuit_measurements( """ mcm_config = mcm_config or {} + postselect_shots = mcm_config.get("postselect_shots", None) + mcm_method = mcm_config.get("mcm_method", None) - if (mcm_method := mcm_config.get("mcm_method", None)) is not None: + if mcm_method is not None: if mcm_method == "one-shot": - return qml.dynamic_one_shot( - tape, postselect_shots=mcm_config.get("postselect_shots", None) - ) + return qml.dynamic_one_shot(tape, postselect_shots=postselect_shots) return qml.defer_measurements(tape, device=device) if tape.shots: - postselect_shots = mcm_config.get("postselect_shots", None) return qml.dynamic_one_shot(tape, postselect_shots=postselect_shots) return qml.defer_measurements(tape, device=device, postselect_shots=postselect_shots) diff --git a/pennylane/transforms/defer_measurements.py b/pennylane/transforms/defer_measurements.py index 10ce94c8e5d..63a7f9c3769 100644 --- a/pennylane/transforms/defer_measurements.py +++ b/pennylane/transforms/defer_measurements.py @@ -275,12 +275,11 @@ def node(x): return (tape,), null_postprocessing if qml.compiler.active() and postselect_shots: - raise ValueError("Can't discard invalid shots while using qml.qjit") + raise ValueError("Cannot discard invalid shots while using qml.qjit") _check_tape_validity(tape) device = kwargs.get("device", None) - print("in defer_measurements") new_operations = [] diff --git a/pennylane/transforms/dynamic_one_shot.py b/pennylane/transforms/dynamic_one_shot.py index 26aaa521d7a..526217b1774 100644 --- a/pennylane/transforms/dynamic_one_shot.py +++ b/pennylane/transforms/dynamic_one_shot.py @@ -100,11 +100,10 @@ def func(x, y): "measurements." ) _ = kwargs.get("device", None) - print("in dynamic_one_shot") postselect_shots = kwargs.get("postselect_shots", None) if qml.compiler.active() and postselect_shots: - raise ValueError("Can't discard invalid shots while using qml.qjit") + raise ValueError("Cannot discard invalid shots while using qml.qjit") postselect_shots = True if postselect_shots is None else postselect_shots if not tape.shots: diff --git a/pennylane/workflow/qnode.py b/pennylane/workflow/qnode.py index 1c3ef764cc1..08d7795db9d 100644 --- a/pennylane/workflow/qnode.py +++ b/pennylane/workflow/qnode.py @@ -1049,7 +1049,8 @@ def _execution_component(self, args: tuple, kwargs: dict, override_shots) -> qml device=self.device, mcm_config=self.execute_kwargs["mcm_config"], ) - if override_shots and mcm_method in ("one-shot", None): + shots = self.device.shots if override_shots is False else override_shots + if shots and mcm_method in ("one-shot", None): override_shots = 1 elif hasattr(self.device, "capabilities"): full_transform_program.add_transform( diff --git a/tests/devices/experimental/test_execution_config.py b/tests/devices/experimental/test_execution_config.py index ce751d9e2f2..01142a7a653 100644 --- a/tests/devices/experimental/test_execution_config.py +++ b/tests/devices/experimental/test_execution_config.py @@ -30,6 +30,7 @@ def test_default_values(): assert config.gradient_keyword_arguments == {} assert config.grad_on_execution is None assert config.use_device_gradient is None + assert config.mcm_config == {"postselect_shots": None, "mcm_method": None} def test_invalid_interface(): diff --git a/tests/transforms/test_dynamic_one_shot.py b/tests/transforms/test_dynamic_one_shot.py index 3ca186171fe..17d31385b9b 100644 --- a/tests/transforms/test_dynamic_one_shot.py +++ b/tests/transforms/test_dynamic_one_shot.py @@ -16,6 +16,8 @@ """ # pylint: disable=too-few-public-methods, too-many-arguments +from functools import partial + import numpy as np import pytest @@ -61,6 +63,24 @@ def _(): return qml.probs(wires=[0]) +def test_qjit_postselection_error(): + """Test that an error is raised if using qml.qjit and requesting `postselect_shots=True`""" + catalyst = pytest.importorskip("catalyst") + + dev = qml.device("lightning.qubit", wires=3, shots=10) + + @qml.qjit + @partial(qml.dynamic_one_shot, postselect_shots=True) + @qml.qnode(dev) + def func(x): + qml.RX(x, 0) + _ = catalyst.measure(0, postselect=0) + return qml.sample(wires=[0, 1]) + + with pytest.raises(ValueError, match="Cannot discard invalid shots while using qml.qjit"): + _ = func(1.8) + + def test_unsupported_measurements(): """Test that using unsupported measurements raises an error.""" tape = qml.tape.QuantumScript([MidMeasureMP(0)], [qml.state()]) From 029ac2af90b0790eda7df30497f4450396c46c1c Mon Sep 17 00:00:00 2001 From: Mudit Pandey Date: Wed, 22 May 2024 15:15:54 -0400 Subject: [PATCH 09/48] Fix interface check in qml.execute --- pennylane/workflow/execution.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/pennylane/workflow/execution.py b/pennylane/workflow/execution.py index 37ddb7d2094..039b535b804 100644 --- a/pennylane/workflow/execution.py +++ b/pennylane/workflow/execution.py @@ -549,7 +549,7 @@ def cost_fn(params, x): gradient_fn, grad_on_execution, interface, device, device_vjp, mcm_config ) - if "jax" in interface and config.mcm_config["postselect_shots"]: + if interface in {"jax", "jax-jit"} and config.mcm_config.get("postselect_shots", None): warnings.warn( "Cannot discard invalid shots with postselection when using the 'jax' interface. " "Ignoring requested mid-circuit measurement configuration.", @@ -557,7 +557,10 @@ def cost_fn(params, x): ) config.mcm_config["postselect_shots"] = False - if any(not tape.shots for tape in tapes) and config.mcm_config["mcm_method"] == "one-shot": + if ( + any(not tape.shots for tape in tapes) + and config.mcm_config.get("mcm_method", None) == "one-shot" + ): warnings.warn( "Cannot use the 'one-shot' method for mid-circuit measurements with " "analytic mode. Using deferred measurements.", From cd21934ffd059a0328865bf8d13c0bb85233484e Mon Sep 17 00:00:00 2001 From: Mudit Pandey Date: Wed, 22 May 2024 15:35:15 -0400 Subject: [PATCH 10/48] Added preprocess tests --- pennylane/workflow/execution.py | 7 ++---- tests/devices/test_preprocess.py | 37 ++++++++++++++++++++++++++++++++ 2 files changed, 39 insertions(+), 5 deletions(-) diff --git a/pennylane/workflow/execution.py b/pennylane/workflow/execution.py index 039b535b804..fb5c5eed7c7 100644 --- a/pennylane/workflow/execution.py +++ b/pennylane/workflow/execution.py @@ -549,7 +549,7 @@ def cost_fn(params, x): gradient_fn, grad_on_execution, interface, device, device_vjp, mcm_config ) - if interface in {"jax", "jax-jit"} and config.mcm_config.get("postselect_shots", None): + if interface in {"jax", "jax-jit"} and config.mcm_config["postselect_shots"]: warnings.warn( "Cannot discard invalid shots with postselection when using the 'jax' interface. " "Ignoring requested mid-circuit measurement configuration.", @@ -557,10 +557,7 @@ def cost_fn(params, x): ) config.mcm_config["postselect_shots"] = False - if ( - any(not tape.shots for tape in tapes) - and config.mcm_config.get("mcm_method", None) == "one-shot" - ): + if any(not tape.shots for tape in tapes) and config.mcm_config["mcm_method"] == "one-shot": warnings.warn( "Cannot use the 'one-shot' method for mid-circuit measurements with " "analytic mode. Using deferred measurements.", diff --git a/tests/devices/test_preprocess.py b/tests/devices/test_preprocess.py index 38c3d48371e..74f8ac386ae 100644 --- a/tests/devices/test_preprocess.py +++ b/tests/devices/test_preprocess.py @@ -21,6 +21,7 @@ from pennylane.devices.preprocess import ( _operator_decomposition_gen, decompose, + mid_circuit_measurements, no_sampling, validate_adjoint_trainable_params, validate_device_wires, @@ -440,6 +441,42 @@ def test_decompose_initial_state_prep_if_requested(self, prep_op): assert new_tape[0] != prep_op +class TestMidCircuitMeasurements: + """Unit tests for the mid_circuit_measurements preprocessing transform""" + + @pytest.mark.parametrize( + "mcm_method, shots, expected_transform", + [ + ("deferred", 10, "defer_measurements"), + ("deferred", None, "defer_measurements"), + (None, None, "defer_measurements"), + (None, 10, "dynamic_one_shot"), + ("one-shot", 10, "dynamic_one_shot"), + ], + ) + def test_mcm_method(self, mcm_method, shots, expected_transform, mocker): + """Test that the transform adheres to the specified transform""" + dev = qml.device("default.qubit") + mcm_config = {"postselect_shots": None, "mcm_method": mcm_method} + tape = QuantumScript([qml.measurements.MidMeasureMP(0)], [], shots=shots) + spy = mocker.spy(qml, expected_transform) + + _, _ = mid_circuit_measurements(tape, dev, mcm_config) + spy.assert_called_once() + + def test_error_incompatible_mcm_method(self): + """Test that an error is raised if requesting the one-shot transform without shots""" + dev = qml.device("default.qubit") + shots = None + mcm_config = {"postselect_shots": None, "mcm_method": "one-shot"} + tape = QuantumScript([qml.measurements.MidMeasureMP(0)], [], shots=shots) + + with pytest.raises( + qml.QuantumFunctionError, match="dynamic_one_shot is only supported with finite shots." + ): + _, _ = mid_circuit_measurements(tape, dev, mcm_config) + + def test_validate_multiprocessing_workers_None(): """Test that validation does not fail when max_workers is None""" qs = QuantumScript( From fdbaf42df98a105b7a91f3007b65fe9134b01403 Mon Sep 17 00:00:00 2001 From: Mudit Pandey Date: Wed, 22 May 2024 17:19:06 -0400 Subject: [PATCH 11/48] Added defer_measurements tests --- pennylane/transforms/defer_measurements.py | 6 +-- pennylane/transforms/dynamic_one_shot.py | 2 - tests/transforms/test_defer_measurements.py | 44 +++++++++++++++++++++ tests/transforms/test_dynamic_one_shot.py | 23 +++++++++++ 4 files changed, 69 insertions(+), 6 deletions(-) diff --git a/pennylane/transforms/defer_measurements.py b/pennylane/transforms/defer_measurements.py index 63a7f9c3769..03c03230f69 100644 --- a/pennylane/transforms/defer_measurements.py +++ b/pennylane/transforms/defer_measurements.py @@ -103,7 +103,7 @@ def null_postprocessing(results): @transform def defer_measurements( - tape: QuantumTape, reduce_postselected: bool = True, postselect_shots: bool = None, **kwargs + tape: QuantumTape, reduce_postselected: bool = True, **kwargs ) -> (Sequence[QuantumTape], Callable): """Quantum function transform that substitutes operations conditioned on measurement outcomes to controlled operations. @@ -158,8 +158,6 @@ def defer_measurements( tape (QNode or QuantumTape or Callable): a quantum circuit. reduce_postselected (bool): Whether or not to use postselection information to reduce the number of operations and control wires in the output tape. Active by default. - postselect_shots (bool): Whether or not to discard shots that don't match the - postselection criteria. Returns: qnode (QNode) or quantum function (Callable) or tuple[List[QuantumTape], function]: The @@ -274,7 +272,7 @@ def node(x): if not any(isinstance(o, MidMeasureMP) for o in tape.operations): return (tape,), null_postprocessing - if qml.compiler.active() and postselect_shots: + if qml.compiler.active() and kwargs.get("postselect_shots", None): raise ValueError("Cannot discard invalid shots while using qml.qjit") _check_tape_validity(tape) diff --git a/pennylane/transforms/dynamic_one_shot.py b/pennylane/transforms/dynamic_one_shot.py index 526217b1774..77ac81cac62 100644 --- a/pennylane/transforms/dynamic_one_shot.py +++ b/pennylane/transforms/dynamic_one_shot.py @@ -55,8 +55,6 @@ def dynamic_one_shot( Args: tape (QNode or QuantumTape or Callable): a quantum circuit to add a batch dimension to - postselect_shots (bool): Whether or not to discard shots that don't match the - postselection criteria. Returns: qnode (QNode) or quantum function (Callable) or tuple[List[QuantumTape], function]: diff --git a/tests/transforms/test_defer_measurements.py b/tests/transforms/test_defer_measurements.py index cdc1fd5f138..333bb63756e 100644 --- a/tests/transforms/test_defer_measurements.py +++ b/tests/transforms/test_defer_measurements.py @@ -105,6 +105,50 @@ def circ(): _ = circ() +def test_qjit_postselection_error(monkeypatch): + """Test that an error is raised if qjit is active with `postselect=True`""" + # TODO: Update test once defer_measurements can be used with qjit + # catalyst = pytest.importorskip("catalyst") + dev = qml.device("lightning.qubit", wires=3, shots=10) + + # @qml.qjit + @partial(qml.defer_measurements, postselect_shots=True) + @qml.qnode(dev) + def func(x): + qml.RX(x, 0) + _ = qml.measure(0, postselect=0) + # _ = catalyst.measure(0, postselect=0) + return qml.sample(wires=[0, 1]) + + # Mocking qml.compiler.active() to always return True + with monkeypatch.context() as m: + m.setattr(qml.compiler, "active", lambda: True) + with pytest.raises(ValueError, match="Cannot discard invalid shots while using qml.qjit"): + _ = func(1.8) + + +@pytest.mark.parametrize("postselect_shots", [True, False]) +def test_postselect_shots(postselect_shots, mocker): + """Test that invalid shots are discarded if requested""" + shots = 100 + dev = qml.device("default.qubit", shots=shots) + spy = mocker.spy(qml, "defer_measurements") + + @qml.qnode(dev, postselect_shots=postselect_shots, mcm_method="deferred") + def f(x): + qml.RX(x, 0) + _ = qml.measure(0, postselect=1) + return qml.sample(wires=[0, 1]) + + res = f(np.pi / 2) + spy.assert_called_once() + + if postselect_shots: + assert len(res) < shots + else: + assert len(res) == shots + + @pytest.mark.parametrize( "mp, err_msg", [ diff --git a/tests/transforms/test_dynamic_one_shot.py b/tests/transforms/test_dynamic_one_shot.py index 17d31385b9b..2855cc3251a 100644 --- a/tests/transforms/test_dynamic_one_shot.py +++ b/tests/transforms/test_dynamic_one_shot.py @@ -81,6 +81,29 @@ def func(x): _ = func(1.8) +@pytest.mark.parametrize("postselect_shots", [True, False]) +def test_postselect_shots(postselect_shots, mocker): + """Test that invalid shots are discarded if requested""" + shots = 100 + dev = qml.device("default.qubit", shots=shots) + spy = mocker.spy(qml, "dynamic_one_shot") + + @qml.qnode(dev, postselect_shots=postselect_shots) + def f(x): + qml.RX(x, 0) + _ = qml.measure(0, postselect=1) + return qml.sample(wires=[0, 1]) + + res = f(np.pi / 2) + spy.assert_called_once() + + if postselect_shots: + assert len(res) < shots + else: + assert len(res) == shots + assert np.any(res == np.iinfo(np.int32).min) + + def test_unsupported_measurements(): """Test that using unsupported measurements raises an error.""" tape = qml.tape.QuantumScript([MidMeasureMP(0)], [qml.state()]) From 0a9f52e2bdde10d4fca5d8e2c99c0d947e763af2 Mon Sep 17 00:00:00 2001 From: Mudit Pandey Date: Wed, 22 May 2024 19:34:55 -0400 Subject: [PATCH 12/48] Added qnode tests --- tests/devices/test_preprocess.py | 12 ++-- tests/interfaces/test_jacobian_products.py | 6 +- tests/test_qnode.py | 66 ++++++++++++++++++--- tests/test_qnode_legacy.py | 23 +------ tests/transforms/test_defer_measurements.py | 4 +- tests/transforms/test_dynamic_one_shot.py | 2 +- tests/transforms/test_tape_expand.py | 10 +++- 7 files changed, 82 insertions(+), 41 deletions(-) diff --git a/tests/devices/test_preprocess.py b/tests/devices/test_preprocess.py index 74f8ac386ae..a78daa1db18 100644 --- a/tests/devices/test_preprocess.py +++ b/tests/devices/test_preprocess.py @@ -447,11 +447,11 @@ class TestMidCircuitMeasurements: @pytest.mark.parametrize( "mcm_method, shots, expected_transform", [ - ("deferred", 10, "defer_measurements"), - ("deferred", None, "defer_measurements"), - (None, None, "defer_measurements"), - (None, 10, "dynamic_one_shot"), - ("one-shot", 10, "dynamic_one_shot"), + ("deferred", 10, qml.defer_measurements), + ("deferred", None, qml.defer_measurements), + (None, None, qml.defer_measurements), + (None, 10, qml.dynamic_one_shot), + ("one-shot", 10, qml.dynamic_one_shot), ], ) def test_mcm_method(self, mcm_method, shots, expected_transform, mocker): @@ -459,7 +459,7 @@ def test_mcm_method(self, mcm_method, shots, expected_transform, mocker): dev = qml.device("default.qubit") mcm_config = {"postselect_shots": None, "mcm_method": mcm_method} tape = QuantumScript([qml.measurements.MidMeasureMP(0)], [], shots=shots) - spy = mocker.spy(qml, expected_transform) + spy = mocker.spy(expected_transform, "_transform") _, _ = mid_circuit_measurements(tape, dev, mcm_config) spy.assert_called_once() diff --git a/tests/interfaces/test_jacobian_products.py b/tests/interfaces/test_jacobian_products.py index 5796aa0cda9..b2d88037918 100644 --- a/tests/interfaces/test_jacobian_products.py +++ b/tests/interfaces/test_jacobian_products.py @@ -151,7 +151,8 @@ def test_device_jacobians_repr(self): r" ExecutionConfig(grad_on_execution=None, use_device_gradient=None," r" use_device_jacobian_product=None," r" gradient_method='adjoint', gradient_keyword_arguments={}," - r" device_options={}, interface=None, derivative_order=1)>" + r" device_options={}, interface=None, derivative_order=1," + r" mcm_config={'postselect_shots': None, 'mcm_method': None})>" ) assert repr(jpc) == expected @@ -169,7 +170,8 @@ def test_device_jacobian_products_repr(self): r" ExecutionConfig(grad_on_execution=None, use_device_gradient=None," r" use_device_jacobian_product=None," r" gradient_method='adjoint', gradient_keyword_arguments={}, device_options={}," - r" interface=None, derivative_order=1)>" + r" interface=None, derivative_order=1," + r" mcm_config={'postselect_shots': None, 'mcm_method': None})>" ) assert repr(jpc) == expected diff --git a/tests/test_qnode.py b/tests/test_qnode.py index 2682861e6e2..69366bced7a 100644 --- a/tests/test_qnode.py +++ b/tests/test_qnode.py @@ -878,11 +878,8 @@ def circuit(x, y): assert np.allclose(res, expected, atol=tol, rtol=0) @pytest.mark.parametrize( - "dev, call_count", - [ - (qml.device("default.qubit", wires=3), 2), - (qml.device("default.qubit.legacy", wires=3), 1), - ], + "dev", + [qml.device("default.qubit", wires=3), qml.device("default.qubit.legacy", wires=3)], ) @pytest.mark.parametrize("first_par", np.linspace(0.15, np.pi - 0.3, 3)) @pytest.mark.parametrize("sec_par", np.linspace(0.15, np.pi - 0.3, 3)) @@ -898,7 +895,7 @@ def circuit(x, y): ], ) def test_defer_meas_if_mcm_unsupported( - self, dev, call_count, first_par, sec_par, return_type, mv_return, mv_res, mocker + self, dev, first_par, sec_par, return_type, mv_return, mv_res, mocker ): # pylint: disable=too-many-arguments """Tests that the transform using the deferred measurement principle is applied if the device doesn't support mid-circuit measurements @@ -928,7 +925,7 @@ def conditional_ry_qnode(x, y): assert np.allclose(r1, r2[0]) assert np.allclose(r2[1], mv_res(first_par)) - assert spy.call_count == call_count # once for each preprocessing + assert spy.call_count == 2 @pytest.mark.parametrize("dev_name", ["default.qubit.legacy", "default.mixed"]) def test_dynamic_one_shot_if_mcm_unsupported(self, dev_name): @@ -1701,6 +1698,61 @@ def circuit(): assert qml.math.allclose(results, np.zeros((20, 2))) +@pytest.mark.parametrize("dev_name", ["default.qubit", "default.qubit.legacy"]) +class TestMCMConfiguration: + """Tests for MCM configuration arguments""" + + @pytest.mark.jax + @pytest.mark.parametrize("use_jit", [True, False]) + @pytest.mark.parametrize("interface", ["jax", "auto"]) + def test_jax_warning_with_postselect_shots(self, use_jit, dev_name, interface): + """Test that a warning is raised when postselect_shots=True with jax""" + import jax # pylint: disable=import-outside-toplevel + + shots = 100 + dev = qml.device(dev_name, wires=3, shots=shots) + + @qml.qnode(dev, postselect_shots=True, interface=interface) + def f(x): + qml.RX(x, 0) + _ = qml.measure(0, postselect=1) + return qml.sample(wires=[0, 1]) + + if use_jit: + f = jax.jit(f) + param = jax.numpy.array(np.pi / 4) + + with pytest.warns( + UserWarning, + match="Cannot discard invalid shots with postselection when using the 'jax' interface", + ): + res = f(param) + + assert len(res) == shots + + def test_one_shot_warning_without_shots(self, dev_name, mocker): + """Test that a warning is raised if mcm_method="one-shot" with no shots""" + dev = qml.device(dev_name, wires=3) + spy = mocker.spy(qml.defer_measurements, "_transform") + one_shot_spy = mocker.spy(qml.dynamic_one_shot, "_transform") + + @qml.qnode(dev, mcm_method="one-shot") + def f(x): + qml.RX(x, 0) + _ = qml.measure(0, postselect=1) + return qml.probs(wires=[0, 1]) + + param = np.pi / 4 + + with pytest.warns( + UserWarning, match="Cannot use the 'one-shot' method for mid-circuit measurements with" + ): + _ = f(param) + + assert spy.call_count != 0 + one_shot_spy.assert_not_called() + + class TestTapeExpansion: """Test that tape expansion within the QNode works correctly""" diff --git a/tests/test_qnode_legacy.py b/tests/test_qnode_legacy.py index 9f02791c48e..5dcb1b50f8f 100644 --- a/tests/test_qnode_legacy.py +++ b/tests/test_qnode_legacy.py @@ -1075,7 +1075,7 @@ def test_no_defer_measurements_if_supported(self, mocker): QNode construction if the device supports mid-circuit measurements.""" dev = qml.device("default.qubit.legacy", wires=3) mocker.patch.object(qml.Device, "_capabilities", {"supports_mid_measure": True}) - spy = mocker.spy(qml, "defer_measurements") + spy = mocker.spy(qml.defer_measurements, "_transform") @qml.qnode(dev) def circuit(): @@ -1138,23 +1138,6 @@ def conditional_ry_qnode(x, y): assert np.allclose(r2[1], mv_res(first_par)) assert spy.call_count == 3 if dev.name == "defaut.qubit" else 1 - def test_drawing_has_deferred_measurements(self): - """Test that `qml.draw` with qnodes uses defer_measurements - to draw circuits with mid-circuit measurements.""" - dev = qml.device("default.qubit.legacy", wires=2) - - @qml.qnode(dev) - def circuit(x): - qml.RX(x, wires=0) - m = qml.measure(0) - qml.cond(m, qml.PauliX)(wires=1) - return qml.expval(qml.PauliZ(wires=1)) - - res = qml.draw(circuit)("x") - expected = "0: ──RX(x)─╭●─┤ \n1: ────────╰X─┤ " - - assert res == expected - @pytest.mark.parametrize("basis_state", [[1, 0], [0, 1]]) def test_sampling_with_mcm(self, basis_state, mocker): """Tests that a QNode with qml.sample and mid-circuit measurements @@ -1179,11 +1162,11 @@ def conditional_ry_qnode(x): qml.cond(m_0, qml.RY)(x, wires=1) return qml.sample(qml.PauliZ(1)) - spy = mocker.spy(qml, "defer_measurements") + spy = mocker.spy(qml.defer_measurements, "_transform") r1 = cry_qnode(first_par) r2 = conditional_ry_qnode(first_par) assert np.allclose(r1, r2) - spy.assert_called_once() + spy.assert_called() @pytest.mark.tf @pytest.mark.parametrize("interface", ["tf", "auto"]) diff --git a/tests/transforms/test_defer_measurements.py b/tests/transforms/test_defer_measurements.py index 333bb63756e..ce1659c828f 100644 --- a/tests/transforms/test_defer_measurements.py +++ b/tests/transforms/test_defer_measurements.py @@ -132,7 +132,7 @@ def test_postselect_shots(postselect_shots, mocker): """Test that invalid shots are discarded if requested""" shots = 100 dev = qml.device("default.qubit", shots=shots) - spy = mocker.spy(qml, "defer_measurements") + spy = mocker.spy(qml.defer_measurements, "_transform") @qml.qnode(dev, postselect_shots=postselect_shots, mcm_method="deferred") def f(x): @@ -140,7 +140,7 @@ def f(x): _ = qml.measure(0, postselect=1) return qml.sample(wires=[0, 1]) - res = f(np.pi / 2) + res = f(np.pi / 4) spy.assert_called_once() if postselect_shots: diff --git a/tests/transforms/test_dynamic_one_shot.py b/tests/transforms/test_dynamic_one_shot.py index 2855cc3251a..6bc828d055c 100644 --- a/tests/transforms/test_dynamic_one_shot.py +++ b/tests/transforms/test_dynamic_one_shot.py @@ -47,7 +47,7 @@ def test_parse_native_mid_circuit_measurements_unsupported_meas(measurement): circuit = qml.tape.QuantumScript([qml.RX(1.0, 0)], [measurement]) with pytest.raises(TypeError, match="Native mid-circuit measurement mode does not support"): - parse_native_mid_circuit_measurements(circuit, [circuit], [[]]) + parse_native_mid_circuit_measurements(circuit, [circuit], [[]], None) def test_postselection_error_with_wrong_device(): diff --git a/tests/transforms/test_tape_expand.py b/tests/transforms/test_tape_expand.py index 55f65a7cabc..715320b48e8 100644 --- a/tests/transforms/test_tape_expand.py +++ b/tests/transforms/test_tape_expand.py @@ -931,7 +931,6 @@ def circuit(): _ = circuit() decomp_ops = circuit.tape.operations - print(decomp_ops) assert len(decomp_ops) == 4 if shots is None else 5 assert decomp_ops[0].name == "RZ" @@ -940,5 +939,10 @@ def circuit(): assert decomp_ops[1].name == "RY" assert np.isclose(decomp_ops[1].parameters[0], np.pi / 2) - assert decomp_ops[2].name == "CNOT" - assert decomp_ops[3].name == "CNOT" + if shots: + assert decomp_ops[2].name == "MidMeasureMP" + assert decomp_ops[3].name == "CNOT" + assert decomp_ops[4].name == "MidMeasureMP" + else: + assert decomp_ops[2].name == "CNOT" + assert decomp_ops[3].name == "CNOT" From 79d078ed23921867254e1c822cec887441732524 Mon Sep 17 00:00:00 2001 From: Mudit Pandey Date: Wed, 22 May 2024 20:35:37 -0400 Subject: [PATCH 13/48] Fixed qnode test --- tests/test_qnode.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/test_qnode.py b/tests/test_qnode.py index 69366bced7a..fcce3ac838c 100644 --- a/tests/test_qnode.py +++ b/tests/test_qnode.py @@ -1698,19 +1698,18 @@ def circuit(): assert qml.math.allclose(results, np.zeros((20, 2))) -@pytest.mark.parametrize("dev_name", ["default.qubit", "default.qubit.legacy"]) class TestMCMConfiguration: """Tests for MCM configuration arguments""" @pytest.mark.jax @pytest.mark.parametrize("use_jit", [True, False]) @pytest.mark.parametrize("interface", ["jax", "auto"]) - def test_jax_warning_with_postselect_shots(self, use_jit, dev_name, interface): + def test_jax_warning_with_postselect_shots(self, use_jit, interface): """Test that a warning is raised when postselect_shots=True with jax""" import jax # pylint: disable=import-outside-toplevel shots = 100 - dev = qml.device(dev_name, wires=3, shots=shots) + dev = qml.device("default.qubit", wires=3, shots=shots) @qml.qnode(dev, postselect_shots=True, interface=interface) def f(x): @@ -1730,6 +1729,7 @@ def f(x): assert len(res) == shots + @pytest.mark.parametrize("dev_name", ["default.qubit", "default.qubit.legacy"]) def test_one_shot_warning_without_shots(self, dev_name, mocker): """Test that a warning is raised if mcm_method="one-shot" with no shots""" dev = qml.device(dev_name, wires=3) @@ -1739,7 +1739,7 @@ def test_one_shot_warning_without_shots(self, dev_name, mocker): @qml.qnode(dev, mcm_method="one-shot") def f(x): qml.RX(x, 0) - _ = qml.measure(0, postselect=1) + _ = qml.measure(0) return qml.probs(wires=[0, 1]) param = np.pi / 4 From 4319f733951ace0565da4574601f5b46d467ae61 Mon Sep 17 00:00:00 2001 From: Mudit Pandey Date: Wed, 22 May 2024 21:04:55 -0400 Subject: [PATCH 14/48] Update simulate to map wires before simulation --- pennylane/devices/qubit/simulate.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/pennylane/devices/qubit/simulate.py b/pennylane/devices/qubit/simulate.py index a0c4d60df54..2745ba060c2 100644 --- a/pennylane/devices/qubit/simulate.py +++ b/pennylane/devices/qubit/simulate.py @@ -140,7 +140,8 @@ def get_final_state(circuit, debugger=None, **execution_kwargs): interface = execution_kwargs.get("interface", None) mid_measurements = execution_kwargs.get("mid_measurements", None) postselect_shots = execution_kwargs.get("postselect_shots", None) - circuit = circuit.map_to_standard_wires() + + # circuit = circuit.map_to_standard_wires() prep = None if len(circuit) > 0 and isinstance(circuit[0], qml.operation.StatePrepBase): @@ -214,7 +215,7 @@ def measure_final_state(circuit, state, is_state_batched, **execution_kwargs) -> prng_key = execution_kwargs.get("prng_key", None) mid_measurements = execution_kwargs.get("mid_measurements", None) - circuit = circuit.map_to_standard_wires() + # circuit = circuit.map_to_standard_wires() # analytic case @@ -293,6 +294,8 @@ def simulate( interface = execution_kwargs.get("interface", None) postselect_shots = execution_kwargs.get("postselect_shots", None) + circuit = circuit.map_to_standard_wires() + has_mcm = any(isinstance(op, MidMeasureMP) for op in circuit.operations) if circuit.shots and has_mcm: results = [] From ac16f1762e7651b8e3a9a1ba11357fd4359cf8c4 Mon Sep 17 00:00:00 2001 From: Mudit Pandey Date: Thu, 23 May 2024 10:38:18 -0400 Subject: [PATCH 15/48] Update simulation functions to not map wires and make it a precondition --- pennylane/devices/default_qubit.py | 3 +++ pennylane/devices/qubit/simulate.py | 17 ++++++++++------- 2 files changed, 13 insertions(+), 7 deletions(-) diff --git a/pennylane/devices/default_qubit.py b/pennylane/devices/default_qubit.py index 394389d5a86..7e14661b3d2 100644 --- a/pennylane/devices/default_qubit.py +++ b/pennylane/devices/default_qubit.py @@ -853,6 +853,7 @@ def _simulate_wrapper(circuit, kwargs): def _adjoint_jac_wrapper(c, debugger=None): + c = c.map_to_standard_wires() state, is_state_batched = get_final_state(c, debugger=debugger) jac = adjoint_jacobian(c, state=state) res = measure_final_state(c, state, is_state_batched) @@ -860,6 +861,7 @@ def _adjoint_jac_wrapper(c, debugger=None): def _adjoint_jvp_wrapper(c, t, debugger=None): + c = c.map_to_standard_wires() state, is_state_batched = get_final_state(c, debugger=debugger) jvp = adjoint_jvp(c, t, state=state) res = measure_final_state(c, state, is_state_batched) @@ -867,6 +869,7 @@ def _adjoint_jvp_wrapper(c, t, debugger=None): def _adjoint_vjp_wrapper(c, t, debugger=None): + c = c.map_to_standard_wires() state, is_state_batched = get_final_state(c, debugger=debugger) vjp = adjoint_vjp(c, t, state=state) res = measure_final_state(c, state, is_state_batched) diff --git a/pennylane/devices/qubit/simulate.py b/pennylane/devices/qubit/simulate.py index 2745ba060c2..8577378df9f 100644 --- a/pennylane/devices/qubit/simulate.py +++ b/pennylane/devices/qubit/simulate.py @@ -52,11 +52,15 @@ class _FlexShots(qml.measurements.Shots): """Shots class that allows zero shots.""" + _frozen = False + # pylint: disable=super-init-not-called def __init__(self, shots=None): if isinstance(shots, int): self.total_shots = shots self.shot_vector = (qml.measurements.ShotCopies(shots, 1),) + elif isinstance(shots, self.__class__): + return # self already _is_ shots as defined by __new__ else: self.__all_tuple_init__([s if isinstance(s, tuple) else (s, 1) for s in shots]) @@ -119,7 +123,8 @@ def get_final_state(circuit, debugger=None, **execution_kwargs): This is an internal function that will be called by the successor to ``default.qubit``. Args: - circuit (.QuantumScript): The single circuit to simulate + circuit (.QuantumScript): The single circuit to simulate. This circuit is assumed to have + non-negative integer wire labels debugger (._Debugger): The debugger to use interface (str): The machine learning interface to create the initial state with mid_measurements (None, dict): Dictionary of mid-circuit measurements @@ -141,8 +146,6 @@ def get_final_state(circuit, debugger=None, **execution_kwargs): mid_measurements = execution_kwargs.get("mid_measurements", None) postselect_shots = execution_kwargs.get("postselect_shots", None) - # circuit = circuit.map_to_standard_wires() - prep = None if len(circuit) > 0 and isinstance(circuit[0], qml.operation.StatePrepBase): prep = circuit[0] @@ -168,7 +171,7 @@ def get_final_state(circuit, debugger=None, **execution_kwargs): # Handle postselection on mid-circuit measurements if isinstance(op, qml.Projector): prng_key, key = jax_random_split(prng_key) - state, circuit._shots = _postselection_postprocess( + state, new_shots = _postselection_postprocess( state, is_state_batched, circuit.shots, @@ -176,6 +179,7 @@ def get_final_state(circuit, debugger=None, **execution_kwargs): prng_key=key, postselect_shots=postselect_shots, ) + circuit._shots = circuit._shots = new_shots # new state is batched if i) the old state is batched, or ii) the new op adds a batch dim is_state_batched = is_state_batched or (op.batch_size is not None) @@ -196,7 +200,8 @@ def measure_final_state(circuit, state, is_state_batched, **execution_kwargs) -> This is an internal function that will be called by the successor to ``default.qubit``. Args: - circuit (.QuantumScript): The single circuit to simulate + circuit (.QuantumScript): The single circuit to simulate. This circuit is assumed to have + non-negative integer wire labels state (TensorLike): The state to perform measurement on is_state_batched (bool): Whether the state has a batch dimension or not. rng (Union[None, int, array_like[int], SeedSequence, BitGenerator, Generator]): A @@ -215,8 +220,6 @@ def measure_final_state(circuit, state, is_state_batched, **execution_kwargs) -> prng_key = execution_kwargs.get("prng_key", None) mid_measurements = execution_kwargs.get("mid_measurements", None) - # circuit = circuit.map_to_standard_wires() - # analytic case if not circuit.shots: From 6e6cb0c702ec4d328bde3387ece49d06ebc2c837 Mon Sep 17 00:00:00 2001 From: Mudit Pandey Date: Thu, 23 May 2024 17:01:06 -0400 Subject: [PATCH 16/48] Added more tests for code coverage --- pennylane/devices/preprocess.py | 2 +- .../default_qubit/test_default_qubit.py | 13 ++------- tests/devices/qubit/test_simulate.py | 25 +++++++++++++++++ tests/test_qnode.py | 14 ++++++++++ tests/transforms/test_defer_measurements.py | 6 ++-- tests/transforms/test_dynamic_one_shot.py | 28 ++++++++++--------- 6 files changed, 60 insertions(+), 28 deletions(-) diff --git a/pennylane/devices/preprocess.py b/pennylane/devices/preprocess.py index 1d923159bce..35f0b4ec43d 100644 --- a/pennylane/devices/preprocess.py +++ b/pennylane/devices/preprocess.py @@ -160,7 +160,7 @@ def mid_circuit_measurements( if mcm_method is not None: if mcm_method == "one-shot": return qml.dynamic_one_shot(tape, postselect_shots=postselect_shots) - return qml.defer_measurements(tape, device=device) + return qml.defer_measurements(tape, device=device, postselect_shots=postselect_shots) if tape.shots: return qml.dynamic_one_shot(tape, postselect_shots=postselect_shots) diff --git a/tests/devices/default_qubit/test_default_qubit.py b/tests/devices/default_qubit/test_default_qubit.py index e1fd668f504..784bd2eacd6 100644 --- a/tests/devices/default_qubit/test_default_qubit.py +++ b/tests/devices/default_qubit/test_default_qubit.py @@ -18,7 +18,6 @@ import numpy as np import pytest -from flaky import flaky import pennylane as qml from pennylane.devices import DefaultQubit, ExecutionConfig @@ -1817,7 +1816,6 @@ def circ_expected(): assert qml.math.allclose(res, expected) assert qml.math.get_interface(res) == qml.math.get_interface(expected) - @flaky(max_runs=5) @pytest.mark.parametrize( "mp", [ @@ -1839,9 +1837,7 @@ def test_postselection_valid_finite_shots(self, param, mp, shots, interface, use if use_jit and (interface != "jax" or isinstance(shots, tuple)): pytest.skip("Cannot JIT in non-JAX interfaces, or with shot vectors.") - np.random.seed(42) - - dev = qml.device("default.qubit") + dev = qml.device("default.qubit", seed=1971) param = qml.math.asarray(param, like=interface) @qml.defer_measurements @@ -1879,7 +1875,7 @@ def circ_expected(): @pytest.mark.parametrize( "mp, expected_shape", - [(qml.sample(wires=[0]), (5,)), (qml.classical_shadow(wires=[0]), (2, 5, 1))], + [(qml.sample(wires=[0, 2]), (5, 2)), (qml.classical_shadow(wires=[0, 2]), (2, 5, 2))], ) @pytest.mark.parametrize("param", np.linspace(np.pi / 4, 3 * np.pi / 4, 3)) @pytest.mark.parametrize("shots", [10, (10, 10)]) @@ -1909,11 +1905,6 @@ def circ_postselect(theta): qml.measure(0, postselect=1) return qml.apply(mp) - if use_jit: - import jax - - circ_postselect = jax.jit(circ_postselect, static_argnames=["shots"]) - res = circ_postselect(param, shots=shots) if not isinstance(shots, tuple): diff --git a/tests/devices/qubit/test_simulate.py b/tests/devices/qubit/test_simulate.py index ac54393a028..861d938000a 100644 --- a/tests/devices/qubit/test_simulate.py +++ b/tests/devices/qubit/test_simulate.py @@ -18,6 +18,7 @@ import pennylane as qml from pennylane.devices.qubit import get_final_state, measure_final_state, simulate +from pennylane.devices.qubit.simulate import _FlexShots class TestCurrentlyUnsupportedCases: @@ -409,6 +410,30 @@ def test_nan_state(self, interface): assert qml.math.all(qml.math.isnan(res)) +class Test_FlexShots: + """Unit tests for _FlexShots""" + + @pytest.mark.parametrize( + "shots, expected_shot_vector", + [ + (0, (0,)), + ((10, 0, 5, 0), (10, 0, 5, 0)), + (((10, 3), (0, 5)), (10, 10, 10, 0, 0, 0, 0, 0)), + ], + ) + def test_init_with_zero_shots(self, shots, expected_shot_vector): + """Test that _FlexShots is initialized correctly with zero shots""" + flex_shots = _FlexShots(shots) + shot_vector = tuple(s for s in flex_shots) + assert shot_vector == expected_shot_vector + + def test_init_with_other_shots(self): + """Test that a new _FlexShots object is not created if the input is a Shots object.""" + shots = _FlexShots(10) + new_shots = _FlexShots(shots) + assert new_shots is shots + + class TestDebugger: """Tests that the debugger works for a simple circuit""" diff --git a/tests/test_qnode.py b/tests/test_qnode.py index fcce3ac838c..4cde6c1645c 100644 --- a/tests/test_qnode.py +++ b/tests/test_qnode.py @@ -1752,6 +1752,20 @@ def f(x): assert spy.call_count != 0 one_shot_spy.assert_not_called() + def test_invalid_mcm_method_warning(self): + """Test that a warning is raised if the requested mcm_method is invalid""" + shots = 100 + dev = qml.device("default.qubit", wires=3, shots=shots) + + @qml.qnode(dev, mcm_method="foo") + def f(x): + qml.RX(x, 0) + _ = qml.measure(0, postselect=1) + return qml.sample(wires=[0, 1]) + + with pytest.warns(UserWarning, match="Invalid mid-circuit measurements method 'foo'"): + _ = f(1.8) + class TestTapeExpansion: """Test that tape expansion within the QNode works correctly""" diff --git a/tests/transforms/test_defer_measurements.py b/tests/transforms/test_defer_measurements.py index ce1659c828f..19f3713ef7c 100644 --- a/tests/transforms/test_defer_measurements.py +++ b/tests/transforms/test_defer_measurements.py @@ -109,11 +109,11 @@ def test_qjit_postselection_error(monkeypatch): """Test that an error is raised if qjit is active with `postselect=True`""" # TODO: Update test once defer_measurements can be used with qjit # catalyst = pytest.importorskip("catalyst") - dev = qml.device("lightning.qubit", wires=3, shots=10) + # dev = qml.device("lightning.qubit", wires=3, shots=10) + dev = qml.device("default.qubit", wires=3, shots=10) # @qml.qjit - @partial(qml.defer_measurements, postselect_shots=True) - @qml.qnode(dev) + @qml.qnode(dev, postselect_shots=True, mcm_method="deferred") def func(x): qml.RX(x, 0) _ = qml.measure(0, postselect=0) diff --git a/tests/transforms/test_dynamic_one_shot.py b/tests/transforms/test_dynamic_one_shot.py index 6bc828d055c..01825aaceac 100644 --- a/tests/transforms/test_dynamic_one_shot.py +++ b/tests/transforms/test_dynamic_one_shot.py @@ -16,8 +16,6 @@ """ # pylint: disable=too-few-public-methods, too-many-arguments -from functools import partial - import numpy as np import pytest @@ -63,22 +61,26 @@ def _(): return qml.probs(wires=[0]) -def test_qjit_postselection_error(): - """Test that an error is raised if using qml.qjit and requesting `postselect_shots=True`""" - catalyst = pytest.importorskip("catalyst") - - dev = qml.device("lightning.qubit", wires=3, shots=10) +def test_qjit_postselection_error(monkeypatch): + """Test that an error is raised if qjit is active with `postselect=True`""" + # TODO: Update test once defer_measurements can be used with qjit + # catalyst = pytest.importorskip("catalyst") + # dev = qml.device("lightning.qubit", wires=3, shots=10) + dev = qml.device("default.qubit", wires=3, shots=10) - @qml.qjit - @partial(qml.dynamic_one_shot, postselect_shots=True) - @qml.qnode(dev) + # @qml.qjit + @qml.qnode(dev, postselect_shots=True, mcm_method="one-shot") def func(x): qml.RX(x, 0) - _ = catalyst.measure(0, postselect=0) + _ = qml.measure(0, postselect=0) + # _ = catalyst.measure(0, postselect=0) return qml.sample(wires=[0, 1]) - with pytest.raises(ValueError, match="Cannot discard invalid shots while using qml.qjit"): - _ = func(1.8) + # Mocking qml.compiler.active() to always return True + with monkeypatch.context() as m: + m.setattr(qml.compiler, "active", lambda: True) + with pytest.raises(ValueError, match="Cannot discard invalid shots while using qml.qjit"): + _ = func(1.8) @pytest.mark.parametrize("postselect_shots", [True, False]) From 77614fc163adfcb5c21c772d62f16049050fe657 Mon Sep 17 00:00:00 2001 From: Mudit Pandey Date: Thu, 23 May 2024 17:25:47 -0400 Subject: [PATCH 17/48] Updated postselect_shots to postselect_mode --- doc/releases/changelog-dev.md | 6 ++--- pennylane/devices/default_qubit.py | 4 ++-- pennylane/devices/execution_config.py | 2 +- pennylane/devices/preprocess.py | 10 ++++---- pennylane/devices/qubit/simulate.py | 23 ++++++++++--------- pennylane/transforms/defer_measurements.py | 2 +- pennylane/transforms/dynamic_one_shot.py | 6 ++--- pennylane/workflow/execution.py | 4 ++-- pennylane/workflow/qnode.py | 12 +++++----- .../experimental/test_execution_config.py | 2 +- tests/devices/test_preprocess.py | 4 ++-- tests/interfaces/test_jacobian_products.py | 4 ++-- tests/test_qnode.py | 6 ++--- tests/transforms/test_defer_measurements.py | 10 ++++---- tests/transforms/test_dynamic_one_shot.py | 12 +++++----- 15 files changed, 54 insertions(+), 53 deletions(-) diff --git a/doc/releases/changelog-dev.md b/doc/releases/changelog-dev.md index c3390bf38ec..fdd5e91d8f3 100644 --- a/doc/releases/changelog-dev.md +++ b/doc/releases/changelog-dev.md @@ -4,13 +4,13 @@

New features since last release

-* `qml.QNode` and `qml.qnode` now accept two new keyword arguments: `postselect_shots` and `mcm_method`. +* `qml.QNode` and `qml.qnode` now accept two new keyword arguments: `postselect_mode` and `mcm_method`. These keyword arguments can be used to configure how the device should behave when running circuits with mid-circuit measurements. [(#5679)](https://github.com/PennyLaneAI/pennylane/pull/5679) - * `postselect_shots=True` will indicate to devices to discard invalid shots when postselecting - mid-circuit measurements. Use `postselect_shots=False` to return invalid shots, which will be replaced + * `postselect_mode="hw-like"` will indicate to devices to discard invalid shots when postselecting + mid-circuit measurements. Use `postselect_mode="fill-shots"` to return invalid shots, which will be replaced by dummy values. * `mcm_method` will indicate which strategy to use for running circuits with mid-circuit measurements. Use `mcm_method="deferred"` to use the deferred measurements principle, or `mcm_method="one-shot"` diff --git a/pennylane/devices/default_qubit.py b/pennylane/devices/default_qubit.py index 7e14661b3d2..31daebaa098 100644 --- a/pennylane/devices/default_qubit.py +++ b/pennylane/devices/default_qubit.py @@ -602,7 +602,7 @@ def execute( "interface": interface, "state_cache": self._state_cache, "prng_key": _key, - "postselect_shots": execution_config.mcm_config["postselect_shots"], + "postselect_mode": execution_config.mcm_config["postselect_mode"], }, ) for c, _key in zip(circuits, prng_keys) @@ -614,7 +614,7 @@ def execute( { "rng": _rng, "prng_key": _key, - "postselect_shots": execution_config.mcm_config["postselect_shots"], + "postselect_mode": execution_config.mcm_config["postselect_mode"], } for _rng, _key in zip(seeds, prng_keys) ] diff --git a/pennylane/devices/execution_config.py b/pennylane/devices/execution_config.py index 849ff428299..663fa12690d 100644 --- a/pennylane/devices/execution_config.py +++ b/pennylane/devices/execution_config.py @@ -94,7 +94,7 @@ def __post_init__(self): if self.mcm_config is None: self.mcm_config = {} - for option in ("postselect_shots", "mcm_method"): + for option in ("postselect_mode", "mcm_method"): if option not in self.mcm_config: self.mcm_config[option] = None diff --git a/pennylane/devices/preprocess.py b/pennylane/devices/preprocess.py index 35f0b4ec43d..7afdb303ea7 100644 --- a/pennylane/devices/preprocess.py +++ b/pennylane/devices/preprocess.py @@ -154,17 +154,17 @@ def mid_circuit_measurements( """ mcm_config = mcm_config or {} - postselect_shots = mcm_config.get("postselect_shots", None) + postselect_mode = mcm_config.get("postselect_mode", None) mcm_method = mcm_config.get("mcm_method", None) if mcm_method is not None: if mcm_method == "one-shot": - return qml.dynamic_one_shot(tape, postselect_shots=postselect_shots) - return qml.defer_measurements(tape, device=device, postselect_shots=postselect_shots) + return qml.dynamic_one_shot(tape, postselect_mode=postselect_mode) + return qml.defer_measurements(tape, device=device, postselect_mode=postselect_mode) if tape.shots: - return qml.dynamic_one_shot(tape, postselect_shots=postselect_shots) - return qml.defer_measurements(tape, device=device, postselect_shots=postselect_shots) + return qml.dynamic_one_shot(tape, postselect_mode=postselect_mode) + return qml.defer_measurements(tape, device=device, postselect_mode=postselect_mode) @transform diff --git a/pennylane/devices/qubit/simulate.py b/pennylane/devices/qubit/simulate.py index 8577378df9f..956e3bddc3c 100644 --- a/pennylane/devices/qubit/simulate.py +++ b/pennylane/devices/qubit/simulate.py @@ -79,14 +79,13 @@ def _postselection_postprocess(state, is_state_batched, shots, **execution_kwarg rng = execution_kwargs.get("rng", None) prng_key = execution_kwargs.get("prng_key", None) - postselect_shots = execution_kwargs.get("postselect_shots", None) + postselect_mode = execution_kwargs.get("postselect_mode", None) # The floor function is being used here so that a norm very close to zero becomes exactly # equal to zero so that the state can become invalid. This way, execution can continue, and # bad postselection gives results that are invalid rather than results that look valid but # are incorrect. norm = qml.math.norm(state) - postselect_shots = True if postselect_shots is None else postselect_shots if not qml.math.is_abstract(state) and qml.math.allclose(norm, 0.0): norm = 0.0 @@ -104,7 +103,7 @@ def _postselection_postprocess(state, is_state_batched, shots, **execution_kwarg postselected_shots = ( [int(binomial_fn(s, float(norm**2))) for s in shots] - if postselect_shots and not qml.math.is_abstract(norm) + if postselect_mode in (None, "hw-like") and not qml.math.is_abstract(norm) else shots ) @@ -132,8 +131,9 @@ def get_final_state(circuit, debugger=None, **execution_kwargs): prng_key (Optional[jax.random.PRNGKey]): An optional ``jax.random.PRNGKey``. This is the key to the JAX pseudo random number generator. Only for simulation using JAX. If None, a ``numpy.random.default_rng`` will be for sampling. - postselect_shots (bool): Whether or not to discard invalid shots when postselecting - mid-circuit measurements. + postselect_mode (str): Configuration for handling shots with mid-circuit measurement + postselection. Use ``"hw-like"`` to discard invalid shots and ``"fill-shots"`` to + keep the same number of shots. Returns: Tuple[TensorLike, bool]: A tuple containing the final state of the quantum script and @@ -144,7 +144,7 @@ def get_final_state(circuit, debugger=None, **execution_kwargs): prng_key = execution_kwargs.get("prng_key", None) interface = execution_kwargs.get("interface", None) mid_measurements = execution_kwargs.get("mid_measurements", None) - postselect_shots = execution_kwargs.get("postselect_shots", None) + postselect_mode = execution_kwargs.get("postselect_mode", None) prep = None if len(circuit) > 0 and isinstance(circuit[0], qml.operation.StatePrepBase): @@ -177,7 +177,7 @@ def get_final_state(circuit, debugger=None, **execution_kwargs): circuit.shots, rng=rng, prng_key=key, - postselect_shots=postselect_shots, + postselect_mode=postselect_mode, ) circuit._shots = circuit._shots = new_shots @@ -276,8 +276,9 @@ def simulate( the key to the JAX pseudo random number generator. If None, a random key will be generated. Only for simulation using JAX. interface (str): The machine learning interface to create the initial state with - postselect_shots (bool): Whether or not to discard invalid shots when postselecting - mid-circuit measurements. + postselect_mode (str): Configuration for handling shots with mid-circuit measurement + postselection. Use ``"hw-like"`` to discard invalid shots and ``"fill-shots"`` to + keep the same number of shots. Returns: tuple(TensorLike): The results of the simulation @@ -295,7 +296,7 @@ def simulate( rng = execution_kwargs.get("rng", None) prng_key = execution_kwargs.get("prng_key", None) interface = execution_kwargs.get("interface", None) - postselect_shots = execution_kwargs.get("postselect_shots", None) + postselect_mode = execution_kwargs.get("postselect_mode", None) circuit = circuit.map_to_standard_wires() @@ -336,7 +337,7 @@ def simulate_partial(k): rng=rng, prng_key=ops_key, interface=interface, - postselect_shots=postselect_shots, + postselect_mode=postselect_mode, ) if state_cache is not None: state_cache[circuit.hash] = state diff --git a/pennylane/transforms/defer_measurements.py b/pennylane/transforms/defer_measurements.py index 03c03230f69..46253b55388 100644 --- a/pennylane/transforms/defer_measurements.py +++ b/pennylane/transforms/defer_measurements.py @@ -272,7 +272,7 @@ def node(x): if not any(isinstance(o, MidMeasureMP) for o in tape.operations): return (tape,), null_postprocessing - if qml.compiler.active() and kwargs.get("postselect_shots", None): + if qml.compiler.active() and kwargs.get("postselect_mode", None) == "hw-like": raise ValueError("Cannot discard invalid shots while using qml.qjit") _check_tape_validity(tape) diff --git a/pennylane/transforms/dynamic_one_shot.py b/pennylane/transforms/dynamic_one_shot.py index 77ac81cac62..eaac75afc22 100644 --- a/pennylane/transforms/dynamic_one_shot.py +++ b/pennylane/transforms/dynamic_one_shot.py @@ -99,10 +99,10 @@ def func(x, y): ) _ = kwargs.get("device", None) - postselect_shots = kwargs.get("postselect_shots", None) - if qml.compiler.active() and postselect_shots: + postselect_mode = kwargs.get("postselect_mode", None) + if qml.compiler.active() and postselect_mode == "hw-like": raise ValueError("Cannot discard invalid shots while using qml.qjit") - postselect_shots = True if postselect_shots is None else postselect_shots + postselect_shots = postselect_mode in (None, "hw-like") if not tape.shots: raise qml.QuantumFunctionError("dynamic_one_shot is only supported with finite shots.") diff --git a/pennylane/workflow/execution.py b/pennylane/workflow/execution.py index fb5c5eed7c7..b2df628750a 100644 --- a/pennylane/workflow/execution.py +++ b/pennylane/workflow/execution.py @@ -549,13 +549,13 @@ def cost_fn(params, x): gradient_fn, grad_on_execution, interface, device, device_vjp, mcm_config ) - if interface in {"jax", "jax-jit"} and config.mcm_config["postselect_shots"]: + if interface in {"jax", "jax-jit"} and config.mcm_config["postselect_mode"] == "hw-like": warnings.warn( "Cannot discard invalid shots with postselection when using the 'jax' interface. " "Ignoring requested mid-circuit measurement configuration.", UserWarning, ) - config.mcm_config["postselect_shots"] = False + config.mcm_config["postselect_mode"] = "fill-shots" if any(not tape.shots for tape in tapes) and config.mcm_config["mcm_method"] == "one-shot": warnings.warn( diff --git a/pennylane/workflow/qnode.py b/pennylane/workflow/qnode.py index 08d7795db9d..dc49e128f4a 100644 --- a/pennylane/workflow/qnode.py +++ b/pennylane/workflow/qnode.py @@ -219,9 +219,9 @@ class QNode: (classical) computational overhead during the backwards pass. device_vjp (bool): Whether or not to use the device-provided Vector Jacobian Product (VJP). A value of ``None`` indicates to use it if the device provides it, but use the full jacobian otherwise. - postselect_shots (bool): Whether or not to discard invalid shots when postselecting mid-circuit measurements. - If ``True``, invalid shots will be discarded and only results for valid shots will be returned. If - ``False``, results corresponding to the original number of shots will be returned. + postselect_mode (str): Configuration for handling shots with mid-circuit measurement + postselection. If ``"hw-like"``, invalid shots will be discarded and only results for valid shots will be returned. If + ``"fill-shots"``, results corresponding to the original number of shots will be returned. mcm_method (str): Strategy to use when executing circuits with mid-circuit measurements. Use ``"deferred"`` to execute using the deferred measurements principle (applied using the :func:`~pennylane.defer_measurements` transform), or ``"one-shot"`` if using finite shots to execute the @@ -449,7 +449,7 @@ def __init__( cachesize=10000, max_diff=1, device_vjp=False, - postselect_shots=None, + postselect_mode=None, mcm_method=None, **gradient_kwargs, ): @@ -525,7 +525,7 @@ def __init__( "max_diff": max_diff, "max_expansion": max_expansion, "device_vjp": device_vjp, - "mcm_config": {"postselect_shots": postselect_shots, "mcm_method": mcm_method}, + "mcm_config": {"postselect_mode": postselect_mode, "mcm_method": mcm_method}, } if self.expansion_strategy == "device": @@ -1055,7 +1055,7 @@ def _execution_component(self, args: tuple, kwargs: dict, override_shots) -> qml elif hasattr(self.device, "capabilities"): full_transform_program.add_transform( qml.defer_measurements, - postselect_shots=self.execute_kwargs["mcm_config"]["postselect_shots"], + postselect_mode=self.execute_kwargs["mcm_config"]["postselect_mode"], device=self.device, ) diff --git a/tests/devices/experimental/test_execution_config.py b/tests/devices/experimental/test_execution_config.py index 01142a7a653..b49cd430d71 100644 --- a/tests/devices/experimental/test_execution_config.py +++ b/tests/devices/experimental/test_execution_config.py @@ -30,7 +30,7 @@ def test_default_values(): assert config.gradient_keyword_arguments == {} assert config.grad_on_execution is None assert config.use_device_gradient is None - assert config.mcm_config == {"postselect_shots": None, "mcm_method": None} + assert config.mcm_config == {"postselect_mode": None, "mcm_method": None} def test_invalid_interface(): diff --git a/tests/devices/test_preprocess.py b/tests/devices/test_preprocess.py index a78daa1db18..b0ceecb5839 100644 --- a/tests/devices/test_preprocess.py +++ b/tests/devices/test_preprocess.py @@ -457,7 +457,7 @@ class TestMidCircuitMeasurements: def test_mcm_method(self, mcm_method, shots, expected_transform, mocker): """Test that the transform adheres to the specified transform""" dev = qml.device("default.qubit") - mcm_config = {"postselect_shots": None, "mcm_method": mcm_method} + mcm_config = {"postselect_mode": None, "mcm_method": mcm_method} tape = QuantumScript([qml.measurements.MidMeasureMP(0)], [], shots=shots) spy = mocker.spy(expected_transform, "_transform") @@ -468,7 +468,7 @@ def test_error_incompatible_mcm_method(self): """Test that an error is raised if requesting the one-shot transform without shots""" dev = qml.device("default.qubit") shots = None - mcm_config = {"postselect_shots": None, "mcm_method": "one-shot"} + mcm_config = {"postselect_mode": None, "mcm_method": "one-shot"} tape = QuantumScript([qml.measurements.MidMeasureMP(0)], [], shots=shots) with pytest.raises( diff --git a/tests/interfaces/test_jacobian_products.py b/tests/interfaces/test_jacobian_products.py index b2d88037918..67f245f8dcb 100644 --- a/tests/interfaces/test_jacobian_products.py +++ b/tests/interfaces/test_jacobian_products.py @@ -152,7 +152,7 @@ def test_device_jacobians_repr(self): r" use_device_jacobian_product=None," r" gradient_method='adjoint', gradient_keyword_arguments={}," r" device_options={}, interface=None, derivative_order=1," - r" mcm_config={'postselect_shots': None, 'mcm_method': None})>" + r" mcm_config={'postselect_mode': None, 'mcm_method': None})>" ) assert repr(jpc) == expected @@ -171,7 +171,7 @@ def test_device_jacobian_products_repr(self): r" use_device_jacobian_product=None," r" gradient_method='adjoint', gradient_keyword_arguments={}, device_options={}," r" interface=None, derivative_order=1," - r" mcm_config={'postselect_shots': None, 'mcm_method': None})>" + r" mcm_config={'postselect_mode': None, 'mcm_method': None})>" ) assert repr(jpc) == expected diff --git a/tests/test_qnode.py b/tests/test_qnode.py index 4cde6c1645c..658733b3381 100644 --- a/tests/test_qnode.py +++ b/tests/test_qnode.py @@ -1704,14 +1704,14 @@ class TestMCMConfiguration: @pytest.mark.jax @pytest.mark.parametrize("use_jit", [True, False]) @pytest.mark.parametrize("interface", ["jax", "auto"]) - def test_jax_warning_with_postselect_shots(self, use_jit, interface): - """Test that a warning is raised when postselect_shots=True with jax""" + def test_jax_warning_with_postselect_mode_hw_like(self, use_jit, interface): + """Test that a warning is raised when postselect_mode="hw-like" with jax""" import jax # pylint: disable=import-outside-toplevel shots = 100 dev = qml.device("default.qubit", wires=3, shots=shots) - @qml.qnode(dev, postselect_shots=True, interface=interface) + @qml.qnode(dev, postselect_mode="hw-like", interface=interface) def f(x): qml.RX(x, 0) _ = qml.measure(0, postselect=1) diff --git a/tests/transforms/test_defer_measurements.py b/tests/transforms/test_defer_measurements.py index 19f3713ef7c..7fd87008c5f 100644 --- a/tests/transforms/test_defer_measurements.py +++ b/tests/transforms/test_defer_measurements.py @@ -113,7 +113,7 @@ def test_qjit_postselection_error(monkeypatch): dev = qml.device("default.qubit", wires=3, shots=10) # @qml.qjit - @qml.qnode(dev, postselect_shots=True, mcm_method="deferred") + @qml.qnode(dev, postselect_mode="hw-like", mcm_method="deferred") def func(x): qml.RX(x, 0) _ = qml.measure(0, postselect=0) @@ -127,14 +127,14 @@ def func(x): _ = func(1.8) -@pytest.mark.parametrize("postselect_shots", [True, False]) -def test_postselect_shots(postselect_shots, mocker): +@pytest.mark.parametrize("postselect_mode", ["hw-like", "fill-shots"]) +def test_postselect_mode(postselect_mode, mocker): """Test that invalid shots are discarded if requested""" shots = 100 dev = qml.device("default.qubit", shots=shots) spy = mocker.spy(qml.defer_measurements, "_transform") - @qml.qnode(dev, postselect_shots=postselect_shots, mcm_method="deferred") + @qml.qnode(dev, postselect_mode=postselect_mode, mcm_method="deferred") def f(x): qml.RX(x, 0) _ = qml.measure(0, postselect=1) @@ -143,7 +143,7 @@ def f(x): res = f(np.pi / 4) spy.assert_called_once() - if postselect_shots: + if postselect_mode == "hw-like": assert len(res) < shots else: assert len(res) == shots diff --git a/tests/transforms/test_dynamic_one_shot.py b/tests/transforms/test_dynamic_one_shot.py index 01825aaceac..b8726890033 100644 --- a/tests/transforms/test_dynamic_one_shot.py +++ b/tests/transforms/test_dynamic_one_shot.py @@ -62,14 +62,14 @@ def _(): def test_qjit_postselection_error(monkeypatch): - """Test that an error is raised if qjit is active with `postselect=True`""" + """Test that an error is raised if qjit is active with `postselect_mode="hw-like"`""" # TODO: Update test once defer_measurements can be used with qjit # catalyst = pytest.importorskip("catalyst") # dev = qml.device("lightning.qubit", wires=3, shots=10) dev = qml.device("default.qubit", wires=3, shots=10) # @qml.qjit - @qml.qnode(dev, postselect_shots=True, mcm_method="one-shot") + @qml.qnode(dev, postselect_mode="hw-like", mcm_method="one-shot") def func(x): qml.RX(x, 0) _ = qml.measure(0, postselect=0) @@ -83,14 +83,14 @@ def func(x): _ = func(1.8) -@pytest.mark.parametrize("postselect_shots", [True, False]) -def test_postselect_shots(postselect_shots, mocker): +@pytest.mark.parametrize("postselect_mode", ["hw-like", "fill-shots"]) +def test_postselect_mode(postselect_mode, mocker): """Test that invalid shots are discarded if requested""" shots = 100 dev = qml.device("default.qubit", shots=shots) spy = mocker.spy(qml, "dynamic_one_shot") - @qml.qnode(dev, postselect_shots=postselect_shots) + @qml.qnode(dev, postselect_mode=postselect_mode) def f(x): qml.RX(x, 0) _ = qml.measure(0, postselect=1) @@ -99,7 +99,7 @@ def f(x): res = f(np.pi / 2) spy.assert_called_once() - if postselect_shots: + if postselect_mode == "hw-like": assert len(res) < shots else: assert len(res) == shots From fcdfee6ac9cd68e5483ee6c846ed5c0b97031b58 Mon Sep 17 00:00:00 2001 From: Mudit Pandey Date: Fri, 24 May 2024 13:25:27 -0400 Subject: [PATCH 18/48] Added docs --- doc/introduction/measurements.rst | 55 +++++++++++++++++++++++++++++++ pennylane/workflow/qnode.py | 6 ++-- 2 files changed, 59 insertions(+), 2 deletions(-) diff --git a/doc/introduction/measurements.rst b/doc/introduction/measurements.rst index 1dd17edd184..0270b073a7a 100644 --- a/doc/introduction/measurements.rst +++ b/doc/introduction/measurements.rst @@ -524,6 +524,61 @@ Collecting statistics for sequences of mid-circuit measurements is supported wit When collecting statistics for a list of mid-circuit measurements, values manipulated using arithmetic operators should not be used as this behaviour is not supported. +Configuring mid-circuit measurements +************************************ + +As seen above, there are multiple ways in which circuits with mid-circuit measurements can be executed with +PennyLane. For ease of use, we provide the following configuration options to users when initializing a +:class:`~pennylane.QNode`: + + * ``mcm_method``: To set the method used for applying mid-circuit measurements. Use ``mcm_method="deferred"`` + to use the deferred measurements principle or ``mcm_method="one-shot"`` to use the one-shot transform as + described above. + + .. note:: + + If the ``mcm_method`` argument is provided, the :func:`~pennylane.defer_measurements` or + :func:`~pennylane.dynamic_one_shot` transforms should not be applied directly to a :class:`~pennylane.QNode` + as it can lead to incorrect behaviour. + + * ``postselect_mode``: To configure how invalid shots are handled when postselecting mid-circuit measurements + with finite-shot circuits. Use ``postselect_mode="hw-like"`` to discard invalid shots. In this case, the number + of samples that are used for processing results will be less than or equal to the total number of shots. Use + ``postselect_mode="fill-shots"`` to keep invalid shots. + + .. note:: + + If ``postselect_mode="fill-shots"``, the specified ``mcm_method`` will impact the results due to the particular + features/limitations of the requested ``mcm_method``. + + * If using ``mcm_method="hw-like"``, invalid samples will be replaced with ``np.iinfo(np.int32).min``, + and these invalid values will not be used for processing final results. + * If using ``mcm_method="deferred"``, all shots will be projected to the postselected value, so all + shots will be considered valid. + + .. note:: + + When using the ``jax`` interface or while using :func:`~pennylane.qjit`, the results will reflect + ``postselect_mode="fill-shots"`` regardless of the specified value. + +.. code-block:: python3 + + import pennylane as qml + import numpy as np + + dev = qml.device("default.qubit", wires=3, shots=10) + + @qml.qnode(dev, mcm_method="one-shot", postselect_mode="fill-shots") + def circuit(x): + qml.RX(x, 0) + m0 = qml.measure(0, postselect=1) + qml.CNOT([0, 1]) + return qml.expval(qml.PauliZ(0)) + +>>> circuit(np.pi / 2) +array([-2147483648, 1, -2147483648, -2147483648, 1, + 1, 1, 1, 1, 1]) + Changing the number of shots ---------------------------- diff --git a/pennylane/workflow/qnode.py b/pennylane/workflow/qnode.py index dc49e128f4a..9f657b4c7a9 100644 --- a/pennylane/workflow/qnode.py +++ b/pennylane/workflow/qnode.py @@ -221,11 +221,13 @@ class QNode: A value of ``None`` indicates to use it if the device provides it, but use the full jacobian otherwise. postselect_mode (str): Configuration for handling shots with mid-circuit measurement postselection. If ``"hw-like"``, invalid shots will be discarded and only results for valid shots will be returned. If - ``"fill-shots"``, results corresponding to the original number of shots will be returned. + ``"fill-shots"``, results corresponding to the original number of shots will be returned. For usage details, + please refer to the :doc:`main measurements page . mcm_method (str): Strategy to use when executing circuits with mid-circuit measurements. Use ``"deferred"`` to execute using the deferred measurements principle (applied using the :func:`~pennylane.defer_measurements` transform), or ``"one-shot"`` if using finite shots to execute the - circuit for each shot separately. + circuit for each shot separately. For usage details, please refer to the + :doc:`main measurements page . Keyword Args: **kwargs: Any additional keyword arguments provided are passed to the differentiation From 60912d2f1c31e020c420cff642e4314e5a5a719f Mon Sep 17 00:00:00 2001 From: Mudit Pandey Date: Fri, 24 May 2024 13:43:23 -0400 Subject: [PATCH 19/48] Fix docs --- doc/introduction/measurements.rst | 14 +++++++------- pennylane/workflow/qnode.py | 4 ++-- 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/doc/introduction/measurements.rst b/doc/introduction/measurements.rst index 0270b073a7a..39f0be5986b 100644 --- a/doc/introduction/measurements.rst +++ b/doc/introduction/measurements.rst @@ -532,8 +532,8 @@ PennyLane. For ease of use, we provide the following configuration options to us :class:`~pennylane.QNode`: * ``mcm_method``: To set the method used for applying mid-circuit measurements. Use ``mcm_method="deferred"`` - to use the deferred measurements principle or ``mcm_method="one-shot"`` to use the one-shot transform as - described above. + to use the deferred measurements principle or ``mcm_method="one-shot"`` to use the one-shot transform as + described above. .. note:: @@ -542,9 +542,9 @@ PennyLane. For ease of use, we provide the following configuration options to us as it can lead to incorrect behaviour. * ``postselect_mode``: To configure how invalid shots are handled when postselecting mid-circuit measurements - with finite-shot circuits. Use ``postselect_mode="hw-like"`` to discard invalid shots. In this case, the number - of samples that are used for processing results will be less than or equal to the total number of shots. Use - ``postselect_mode="fill-shots"`` to keep invalid shots. + with finite-shot circuits. Use ``postselect_mode="hw-like"`` to discard invalid shots. In this case, the number + of samples that are used for processing results will be less than or equal to the total number of shots. Use + ``postselect_mode="fill-shots"`` to keep invalid shots. .. note:: @@ -552,9 +552,9 @@ PennyLane. For ease of use, we provide the following configuration options to us features/limitations of the requested ``mcm_method``. * If using ``mcm_method="hw-like"``, invalid samples will be replaced with ``np.iinfo(np.int32).min``, - and these invalid values will not be used for processing final results. + and these invalid values will not be used for processing final results. * If using ``mcm_method="deferred"``, all shots will be projected to the postselected value, so all - shots will be considered valid. + shots will be considered valid. .. note:: diff --git a/pennylane/workflow/qnode.py b/pennylane/workflow/qnode.py index 9f657b4c7a9..37e95d8275f 100644 --- a/pennylane/workflow/qnode.py +++ b/pennylane/workflow/qnode.py @@ -222,12 +222,12 @@ class QNode: postselect_mode (str): Configuration for handling shots with mid-circuit measurement postselection. If ``"hw-like"``, invalid shots will be discarded and only results for valid shots will be returned. If ``"fill-shots"``, results corresponding to the original number of shots will be returned. For usage details, - please refer to the :doc:`main measurements page . + please refer to the :doc:`main measurements page `. mcm_method (str): Strategy to use when executing circuits with mid-circuit measurements. Use ``"deferred"`` to execute using the deferred measurements principle (applied using the :func:`~pennylane.defer_measurements` transform), or ``"one-shot"`` if using finite shots to execute the circuit for each shot separately. For usage details, please refer to the - :doc:`main measurements page . + :doc:`main measurements page `. Keyword Args: **kwargs: Any additional keyword arguments provided are passed to the differentiation From dab94bba4f387c2aa185304858cb866a6980f04d Mon Sep 17 00:00:00 2001 From: Mudit Pandey Date: Fri, 24 May 2024 13:59:48 -0400 Subject: [PATCH 20/48] Fix indentation in docs --- doc/introduction/measurements.rst | 42 +++++++++++++++---------------- 1 file changed, 21 insertions(+), 21 deletions(-) diff --git a/doc/introduction/measurements.rst b/doc/introduction/measurements.rst index 39f0be5986b..2d769b25f63 100644 --- a/doc/introduction/measurements.rst +++ b/doc/introduction/measurements.rst @@ -531,35 +531,35 @@ As seen above, there are multiple ways in which circuits with mid-circuit measur PennyLane. For ease of use, we provide the following configuration options to users when initializing a :class:`~pennylane.QNode`: - * ``mcm_method``: To set the method used for applying mid-circuit measurements. Use ``mcm_method="deferred"`` - to use the deferred measurements principle or ``mcm_method="one-shot"`` to use the one-shot transform as - described above. +* ``mcm_method``: To set the method used for applying mid-circuit measurements. Use ``mcm_method="deferred"`` + to use the deferred measurements principle or ``mcm_method="one-shot"`` to use the one-shot transform as + described above. - .. note:: + .. note:: - If the ``mcm_method`` argument is provided, the :func:`~pennylane.defer_measurements` or - :func:`~pennylane.dynamic_one_shot` transforms should not be applied directly to a :class:`~pennylane.QNode` - as it can lead to incorrect behaviour. + If the ``mcm_method`` argument is provided, the :func:`~pennylane.defer_measurements` or + :func:`~pennylane.dynamic_one_shot` transforms should not be applied directly to a :class:`~pennylane.QNode` + as it can lead to incorrect behaviour. - * ``postselect_mode``: To configure how invalid shots are handled when postselecting mid-circuit measurements - with finite-shot circuits. Use ``postselect_mode="hw-like"`` to discard invalid shots. In this case, the number - of samples that are used for processing results will be less than or equal to the total number of shots. Use - ``postselect_mode="fill-shots"`` to keep invalid shots. +* ``postselect_mode``: To configure how invalid shots are handled when postselecting mid-circuit measurements + with finite-shot circuits. Use ``postselect_mode="hw-like"`` to discard invalid shots. In this case, the number + of samples that are used for processing results will be less than or equal to the total number of shots. Use + ``postselect_mode="fill-shots"`` to keep invalid shots. - .. note:: + .. note:: - If ``postselect_mode="fill-shots"``, the specified ``mcm_method`` will impact the results due to the particular - features/limitations of the requested ``mcm_method``. + If ``postselect_mode="fill-shots"``, the specified ``mcm_method`` will impact the results due to the particular + features/limitations of the requested ``mcm_method``. - * If using ``mcm_method="hw-like"``, invalid samples will be replaced with ``np.iinfo(np.int32).min``, - and these invalid values will not be used for processing final results. - * If using ``mcm_method="deferred"``, all shots will be projected to the postselected value, so all - shots will be considered valid. + * If using ``mcm_method="hw-like"``, invalid samples will be replaced with ``np.iinfo(np.int32).min``, + and these invalid values will not be used for processing final results. + * If using ``mcm_method="deferred"``, all shots will be projected to the postselected value, so all + shots will be considered valid. - .. note:: +.. note:: - When using the ``jax`` interface or while using :func:`~pennylane.qjit`, the results will reflect - ``postselect_mode="fill-shots"`` regardless of the specified value. + When using the ``jax`` interface or while using :func:`~pennylane.qjit`, the results will reflect + ``postselect_mode="fill-shots"`` regardless of the specified value. .. code-block:: python3 From a40e8452c68dd9376a7406d9b38ff8bdea553535 Mon Sep 17 00:00:00 2001 From: Mudit Pandey Date: Fri, 24 May 2024 15:01:17 -0400 Subject: [PATCH 21/48] Fix indentation again --- doc/introduction/measurements.rst | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/doc/introduction/measurements.rst b/doc/introduction/measurements.rst index 2d769b25f63..9268d31f39e 100644 --- a/doc/introduction/measurements.rst +++ b/doc/introduction/measurements.rst @@ -532,8 +532,8 @@ PennyLane. For ease of use, we provide the following configuration options to us :class:`~pennylane.QNode`: * ``mcm_method``: To set the method used for applying mid-circuit measurements. Use ``mcm_method="deferred"`` - to use the deferred measurements principle or ``mcm_method="one-shot"`` to use the one-shot transform as - described above. + to use the deferred measurements principle or ``mcm_method="one-shot"`` to use the one-shot transform as + described above. .. note:: @@ -542,9 +542,9 @@ PennyLane. For ease of use, we provide the following configuration options to us as it can lead to incorrect behaviour. * ``postselect_mode``: To configure how invalid shots are handled when postselecting mid-circuit measurements - with finite-shot circuits. Use ``postselect_mode="hw-like"`` to discard invalid shots. In this case, the number - of samples that are used for processing results will be less than or equal to the total number of shots. Use - ``postselect_mode="fill-shots"`` to keep invalid shots. + with finite-shot circuits. Use ``postselect_mode="hw-like"`` to discard invalid shots. In this case, the number + of samples that are used for processing results will be less than or equal to the total number of shots. Use + ``postselect_mode="fill-shots"`` to keep invalid shots. .. note:: @@ -556,10 +556,10 @@ PennyLane. For ease of use, we provide the following configuration options to us * If using ``mcm_method="deferred"``, all shots will be projected to the postselected value, so all shots will be considered valid. -.. note:: + .. note:: - When using the ``jax`` interface or while using :func:`~pennylane.qjit`, the results will reflect - ``postselect_mode="fill-shots"`` regardless of the specified value. + When using the ``jax`` interface or while using :func:`~pennylane.qjit`, the results will reflect + ``postselect_mode="fill-shots"`` regardless of the specified value. .. code-block:: python3 From 9b14d370e09823fc227a14bc4b298cae83aabab1 Mon Sep 17 00:00:00 2001 From: Mudit Pandey Date: Fri, 24 May 2024 16:27:28 -0400 Subject: [PATCH 22/48] Another indentation fix.. --- doc/introduction/measurements.rst | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/doc/introduction/measurements.rst b/doc/introduction/measurements.rst index 9268d31f39e..d03e413fa4e 100644 --- a/doc/introduction/measurements.rst +++ b/doc/introduction/measurements.rst @@ -535,7 +535,7 @@ PennyLane. For ease of use, we provide the following configuration options to us to use the deferred measurements principle or ``mcm_method="one-shot"`` to use the one-shot transform as described above. - .. note:: + .. warning:: If the ``mcm_method`` argument is provided, the :func:`~pennylane.defer_measurements` or :func:`~pennylane.dynamic_one_shot` transforms should not be applied directly to a :class:`~pennylane.QNode` @@ -552,9 +552,9 @@ PennyLane. For ease of use, we provide the following configuration options to us features/limitations of the requested ``mcm_method``. * If using ``mcm_method="hw-like"``, invalid samples will be replaced with ``np.iinfo(np.int32).min``, - and these invalid values will not be used for processing final results. + and these invalid values will not be used for processing final results. * If using ``mcm_method="deferred"``, all shots will be projected to the postselected value, so all - shots will be considered valid. + shots will be considered valid. .. note:: From 6abdd028fa484f5b53b41e2cd1fa2c80ed39c296 Mon Sep 17 00:00:00 2001 From: Mudit Pandey Date: Mon, 27 May 2024 16:30:12 -0400 Subject: [PATCH 23/48] Made mcm_config dataclass --- pennylane/devices/default_qubit.py | 4 +-- pennylane/devices/execution_config.py | 27 +++++++++++++---- pennylane/devices/preprocess.py | 11 ++++--- pennylane/workflow/execution.py | 8 ++--- pennylane/workflow/qnode.py | 3 -- .../experimental/test_execution_config.py | 30 +++++++++++++++++-- tests/interfaces/test_jacobian_products.py | 4 +-- 7 files changed, 64 insertions(+), 23 deletions(-) diff --git a/pennylane/devices/default_qubit.py b/pennylane/devices/default_qubit.py index 31daebaa098..6b219f98a69 100644 --- a/pennylane/devices/default_qubit.py +++ b/pennylane/devices/default_qubit.py @@ -602,7 +602,7 @@ def execute( "interface": interface, "state_cache": self._state_cache, "prng_key": _key, - "postselect_mode": execution_config.mcm_config["postselect_mode"], + "postselect_mode": execution_config.mcm_config.postselect_mode, }, ) for c, _key in zip(circuits, prng_keys) @@ -614,7 +614,7 @@ def execute( { "rng": _rng, "prng_key": _key, - "postselect_mode": execution_config.mcm_config["postselect_mode"], + "postselect_mode": execution_config.mcm_config.postselect_mode, } for _rng, _key in zip(seeds, prng_keys) ] diff --git a/pennylane/devices/execution_config.py b/pennylane/devices/execution_config.py index 663fa12690d..c6c8b6bb6f3 100644 --- a/pennylane/devices/execution_config.py +++ b/pennylane/devices/execution_config.py @@ -20,6 +20,22 @@ from pennylane.workflow import SUPPORTED_INTERFACES +@dataclass +class MCMConfig: + """A class to store mid-circuit measurement configurations.""" + + mcm_method: Optional[str] = None + """Which mid-circuit measurement strategy to use. Use ``deferred`` for the deferred + measurements principle and "one-shot" if using finite shots to execute the circuit + for each shot separately.""" + + postselect_mode: Optional[str] = None + """Configuration for handling shots with mid-circuit measurement postselection. If + ``"hw-like"``, invalid shots will be discarded and only results for valid shots will + be returned. If ``"fill-shots"``, results corresponding to the original number of + shots will be returned.""" + + # pylint: disable=too-many-instance-attributes @dataclass class ExecutionConfig: @@ -67,7 +83,7 @@ class ExecutionConfig: derivative_order: int = 1 """The derivative order to compute while evaluating a gradient""" - mcm_config: Optional[dict] = None + mcm_config: Optional[MCMConfig] = MCMConfig() """Configuration options for handling mid-circuit measurements""" def __post_init__(self): @@ -92,11 +108,10 @@ def __post_init__(self): if self.gradient_keyword_arguments is None: self.gradient_keyword_arguments = {} - if self.mcm_config is None: - self.mcm_config = {} - for option in ("postselect_mode", "mcm_method"): - if option not in self.mcm_config: - self.mcm_config[option] = None + if isinstance(self.mcm_config, dict): + self.mcm_config = MCMConfig(**self.mcm_config) + elif not isinstance(self.mcm_config, MCMConfig): + raise ValueError(f"Got invalid type {type(self.mcm_config)} for 'mcm_config'") DefaultExecutionConfig = ExecutionConfig() diff --git a/pennylane/devices/preprocess.py b/pennylane/devices/preprocess.py index 7afdb303ea7..12e51f9880e 100644 --- a/pennylane/devices/preprocess.py +++ b/pennylane/devices/preprocess.py @@ -28,6 +28,8 @@ from pennylane.typing import Result, ResultBatch from pennylane.wires import WireError +from .execution_config import MCMConfig + PostprocessingFn = Callable[[ResultBatch], Union[Result, ResultBatch]] @@ -145,7 +147,7 @@ def validate_device_wires( @transform def mid_circuit_measurements( - tape: qml.tape.QuantumTape, device, mcm_config=None + tape: qml.tape.QuantumTape, device, mcm_config=MCMConfig ) -> (Sequence[qml.tape.QuantumTape], Callable): """Provide the transform to handle mid-circuit measurements. @@ -153,9 +155,10 @@ def mid_circuit_measurements( and use the ``qml.defer_measurements`` transform otherwise. """ - mcm_config = mcm_config or {} - postselect_mode = mcm_config.get("postselect_mode", None) - mcm_method = mcm_config.get("mcm_method", None) + if isinstance(mcm_config, dict): + mcm_config = MCMConfig(**mcm_config) + postselect_mode = mcm_config.postselect_mode + mcm_method = mcm_config.mcm_method if mcm_method is not None: if mcm_method == "one-shot": diff --git a/pennylane/workflow/execution.py b/pennylane/workflow/execution.py index b2df628750a..4918f7abff1 100644 --- a/pennylane/workflow/execution.py +++ b/pennylane/workflow/execution.py @@ -549,21 +549,21 @@ def cost_fn(params, x): gradient_fn, grad_on_execution, interface, device, device_vjp, mcm_config ) - if interface in {"jax", "jax-jit"} and config.mcm_config["postselect_mode"] == "hw-like": + if interface in {"jax", "jax-jit"} and config.mcm_config.postselect_mode == "hw-like": warnings.warn( "Cannot discard invalid shots with postselection when using the 'jax' interface. " "Ignoring requested mid-circuit measurement configuration.", UserWarning, ) - config.mcm_config["postselect_mode"] = "fill-shots" + config.mcm_config.postselect_mode = "fill-shots" - if any(not tape.shots for tape in tapes) and config.mcm_config["mcm_method"] == "one-shot": + if any(not tape.shots for tape in tapes) and config.mcm_config.mcm_method == "one-shot": warnings.warn( "Cannot use the 'one-shot' method for mid-circuit measurements with " "analytic mode. Using deferred measurements.", UserWarning, ) - config.mcm_config["mcm_method"] = "deferred" + config.mcm_config.mcm_method = "deferred" if transform_program is None: if isinstance(device, qml.devices.Device): diff --git a/pennylane/workflow/qnode.py b/pennylane/workflow/qnode.py index 37e95d8275f..523138131fc 100644 --- a/pennylane/workflow/qnode.py +++ b/pennylane/workflow/qnode.py @@ -1051,9 +1051,6 @@ def _execution_component(self, args: tuple, kwargs: dict, override_shots) -> qml device=self.device, mcm_config=self.execute_kwargs["mcm_config"], ) - shots = self.device.shots if override_shots is False else override_shots - if shots and mcm_method in ("one-shot", None): - override_shots = 1 elif hasattr(self.device, "capabilities"): full_transform_program.add_transform( qml.defer_measurements, diff --git a/tests/devices/experimental/test_execution_config.py b/tests/devices/experimental/test_execution_config.py index b49cd430d71..9352b582734 100644 --- a/tests/devices/experimental/test_execution_config.py +++ b/tests/devices/experimental/test_execution_config.py @@ -17,7 +17,7 @@ import pytest -from pennylane.devices import ExecutionConfig +from pennylane.devices.execution_config import ExecutionConfig, MCMConfig def test_default_values(): @@ -30,7 +30,14 @@ def test_default_values(): assert config.gradient_keyword_arguments == {} assert config.grad_on_execution is None assert config.use_device_gradient is None - assert config.mcm_config == {"postselect_mode": None, "mcm_method": None} + assert config.mcm_config == MCMConfig() + + +def test_mcm_config_default_values(): + """Test that the default values of MCMConfig are correct""" + mcm_config = MCMConfig() + assert mcm_config.postselect_mode is None + assert mcm_config.mcm_method is None def test_invalid_interface(): @@ -50,3 +57,22 @@ def test_invalid_grad_on_execution(): """Test invalid values for grad on execution raise an error.""" with pytest.raises(ValueError, match=r"grad_on_execution must be True, False,"): ExecutionConfig(grad_on_execution="forward") + + +@pytest.mark.parametrize( + "option", [MCMConfig(mcm_method="deferred"), {"mcm_method": "deferred"}, None] +) +def test_valid_mcm_config(option): + """Test that the mcm_config attribute is set correctly""" + config = ExecutionConfig(mcm_config=option) if option else ExecutionConfig() + if option is None: + assert config.mcm_config == MCMConfig() + else: + assert config.mcm_config == MCMConfig(mcm_method="deferred") + + +def test_invalid_mcm_config(): + """Test that an error is raised if mcm_config is set incorrectly""" + option = "foo" + with pytest.raises(ValueError, match="Got invalid type"): + _ = ExecutionConfig(mcm_config=option) diff --git a/tests/interfaces/test_jacobian_products.py b/tests/interfaces/test_jacobian_products.py index 67f245f8dcb..07f3ccb60ab 100644 --- a/tests/interfaces/test_jacobian_products.py +++ b/tests/interfaces/test_jacobian_products.py @@ -152,7 +152,7 @@ def test_device_jacobians_repr(self): r" use_device_jacobian_product=None," r" gradient_method='adjoint', gradient_keyword_arguments={}," r" device_options={}, interface=None, derivative_order=1," - r" mcm_config={'postselect_mode': None, 'mcm_method': None})>" + r" mcm_config=MCMConfig(mcm_method=None, postselect_mode=None))>" ) assert repr(jpc) == expected @@ -171,7 +171,7 @@ def test_device_jacobian_products_repr(self): r" use_device_jacobian_product=None," r" gradient_method='adjoint', gradient_keyword_arguments={}, device_options={}," r" interface=None, derivative_order=1," - r" mcm_config={'postselect_mode': None, 'mcm_method': None})>" + r" mcm_config=MCMConfig(mcm_method=None, postselect_mode=None))>" ) assert repr(jpc) == expected From 73ee6e2137fcc1f3c522e79044533793adbd947d Mon Sep 17 00:00:00 2001 From: Mudit Pandey Date: Mon, 27 May 2024 16:58:28 -0400 Subject: [PATCH 24/48] Fix linting after merge issues --- pennylane/transforms/dynamic_one_shot.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/pennylane/transforms/dynamic_one_shot.py b/pennylane/transforms/dynamic_one_shot.py index 10ddf66ba5c..0840b0cec4a 100644 --- a/pennylane/transforms/dynamic_one_shot.py +++ b/pennylane/transforms/dynamic_one_shot.py @@ -289,7 +289,7 @@ def measurement_with_no_shots(measurement): # as it assumes all elements of the input are of builtin python types and not belonging # to any particular interface result = qml.math.stack(result, like=interface) - meas = gather_non_mcm(m, result, is_valid) + meas = gather_non_mcm(m, result, is_valid, postselect_shots) m_count += 1 if isinstance(m, SampleMP): meas = qml.math.squeeze(meas) @@ -370,7 +370,7 @@ def gather_mcm(measurement, samples, is_valid, postselect_shots): return counts / qml.math.sum(counts) if isinstance(measurement, CountsMP): mcm_samples = [{"".join(str(int(v)) for v in tuple(s)): 1} for s in mcm_samples] - return gather_non_mcm(measurement, mcm_samples, is_valid) + return gather_non_mcm(measurement, mcm_samples, is_valid, postselect_shots) mcm_samples = qml.math.ravel(qml.math.array(mv.concretize(samples), like=interface)) if isinstance(measurement, ProbabilityMP): counts = [qml.math.sum((mcm_samples == v) * is_valid) for v in list(mv.branches.values())] @@ -378,4 +378,4 @@ def gather_mcm(measurement, samples, is_valid, postselect_shots): return counts / qml.math.sum(counts) if isinstance(measurement, CountsMP): mcm_samples = [{float(s): 1} for s in mcm_samples] - return gather_non_mcm(measurement, mcm_samples, is_valid) + return gather_non_mcm(measurement, mcm_samples, is_valid, postselect_shots) From 9218ec1e9f444f110794c2b5253cc364fff64845 Mon Sep 17 00:00:00 2001 From: Mudit Pandey Date: Mon, 27 May 2024 17:06:52 -0400 Subject: [PATCH 25/48] Fixing MCMConfig intialization --- pennylane/workflow/execution.py | 1 + pennylane/workflow/qnode.py | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/pennylane/workflow/execution.py b/pennylane/workflow/execution.py index 4918f7abff1..77c0f3160e7 100644 --- a/pennylane/workflow/execution.py +++ b/pennylane/workflow/execution.py @@ -545,6 +545,7 @@ def cost_fn(params, x): ) gradient_kwargs = gradient_kwargs or {} + mcm_config = mcm_config or {} config = config or _get_execution_config( gradient_fn, grad_on_execution, interface, device, device_vjp, mcm_config ) diff --git a/pennylane/workflow/qnode.py b/pennylane/workflow/qnode.py index 523138131fc..9f461a831e6 100644 --- a/pennylane/workflow/qnode.py +++ b/pennylane/workflow/qnode.py @@ -73,7 +73,7 @@ def _make_execution_config( grad_on_execution = False elif grad_on_execution == "best": grad_on_execution = None - mcm_config = getattr(circuit, "execute_kwargs", {}).get("mcm_config", None) + mcm_config = getattr(circuit, "execute_kwargs", {}).get("mcm_config", {}) return qml.devices.ExecutionConfig( interface=getattr(circuit, "interface", None), From 8548502d9802f8a8b183a66390e3721b063c1eeb Mon Sep 17 00:00:00 2001 From: Mudit Pandey Date: Tue, 28 May 2024 16:30:32 -0400 Subject: [PATCH 26/48] Added fill-shots support to dynamic_one_shot; fixing type hints --- pennylane/devices/preprocess.py | 20 ++++----- pennylane/devices/qubit/apply_operation.py | 25 +++++++---- pennylane/devices/qubit/simulate.py | 48 ++++++++-------------- pennylane/transforms/defer_measurements.py | 2 +- pennylane/transforms/dynamic_one_shot.py | 37 +++++------------ pennylane/workflow/execution.py | 8 ---- 6 files changed, 55 insertions(+), 85 deletions(-) diff --git a/pennylane/devices/preprocess.py b/pennylane/devices/preprocess.py index 12e51f9880e..f9fcb7715f6 100644 --- a/pennylane/devices/preprocess.py +++ b/pennylane/devices/preprocess.py @@ -82,7 +82,7 @@ def _operator_decomposition_gen( @transform def no_sampling( tape: qml.tape.QuantumTape, name: str = "device" -) -> (Sequence[qml.tape.QuantumTape], Callable): +) -> tuple[Sequence[qml.tape.QuantumTape], Callable]: """Raises an error if the tape has finite shots. Args: @@ -106,7 +106,7 @@ def no_sampling( @transform def validate_device_wires( tape: qml.tape.QuantumTape, wires: Optional[qml.wires.Wires] = None, name: str = "device" -) -> (Sequence[qml.tape.QuantumTape], Callable): +) -> tuple[Sequence[qml.tape.QuantumTape], Callable]: """Validates that all wires present in the tape are in the set of provided wires. Adds the device wires to measurement processes like :class:`~.measurements.StateMP` that are broadcasted across all available wires. @@ -148,7 +148,7 @@ def validate_device_wires( @transform def mid_circuit_measurements( tape: qml.tape.QuantumTape, device, mcm_config=MCMConfig -) -> (Sequence[qml.tape.QuantumTape], Callable): +) -> tuple[Sequence[qml.tape.QuantumTape], Callable]: """Provide the transform to handle mid-circuit measurements. If the tape or device uses finite-shot, use the native implementation (i.e. no transform), @@ -162,18 +162,18 @@ def mid_circuit_measurements( if mcm_method is not None: if mcm_method == "one-shot": - return qml.dynamic_one_shot(tape, postselect_mode=postselect_mode) + return qml.dynamic_one_shot(tape) return qml.defer_measurements(tape, device=device, postselect_mode=postselect_mode) if tape.shots: - return qml.dynamic_one_shot(tape, postselect_mode=postselect_mode) + return qml.dynamic_one_shot(tape) return qml.defer_measurements(tape, device=device, postselect_mode=postselect_mode) @transform def validate_multiprocessing_workers( tape: qml.tape.QuantumTape, max_workers: int, device -) -> (Sequence[qml.tape.QuantumTape], Callable): +) -> tuple[Sequence[qml.tape.QuantumTape], Callable]: """Validates the number of workers for multiprocessing. Checks that the CPU is not oversubscribed and warns user if it is, @@ -232,7 +232,7 @@ def validate_multiprocessing_workers( @transform def validate_adjoint_trainable_params( tape: qml.tape.QuantumTape, -) -> (Sequence[qml.tape.QuantumTape], Callable): +) -> tuple[Sequence[qml.tape.QuantumTape], Callable]: """Raises a warning if any of the observables is trainable, and raises an error if any trainable parameters belong to state-prep operations. Can be used in validating circuits for adjoint differentiation. @@ -268,7 +268,7 @@ def decompose( max_expansion: Union[int, None] = None, name: str = "device", error: Exception = DeviceError, -) -> (Sequence[qml.tape.QuantumTape], Callable): +) -> tuple[Sequence[qml.tape.QuantumTape], Callable]: """Decompose operations until the stopping condition is met. Args: @@ -382,7 +382,7 @@ def validate_observables( tape: qml.tape.QuantumTape, stopping_condition: Callable[[qml.operation.Operator], bool], name: str = "device", -) -> (Sequence[qml.tape.QuantumTape], Callable): +) -> tuple[Sequence[qml.tape.QuantumTape], Callable]: """Validates the observables and measurements for a circuit. Args: @@ -424,7 +424,7 @@ def validate_observables( @transform def validate_measurements( tape: qml.tape.QuantumTape, analytic_measurements=None, sample_measurements=None, name="device" -) -> (Sequence[qml.tape.QuantumTape], Callable): +) -> tuple[Sequence[qml.tape.QuantumTape], Callable]: """Validates the supported state and sample based measurement processes. Args: diff --git a/pennylane/devices/qubit/apply_operation.py b/pennylane/devices/qubit/apply_operation.py index ae0c9117c38..80d636393fd 100644 --- a/pennylane/devices/qubit/apply_operation.py +++ b/pennylane/devices/qubit/apply_operation.py @@ -284,6 +284,9 @@ def apply_mid_measure( is_state_batched (bool): Boolean representing whether the state is batched or not debugger (_Debugger): The debugger to use mid_measurements (dict, None): Mid-circuit measurement dictionary mutated to record the sampled value + postselect_mode (str): Configuration for handling shots with mid-circuit measurement + postselection. Use ``"hw-like"`` to discard invalid shots and ``"fill-shots"`` to + keep the same number of shots. rng (Optional[numpy.random._generator.Generator]): A NumPy random number generator. prng_key (Optional[jax.random.PRNGKey]): An optional ``jax.random.PRNGKey``. This is the key to the JAX pseudo random number generator. Only for simulation using JAX. @@ -295,6 +298,8 @@ def apply_mid_measure( mid_measurements = execution_kwargs.get("mid_measurements", None) rng = execution_kwargs.get("rng", None) prng_key = execution_kwargs.get("prng_key", None) + postselect_mode = execution_kwargs.get("postselect_mode", None) + if is_state_batched: raise ValueError("MidMeasureMP cannot be applied to batched states.") wire = op.wires @@ -303,16 +308,20 @@ def apply_mid_measure( slices[axis] = 0 prob0 = qml.math.norm(state[tuple(slices)]) ** 2 interface = qml.math.get_deep_interface(state) - if prng_key is not None: - # pylint: disable=import-outside-toplevel - from jax.random import binomial - - def binomial_fn(n, p): - return binomial(prng_key, n, p).astype(int) + if postselect_mode == "fill-shots" and op.postselect is not None: + sample = op.postselect else: - binomial_fn = np.random.binomial if rng is None else rng.binomial - sample = binomial_fn(1, 1 - prob0) + if prng_key is not None: + # pylint: disable=import-outside-toplevel + from jax.random import binomial + + def binomial_fn(n, p): + return binomial(prng_key, n, p).astype(int) + + else: + binomial_fn = np.random.binomial if rng is None else rng.binomial + sample = binomial_fn(1, 1 - prob0) mid_measurements[op] = sample # Using apply_operation(qml.QubitUnitary,...) instead of apply_operation(qml.Projector([sample], wire),...) diff --git a/pennylane/devices/qubit/simulate.py b/pennylane/devices/qubit/simulate.py index 1ea6830602b..1a24c0da5e1 100644 --- a/pennylane/devices/qubit/simulate.py +++ b/pennylane/devices/qubit/simulate.py @@ -140,11 +140,8 @@ def get_final_state(circuit, debugger=None, **execution_kwargs): whether the state has a batch dimension. """ - rng = execution_kwargs.get("rng", None) - prng_key = execution_kwargs.get("prng_key", None) + prng_key = execution_kwargs.pop("prng_key", None) interface = execution_kwargs.get("interface", None) - mid_measurements = execution_kwargs.get("mid_measurements", None) - postselect_mode = execution_kwargs.get("postselect_mode", None) prep = None if len(circuit) > 0 and isinstance(circuit[0], qml.operation.StatePrepBase): @@ -164,20 +161,14 @@ def get_final_state(circuit, debugger=None, **execution_kwargs): state, is_state_batched=is_state_batched, debugger=debugger, - mid_measurements=mid_measurements, - rng=rng, prng_key=key, + **execution_kwargs, ) # Handle postselection on mid-circuit measurements if isinstance(op, qml.Projector): prng_key, key = jax_random_split(prng_key) state, new_shots = _postselection_postprocess( - state, - is_state_batched, - circuit.shots, - rng=rng, - prng_key=key, - postselect_mode=postselect_mode, + state, is_state_batched, circuit.shots, prng_key=key, **execution_kwargs ) circuit._shots = circuit._shots = new_shots @@ -293,11 +284,7 @@ def simulate( tensor([0.68117888, 0. , 0.31882112, 0. ], requires_grad=True)) """ - rng = execution_kwargs.get("rng", None) - prng_key = execution_kwargs.get("prng_key", None) - interface = execution_kwargs.get("interface", None) - postselect_mode = execution_kwargs.get("postselect_mode", None) - + prng_key = execution_kwargs.pop("prng_key", None) circuit = circuit.map_to_standard_wires() has_mcm = any(isinstance(op, MidMeasureMP) for op in circuit.operations) @@ -316,7 +303,7 @@ def simulate( def simulate_partial(k): return simulate_one_shot_native_mcm( - aux_circ, debugger=debugger, rng=rng, prng_key=k, interface=interface + aux_circ, debugger=debugger, prng_key=k, **execution_kwargs ) results = jax.vmap(simulate_partial, in_axes=(0,))(keys) @@ -325,23 +312,20 @@ def simulate_partial(k): for i in range(circuit.shots.total_shots): results.append( simulate_one_shot_native_mcm( - aux_circ, debugger=debugger, rng=rng, prng_key=keys[i], interface=interface + aux_circ, debugger=debugger, prng_key=keys[i], **execution_kwargs ) ) return tuple(results) ops_key, meas_key = jax_random_split(prng_key) state, is_state_batched = get_final_state( - circuit, - debugger=debugger, - rng=rng, - prng_key=ops_key, - interface=interface, - postselect_mode=postselect_mode, + circuit, debugger=debugger, prng_key=ops_key, **execution_kwargs ) if state_cache is not None: state_cache[circuit.hash] = state - return measure_final_state(circuit, state, is_state_batched, rng=rng, prng_key=meas_key) + return measure_final_state( + circuit, state, is_state_batched, prng_key=meas_key, **execution_kwargs + ) def simulate_one_shot_native_mcm( @@ -357,30 +341,30 @@ def simulate_one_shot_native_mcm( the key to the JAX pseudo random number generator. If None, a random key will be generated. Only for simulation using JAX. interface (str): The machine learning interface to create the initial state with + postselect_mode (str): Configuration for handling shots with mid-circuit measurement + postselection. Use ``"hw-like"`` to discard invalid shots and ``"fill-shots"`` to + keep the same number of shots. Returns: tuple(TensorLike): The results of the simulation dict: The mid-circuit measurement results of the simulation """ - rng = execution_kwargs.get("rng", None) - prng_key = execution_kwargs.get("prng_key", None) - interface = execution_kwargs.get("interface", None) + prng_key = execution_kwargs.pop("prng_key", None) ops_key, meas_key = jax_random_split(prng_key) mid_measurements = {} state, is_state_batched = get_final_state( circuit, debugger=debugger, - interface=interface, mid_measurements=mid_measurements, - rng=rng, prng_key=ops_key, + **execution_kwargs, ) return measure_final_state( circuit, state, is_state_batched, - rng=rng, prng_key=meas_key, mid_measurements=mid_measurements, + **execution_kwargs, ) diff --git a/pennylane/transforms/defer_measurements.py b/pennylane/transforms/defer_measurements.py index 46253b55388..b1ecaafa49c 100644 --- a/pennylane/transforms/defer_measurements.py +++ b/pennylane/transforms/defer_measurements.py @@ -104,7 +104,7 @@ def null_postprocessing(results): @transform def defer_measurements( tape: QuantumTape, reduce_postselected: bool = True, **kwargs -) -> (Sequence[QuantumTape], Callable): +) -> tuple[Sequence[QuantumTape], Callable]: """Quantum function transform that substitutes operations conditioned on measurement outcomes to controlled operations. diff --git a/pennylane/transforms/dynamic_one_shot.py b/pennylane/transforms/dynamic_one_shot.py index 0840b0cec4a..14a448eb8cd 100644 --- a/pennylane/transforms/dynamic_one_shot.py +++ b/pennylane/transforms/dynamic_one_shot.py @@ -50,7 +50,7 @@ def null_postprocessing(results): @transform def dynamic_one_shot( tape: qml.tape.QuantumTape, **kwargs -) -> (Sequence[qml.tape.QuantumTape], Callable): +) -> tuple[Sequence[qml.tape.QuantumTape], Callable]: """Transform a QNode to into several one-shot tapes to support dynamic circuit execution. Args: @@ -99,11 +99,6 @@ def func(x, y): ) _ = kwargs.get("device", None) - postselect_mode = kwargs.get("postselect_mode", None) - if qml.compiler.active() and postselect_mode == "hw-like": - raise ValueError("Cannot discard invalid shots while using qml.qjit") - postselect_shots = postselect_mode in (None, "hw-like") - if not tape.shots: raise qml.QuantumFunctionError("dynamic_one_shot is only supported with finite shots.") @@ -146,7 +141,7 @@ def processing_fn(results, has_partitioned_shots=None, batched_results=None): return tuple(final_results) if not tape.shots.has_partitioned_shots: results = results[0] - return parse_native_mid_circuit_measurements(tape, aux_tapes, results, postselect_shots) + return parse_native_mid_circuit_measurements(tape, aux_tapes, results) return aux_tapes, processing_fn @@ -213,10 +208,7 @@ def init_auxiliary_tape(circuit: qml.tape.QuantumScript): def parse_native_mid_circuit_measurements( - circuit: qml.tape.QuantumScript, - aux_tapes: qml.tape.QuantumScript, - results: TensorLike, - postselect_shots: bool, + circuit: qml.tape.QuantumScript, aux_tapes: qml.tape.QuantumScript, results: TensorLike ): """Combines, gathers and normalizes the results of native mid-circuit measurement runs. @@ -224,8 +216,6 @@ def parse_native_mid_circuit_measurements( circuit (QuantumTape): Initial ``QuantumScript`` aux_tapes (List[QuantumTape]): List of auxilary ``QuantumScript`` objects results (TensorLike): Array of measurement results - postselect_shots (bool): Whether or not to discard shots that don't match the - postselection criteria. Returns: tuple(TensorLike): The results of the simulation @@ -278,7 +268,7 @@ def measurement_with_no_shots(measurement): if interface != "jax" and m.mv and not has_valid: meas = measurement_with_no_shots(m) elif m.mv: - meas = gather_mcm(m, mcm_samples, is_valid, postselect_shots) + meas = gather_mcm(m, mcm_samples, is_valid) elif interface != "jax" and not has_valid: meas = measurement_with_no_shots(m) m_count += 1 @@ -289,7 +279,7 @@ def measurement_with_no_shots(measurement): # as it assumes all elements of the input are of builtin python types and not belonging # to any particular interface result = qml.math.stack(result, like=interface) - meas = gather_non_mcm(m, result, is_valid, postselect_shots) + meas = gather_non_mcm(m, result, is_valid) m_count += 1 if isinstance(m, SampleMP): meas = qml.math.squeeze(meas) @@ -298,7 +288,7 @@ def measurement_with_no_shots(measurement): return tuple(normalized_meas) if len(normalized_meas) > 1 else normalized_meas[0] -def gather_non_mcm(measurement, samples, is_valid, postselect_shots): +def gather_non_mcm(measurement, samples, is_valid): """Combines, gathers and normalizes several measurements with trivial measurement values. Args: @@ -306,8 +296,6 @@ def gather_non_mcm(measurement, samples, is_valid, postselect_shots): samples (TensorLike): measurement samples is_valid (TensorLike): Boolean array with the same shape as ``samples`` where the value at each index specifies whether or not the respective sample is valid. - postselect_shots (bool): Whether or not to discard shots that don't match the - postselection criteria. Returns: TensorLike: The combined measurement outcome @@ -326,12 +314,11 @@ def gather_non_mcm(measurement, samples, is_valid, postselect_shots): return qml.math.sum(samples * is_valid.reshape((-1, 1)), axis=0) / qml.math.sum(is_valid) if isinstance(measurement, SampleMP): is_interface_jax = qml.math.get_deep_interface(is_valid) == "jax" - postselect_shots = postselect_shots and not is_interface_jax - if not postselect_shots and samples.ndim == 2: + if is_interface_jax and samples.ndim == 2: is_valid = is_valid.reshape((-1, 1)) return ( qml.math.where(is_valid, samples, fill_in_value) - if not postselect_shots + if is_interface_jax else samples[is_valid] ) # VarianceMP @@ -339,7 +326,7 @@ def gather_non_mcm(measurement, samples, is_valid, postselect_shots): return qml.math.sum((samples - expval) ** 2 * is_valid) / qml.math.sum(is_valid) -def gather_mcm(measurement, samples, is_valid, postselect_shots): +def gather_mcm(measurement, samples, is_valid): """Combines, gathers and normalizes several measurements with non-trivial measurement values. Args: @@ -347,8 +334,6 @@ def gather_mcm(measurement, samples, is_valid, postselect_shots): samples (List[dict]): Mid-circuit measurement samples is_valid (TensorLike): Boolean array with the same shape as ``samples`` where the value at each index specifies whether or not the respective sample is valid. - postselect_shots (bool): Whether or not to discard shots that don't match the - postselection criteria. Returns: TensorLike: The combined measurement outcome @@ -370,7 +355,7 @@ def gather_mcm(measurement, samples, is_valid, postselect_shots): return counts / qml.math.sum(counts) if isinstance(measurement, CountsMP): mcm_samples = [{"".join(str(int(v)) for v in tuple(s)): 1} for s in mcm_samples] - return gather_non_mcm(measurement, mcm_samples, is_valid, postselect_shots) + return gather_non_mcm(measurement, mcm_samples, is_valid) mcm_samples = qml.math.ravel(qml.math.array(mv.concretize(samples), like=interface)) if isinstance(measurement, ProbabilityMP): counts = [qml.math.sum((mcm_samples == v) * is_valid) for v in list(mv.branches.values())] @@ -378,4 +363,4 @@ def gather_mcm(measurement, samples, is_valid, postselect_shots): return counts / qml.math.sum(counts) if isinstance(measurement, CountsMP): mcm_samples = [{float(s): 1} for s in mcm_samples] - return gather_non_mcm(measurement, mcm_samples, is_valid, postselect_shots) + return gather_non_mcm(measurement, mcm_samples, is_valid) diff --git a/pennylane/workflow/execution.py b/pennylane/workflow/execution.py index 77c0f3160e7..ffd9e47f87a 100644 --- a/pennylane/workflow/execution.py +++ b/pennylane/workflow/execution.py @@ -550,14 +550,6 @@ def cost_fn(params, x): gradient_fn, grad_on_execution, interface, device, device_vjp, mcm_config ) - if interface in {"jax", "jax-jit"} and config.mcm_config.postselect_mode == "hw-like": - warnings.warn( - "Cannot discard invalid shots with postselection when using the 'jax' interface. " - "Ignoring requested mid-circuit measurement configuration.", - UserWarning, - ) - config.mcm_config.postselect_mode = "fill-shots" - if any(not tape.shots for tape in tapes) and config.mcm_config.mcm_method == "one-shot": warnings.warn( "Cannot use the 'one-shot' method for mid-circuit measurements with " From 8e6cb550746437efd2333878a16d30ded79bc6cf Mon Sep 17 00:00:00 2001 From: Mudit Pandey Date: Tue, 28 May 2024 17:00:40 -0400 Subject: [PATCH 27/48] Added test for jax with postselect_mode='hw-like' and one-shot --- tests/transforms/test_dynamic_one_shot.py | 52 ++++++++++++----------- 1 file changed, 28 insertions(+), 24 deletions(-) diff --git a/tests/transforms/test_dynamic_one_shot.py b/tests/transforms/test_dynamic_one_shot.py index 32a3b3c3911..18caefff1ff 100644 --- a/tests/transforms/test_dynamic_one_shot.py +++ b/tests/transforms/test_dynamic_one_shot.py @@ -45,7 +45,7 @@ def test_parse_native_mid_circuit_measurements_unsupported_meas(measurement): circuit = qml.tape.QuantumScript([qml.RX(1.0, 0)], [measurement]) with pytest.raises(TypeError, match="Native mid-circuit measurement mode does not support"): - parse_native_mid_circuit_measurements(circuit, [circuit], [[]], None) + parse_native_mid_circuit_measurements(circuit, [circuit], [[]]) def test_postselection_error_with_wrong_device(): @@ -61,28 +61,6 @@ def _(): return qml.probs(wires=[0]) -def test_qjit_postselection_error(monkeypatch): - """Test that an error is raised if qjit is active with `postselect_mode="hw-like"`""" - # TODO: Update test once defer_measurements can be used with qjit - # catalyst = pytest.importorskip("catalyst") - # dev = qml.device("lightning.qubit", wires=3, shots=10) - dev = qml.device("default.qubit", wires=3, shots=10) - - # @qml.qjit - @qml.qnode(dev, postselect_mode="hw-like", mcm_method="one-shot") - def func(x): - qml.RX(x, 0) - _ = qml.measure(0, postselect=0) - # _ = catalyst.measure(0, postselect=0) - return qml.sample(wires=[0, 1]) - - # Mocking qml.compiler.active() to always return True - with monkeypatch.context() as m: - m.setattr(qml.compiler, "active", lambda: True) - with pytest.raises(ValueError, match="Cannot discard invalid shots while using qml.qjit"): - _ = func(1.8) - - @pytest.mark.parametrize("postselect_mode", ["hw-like", "fill-shots"]) def test_postselect_mode(postselect_mode, mocker): """Test that invalid shots are discarded if requested""" @@ -103,7 +81,33 @@ def f(x): assert len(res) < shots else: assert len(res) == shots - assert np.any(res == np.iinfo(np.int32).min) + assert np.all(res != np.iinfo(np.int32).min) + + +@pytest.mark.jax +@pytest.mark.parametrize("use_jit", [True, False]) +@pytest.mark.parametrize("diff_method", [None, "best"]) +def test_hw_like_with_jax(use_jit, diff_method): + """Test that invalid shots are replaced with INTEGER_MIN_VAL if + postselect_mode="hw-like" with JAX""" + import jax # pylint: disable=import-outside-toplevel + + shots = 10 + dev = qml.device("default.qubit", shots=shots, seed=jax.random.PRNGKey(123)) + + @qml.qnode(dev, postselect_mode="hw-like", diff_method=diff_method) + def f(x): + qml.RX(x, 0) + _ = qml.measure(0, postselect=1) + return qml.sample(wires=[0, 1]) + + if use_jit: + f = jax.jit(f) + + res = f(jax.numpy.array(np.pi / 2)) + + assert len(res) == shots + assert np.any(res == np.iinfo(np.int32).min) def test_unsupported_measurements(): From 3056cb8cc803b6a7708cb6fa20faf7ad6827ac98 Mon Sep 17 00:00:00 2001 From: Mudit Pandey Date: Tue, 28 May 2024 17:19:44 -0400 Subject: [PATCH 28/48] Updated docs --- doc/introduction/measurements.rst | 79 ++++++++++++++++++------------- 1 file changed, 45 insertions(+), 34 deletions(-) diff --git a/doc/introduction/measurements.rst b/doc/introduction/measurements.rst index d03e413fa4e..4942845e67c 100644 --- a/doc/introduction/measurements.rst +++ b/doc/introduction/measurements.rst @@ -538,46 +538,57 @@ PennyLane. For ease of use, we provide the following configuration options to us .. warning:: If the ``mcm_method`` argument is provided, the :func:`~pennylane.defer_measurements` or - :func:`~pennylane.dynamic_one_shot` transforms should not be applied directly to a :class:`~pennylane.QNode` + :func:`~pennylane.dynamic_one_shot` transforms must not be applied directly to a :class:`~pennylane.QNode` as it can lead to incorrect behaviour. * ``postselect_mode``: To configure how invalid shots are handled when postselecting mid-circuit measurements - with finite-shot circuits. Use ``postselect_mode="hw-like"`` to discard invalid shots. In this case, the number - of samples that are used for processing results will be less than or equal to the total number of shots. Use - ``postselect_mode="fill-shots"`` to keep invalid shots. + with finite-shot circuits. Use ``postselect_mode="hw-like"`` to discard invalid samples. In this case, the number + of samples that are used for processing results will be less than or equal to the total number of shots. If + ``postselect_mode="fill-shots"`` is used, then the postselected value will be picked unconditionally, and all + samples will be considered valid. + + .. code-block:: python3 + + import pennylane as qml + import numpy as np + + dev = qml.device("default.qubit", wires=3, shots=10) + + @qml.qnode(dev, mcm_method="one-shot", postselect_mode="fill-shots") + def circuit(x): + qml.RX(x, 0) + m0 = qml.measure(0, postselect=1) + qml.CNOT([0, 1]) + return qml.sample(qml.PauliZ(0)) + + >>> circuit(np.pi / 2) + array([-1., -1., -1., -1., -1., -1., -1., -1., -1., -1.]) .. note:: - If ``postselect_mode="fill-shots"``, the specified ``mcm_method`` will impact the results due to the particular - features/limitations of the requested ``mcm_method``. - - * If using ``mcm_method="hw-like"``, invalid samples will be replaced with ``np.iinfo(np.int32).min``, - and these invalid values will not be used for processing final results. - * If using ``mcm_method="deferred"``, all shots will be projected to the postselected value, so all - shots will be considered valid. - - .. note:: - - When using the ``jax`` interface or while using :func:`~pennylane.qjit`, the results will reflect - ``postselect_mode="fill-shots"`` regardless of the specified value. - -.. code-block:: python3 - - import pennylane as qml - import numpy as np - - dev = qml.device("default.qubit", wires=3, shots=10) - - @qml.qnode(dev, mcm_method="one-shot", postselect_mode="fill-shots") - def circuit(x): - qml.RX(x, 0) - m0 = qml.measure(0, postselect=1) - qml.CNOT([0, 1]) - return qml.expval(qml.PauliZ(0)) - ->>> circuit(np.pi / 2) -array([-2147483648, 1, -2147483648, -2147483648, 1, - 1, 1, 1, 1, 1]) + If ``postselect_mode="hw-like"``, invalid shots will not be discarded when using the ``jax`` interface. + Instead, invalid samples will be replaced by ``np.iinfo(np.int32).min``. These invalid samples will not be + used for processing final results. Consider the circuit below: + + .. code-block:: python3 + + import pennylane as qml + import jax + import jax.numpy as jnp + + dev = qml.device("default.qubit", wires=3, shots=10, seed=jax.random.PRNGKey(123)) + + @qml.qnode(dev, postselect_mode="hw-like", mcm_method="one-shot") + def circuit(x): + qml.RX(x, 0) + qml.measure(0, postselect=1) + return qml.sample(qml.PauliZ(0)) + + >>> x = jnp.array(1.8) + >>> f(x) + Array([-2.1474836e+09, -1.0000000e+00, -2.1474836e+09, -2.1474836e+09, + -1.0000000e+00, -2.1474836e+09, -1.0000000e+00, -2.1474836e+09, + -1.0000000e+00, -1.0000000e+00], dtype=float32, weak_type=True) Changing the number of shots ---------------------------- From 6069b1041fa08db90c2b3da2b5d155f1be7db357 Mon Sep 17 00:00:00 2001 From: Mudit Pandey Date: Tue, 28 May 2024 17:22:52 -0400 Subject: [PATCH 29/48] Add info about defaults to docs --- doc/introduction/measurements.rst | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/doc/introduction/measurements.rst b/doc/introduction/measurements.rst index 4942845e67c..b462acd88bd 100644 --- a/doc/introduction/measurements.rst +++ b/doc/introduction/measurements.rst @@ -533,7 +533,8 @@ PennyLane. For ease of use, we provide the following configuration options to us * ``mcm_method``: To set the method used for applying mid-circuit measurements. Use ``mcm_method="deferred"`` to use the deferred measurements principle or ``mcm_method="one-shot"`` to use the one-shot transform as - described above. + described above. When executing with finite shots, ``mcm_method="one-shot"`` will be the default, and + ``mcm_method="deferred"`` otherwise. .. warning:: @@ -545,7 +546,7 @@ PennyLane. For ease of use, we provide the following configuration options to us with finite-shot circuits. Use ``postselect_mode="hw-like"`` to discard invalid samples. In this case, the number of samples that are used for processing results will be less than or equal to the total number of shots. If ``postselect_mode="fill-shots"`` is used, then the postselected value will be picked unconditionally, and all - samples will be considered valid. + samples will be considered valid. The default behaviour is ``postselect_mode="hw-like"``. .. code-block:: python3 From a4f40e51641ceb1a934839f86090cd9283c5c9ac Mon Sep 17 00:00:00 2001 From: Mudit Pandey Date: Tue, 28 May 2024 17:28:18 -0400 Subject: [PATCH 30/48] Removed failing test --- tests/test_qnode.py | 28 ---------------------------- 1 file changed, 28 deletions(-) diff --git a/tests/test_qnode.py b/tests/test_qnode.py index 658733b3381..3b987c3cce3 100644 --- a/tests/test_qnode.py +++ b/tests/test_qnode.py @@ -1701,34 +1701,6 @@ def circuit(): class TestMCMConfiguration: """Tests for MCM configuration arguments""" - @pytest.mark.jax - @pytest.mark.parametrize("use_jit", [True, False]) - @pytest.mark.parametrize("interface", ["jax", "auto"]) - def test_jax_warning_with_postselect_mode_hw_like(self, use_jit, interface): - """Test that a warning is raised when postselect_mode="hw-like" with jax""" - import jax # pylint: disable=import-outside-toplevel - - shots = 100 - dev = qml.device("default.qubit", wires=3, shots=shots) - - @qml.qnode(dev, postselect_mode="hw-like", interface=interface) - def f(x): - qml.RX(x, 0) - _ = qml.measure(0, postselect=1) - return qml.sample(wires=[0, 1]) - - if use_jit: - f = jax.jit(f) - param = jax.numpy.array(np.pi / 4) - - with pytest.warns( - UserWarning, - match="Cannot discard invalid shots with postselection when using the 'jax' interface", - ): - res = f(param) - - assert len(res) == shots - @pytest.mark.parametrize("dev_name", ["default.qubit", "default.qubit.legacy"]) def test_one_shot_warning_without_shots(self, dev_name, mocker): """Test that a warning is raised if mcm_method="one-shot" with no shots""" From c6d2a4c9a2d89dce8e9e895348d1a2c7be6d2ef6 Mon Sep 17 00:00:00 2001 From: Mudit Pandey Date: Wed, 29 May 2024 14:56:47 -0400 Subject: [PATCH 31/48] Updated docs --- doc/introduction/measurements.rst | 53 +++++++++++++++++-------------- 1 file changed, 30 insertions(+), 23 deletions(-) diff --git a/doc/introduction/measurements.rst b/doc/introduction/measurements.rst index b462acd88bd..808eaa49519 100644 --- a/doc/introduction/measurements.rst +++ b/doc/introduction/measurements.rst @@ -567,29 +567,36 @@ PennyLane. For ease of use, we provide the following configuration options to us .. note:: - If ``postselect_mode="hw-like"``, invalid shots will not be discarded when using the ``jax`` interface. - Instead, invalid samples will be replaced by ``np.iinfo(np.int32).min``. These invalid samples will not be - used for processing final results. Consider the circuit below: - - .. code-block:: python3 - - import pennylane as qml - import jax - import jax.numpy as jnp - - dev = qml.device("default.qubit", wires=3, shots=10, seed=jax.random.PRNGKey(123)) - - @qml.qnode(dev, postselect_mode="hw-like", mcm_method="one-shot") - def circuit(x): - qml.RX(x, 0) - qml.measure(0, postselect=1) - return qml.sample(qml.PauliZ(0)) - - >>> x = jnp.array(1.8) - >>> f(x) - Array([-2.1474836e+09, -1.0000000e+00, -2.1474836e+09, -2.1474836e+09, - -1.0000000e+00, -2.1474836e+09, -1.0000000e+00, -2.1474836e+09, - -1.0000000e+00, -1.0000000e+00], dtype=float32, weak_type=True) + When using the ``jax`` interface, ``postselect_mode="hw-like"`` will have different behaviour based on the + chosen ``mcm_method``. + + * If ``mcm_method="one-shot"``, invalid shots will not be discarded. Instead, invalid samples will be replaced + by ``np.iinfo(np.int32).min``. These invalid samples will not be used for processing final results. Consider + the circuit below: + + .. code-block:: python3 + + import pennylane as qml + import jax + import jax.numpy as jnp + + dev = qml.device("default.qubit", wires=3, shots=10, seed=jax.random.PRNGKey(123)) + + @qml.qnode(dev, postselect_mode="hw-like", mcm_method="one-shot") + def circuit(x): + qml.RX(x, 0) + qml.measure(0, postselect=1) + return qml.sample(qml.PauliZ(0)) + + >>> x = jnp.array(1.8) + >>> f(x) + Array([-2.1474836e+09, -1.0000000e+00, -2.1474836e+09, -2.1474836e+09, + -1.0000000e+00, -2.1474836e+09, -1.0000000e+00, -2.1474836e+09, + -1.0000000e+00, -1.0000000e+00], dtype=float32, weak_type=True) + + * If ``mcm_method="deferred"``, then using ``postselect_mode="hw-like"`` will have the same behaviour as when + ``postselect_mode="fill-shots"``. This is due to the limitations of the :func:`~pennylane.defer_measurements` + transform, and this behaviour will change in the future to be more consistent with ``mcm_method="one-shot"``. Changing the number of shots ---------------------------- From d8d60bb731aec94afb1e8fdf0935368737eb8dd0 Mon Sep 17 00:00:00 2001 From: Mudit Pandey Date: Wed, 29 May 2024 15:05:47 -0400 Subject: [PATCH 32/48] Update pennylane/devices/execution_config.py Co-authored-by: Christina Lee --- pennylane/devices/execution_config.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pennylane/devices/execution_config.py b/pennylane/devices/execution_config.py index c6c8b6bb6f3..4d8c4f4586a 100644 --- a/pennylane/devices/execution_config.py +++ b/pennylane/devices/execution_config.py @@ -83,7 +83,7 @@ class ExecutionConfig: derivative_order: int = 1 """The derivative order to compute while evaluating a gradient""" - mcm_config: Optional[MCMConfig] = MCMConfig() + mcm_config: MCMConfig = MCMConfig() """Configuration options for handling mid-circuit measurements""" def __post_init__(self): From 70bd23d58eeee2d311935883053e7c878081cc05 Mon Sep 17 00:00:00 2001 From: Mudit Pandey Date: Wed, 29 May 2024 15:01:26 -0400 Subject: [PATCH 33/48] Remove qjit checks --- pennylane/devices/preprocess.py | 7 +++---- pennylane/transforms/defer_measurements.py | 3 --- tests/transforms/test_defer_measurements.py | 22 --------------------- 3 files changed, 3 insertions(+), 29 deletions(-) diff --git a/pennylane/devices/preprocess.py b/pennylane/devices/preprocess.py index f9fcb7715f6..75277eaa76a 100644 --- a/pennylane/devices/preprocess.py +++ b/pennylane/devices/preprocess.py @@ -147,7 +147,7 @@ def validate_device_wires( @transform def mid_circuit_measurements( - tape: qml.tape.QuantumTape, device, mcm_config=MCMConfig + tape: qml.tape.QuantumTape, device, mcm_config=MCMConfig() ) -> tuple[Sequence[qml.tape.QuantumTape], Callable]: """Provide the transform to handle mid-circuit measurements. @@ -157,17 +157,16 @@ def mid_circuit_measurements( if isinstance(mcm_config, dict): mcm_config = MCMConfig(**mcm_config) - postselect_mode = mcm_config.postselect_mode mcm_method = mcm_config.mcm_method if mcm_method is not None: if mcm_method == "one-shot": return qml.dynamic_one_shot(tape) - return qml.defer_measurements(tape, device=device, postselect_mode=postselect_mode) + return qml.defer_measurements(tape, device=device) if tape.shots: return qml.dynamic_one_shot(tape) - return qml.defer_measurements(tape, device=device, postselect_mode=postselect_mode) + return qml.defer_measurements(tape, device=device) @transform diff --git a/pennylane/transforms/defer_measurements.py b/pennylane/transforms/defer_measurements.py index b1ecaafa49c..69de5f88a2c 100644 --- a/pennylane/transforms/defer_measurements.py +++ b/pennylane/transforms/defer_measurements.py @@ -272,9 +272,6 @@ def node(x): if not any(isinstance(o, MidMeasureMP) for o in tape.operations): return (tape,), null_postprocessing - if qml.compiler.active() and kwargs.get("postselect_mode", None) == "hw-like": - raise ValueError("Cannot discard invalid shots while using qml.qjit") - _check_tape_validity(tape) device = kwargs.get("device", None) diff --git a/tests/transforms/test_defer_measurements.py b/tests/transforms/test_defer_measurements.py index 0356727dd77..85f0bc43202 100644 --- a/tests/transforms/test_defer_measurements.py +++ b/tests/transforms/test_defer_measurements.py @@ -105,28 +105,6 @@ def circ(): _ = circ() -def test_qjit_postselection_error(monkeypatch): - """Test that an error is raised if qjit is active with `postselect=True`""" - # TODO: Update test once defer_measurements can be used with qjit - # catalyst = pytest.importorskip("catalyst") - # dev = qml.device("lightning.qubit", wires=3, shots=10) - dev = qml.device("default.qubit", wires=3, shots=10) - - # @qml.qjit - @qml.qnode(dev, postselect_mode="hw-like", mcm_method="deferred") - def func(x): - qml.RX(x, 0) - _ = qml.measure(0, postselect=0) - # _ = catalyst.measure(0, postselect=0) - return qml.sample(wires=[0, 1]) - - # Mocking qml.compiler.active() to always return True - with monkeypatch.context() as m: - m.setattr(qml.compiler, "active", lambda: True) - with pytest.raises(ValueError, match="Cannot discard invalid shots while using qml.qjit"): - _ = func(1.8) - - @pytest.mark.parametrize("postselect_mode", ["hw-like", "fill-shots"]) def test_postselect_mode(postselect_mode, mocker): """Test that invalid shots are discarded if requested""" From c19ddddfd5b5847d3f2c9d744495504a02bda843 Mon Sep 17 00:00:00 2001 From: Mudit Pandey Date: Wed, 29 May 2024 15:08:46 -0400 Subject: [PATCH 34/48] Fixing linting error --- pennylane/devices/execution_config.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pennylane/devices/execution_config.py b/pennylane/devices/execution_config.py index 4d8c4f4586a..f26660c468c 100644 --- a/pennylane/devices/execution_config.py +++ b/pennylane/devices/execution_config.py @@ -15,7 +15,7 @@ Contains the :class:`ExecutionConfig` data class. """ from dataclasses import dataclass -from typing import Optional +from typing import Optional, Union from pennylane.workflow import SUPPORTED_INTERFACES @@ -83,7 +83,7 @@ class ExecutionConfig: derivative_order: int = 1 """The derivative order to compute while evaluating a gradient""" - mcm_config: MCMConfig = MCMConfig() + mcm_config: MCMConfig = Union[MCMConfig(), dict] """Configuration options for handling mid-circuit measurements""" def __post_init__(self): From 9dde1ee7cae91152d7075571a3b697a1fd4533d4 Mon Sep 17 00:00:00 2001 From: Mudit Pandey Date: Wed, 29 May 2024 17:01:44 -0400 Subject: [PATCH 35/48] Fix execution config; add docs per code review --- doc/introduction/measurements.rst | 19 ++++++++++++------- doc/releases/changelog-dev.md | 3 +-- pennylane/devices/execution_config.py | 2 +- 3 files changed, 14 insertions(+), 10 deletions(-) diff --git a/doc/introduction/measurements.rst b/doc/introduction/measurements.rst index 808eaa49519..110cd3805bf 100644 --- a/doc/introduction/measurements.rst +++ b/doc/introduction/measurements.rst @@ -549,21 +549,25 @@ PennyLane. For ease of use, we provide the following configuration options to us samples will be considered valid. The default behaviour is ``postselect_mode="hw-like"``. .. code-block:: python3 - + import pennylane as qml import numpy as np - + dev = qml.device("default.qubit", wires=3, shots=10) - - @qml.qnode(dev, mcm_method="one-shot", postselect_mode="fill-shots") + def circuit(x): qml.RX(x, 0) m0 = qml.measure(0, postselect=1) qml.CNOT([0, 1]) return qml.sample(qml.PauliZ(0)) - - >>> circuit(np.pi / 2) + + fill_shots_qnode = qml.QNode(circuit, dev, mcm_method="one-shot", postselect_mode="fill-shots") + hw_like_qnode = qml.QNode(circuit, dev, mcm_method="one-shot", postselect_mode="hw-like") + + >>> fill_shots_qnode(np.pi / 2) array([-1., -1., -1., -1., -1., -1., -1., -1., -1., -1.]) + >>> hw_like_qnode(np.pi / 2) + array([-1., -1., -1., -1., -1., -1., -1.]) .. note:: @@ -571,7 +575,8 @@ PennyLane. For ease of use, we provide the following configuration options to us chosen ``mcm_method``. * If ``mcm_method="one-shot"``, invalid shots will not be discarded. Instead, invalid samples will be replaced - by ``np.iinfo(np.int32).min``. These invalid samples will not be used for processing final results. Consider + by ``np.iinfo(np.int32).min``. These invalid samples will not be used for processing final results (like + expectation values), but will appear in the ``QNode`` output if samples are requested directly. Consider the circuit below: .. code-block:: python3 diff --git a/doc/releases/changelog-dev.md b/doc/releases/changelog-dev.md index 10ccfd8a267..ebd6af7aedc 100644 --- a/doc/releases/changelog-dev.md +++ b/doc/releases/changelog-dev.md @@ -10,8 +10,7 @@ [(#5679)](https://github.com/PennyLaneAI/pennylane/pull/5679) * `postselect_mode="hw-like"` will indicate to devices to discard invalid shots when postselecting - mid-circuit measurements. Use `postselect_mode="fill-shots"` to return invalid shots, which will be replaced - by dummy values. + mid-circuit measurements. Use `postselect_mode="fill-shots"` to treat all shots as valid. * `mcm_method` will indicate which strategy to use for running circuits with mid-circuit measurements. Use `mcm_method="deferred"` to use the deferred measurements principle, or `mcm_method="one-shot"` to execute once for each shot. diff --git a/pennylane/devices/execution_config.py b/pennylane/devices/execution_config.py index f26660c468c..4dab2a83cfb 100644 --- a/pennylane/devices/execution_config.py +++ b/pennylane/devices/execution_config.py @@ -83,7 +83,7 @@ class ExecutionConfig: derivative_order: int = 1 """The derivative order to compute while evaluating a gradient""" - mcm_config: MCMConfig = Union[MCMConfig(), dict] + mcm_config: Union[MCMConfig, dict] = MCMConfig() """Configuration options for handling mid-circuit measurements""" def __post_init__(self): From 88d98342737f460d06610f14fe9d8e7b8df5431e Mon Sep 17 00:00:00 2001 From: Mudit Pandey Date: Thu, 30 May 2024 14:56:15 -0400 Subject: [PATCH 36/48] Addressing code review; replacing some warnings with errors --- doc/introduction/measurements.rst | 9 +++-- doc/releases/changelog-dev.md | 6 ++- pennylane/devices/__init__.py | 3 +- pennylane/devices/execution_config.py | 5 ++- pennylane/devices/preprocess.py | 9 ++--- pennylane/devices/qubit/apply_operation.py | 33 ++++++++++++---- pennylane/devices/qubit/simulate.py | 17 ++++---- pennylane/workflow/execution.py | 8 ---- pennylane/workflow/qnode.py | 43 ++++++++++++--------- tests/devices/qubit/test_simulate.py | 2 +- tests/devices/test_preprocess.py | 2 +- tests/test_qnode.py | 22 ++++++++--- tests/transforms/test_defer_measurements.py | 4 +- tests/transforms/test_dynamic_one_shot.py | 2 +- 14 files changed, 96 insertions(+), 69 deletions(-) diff --git a/doc/introduction/measurements.rst b/doc/introduction/measurements.rst index 110cd3805bf..872f71419ad 100644 --- a/doc/introduction/measurements.rst +++ b/doc/introduction/measurements.rst @@ -539,14 +539,15 @@ PennyLane. For ease of use, we provide the following configuration options to us .. warning:: If the ``mcm_method`` argument is provided, the :func:`~pennylane.defer_measurements` or - :func:`~pennylane.dynamic_one_shot` transforms must not be applied directly to a :class:`~pennylane.QNode` + :func:`~pennylane.dynamic_one_shot` transforms must not be applied directly to the :class:`~pennylane.QNode` as it can lead to incorrect behaviour. * ``postselect_mode``: To configure how invalid shots are handled when postselecting mid-circuit measurements with finite-shot circuits. Use ``postselect_mode="hw-like"`` to discard invalid samples. In this case, the number - of samples that are used for processing results will be less than or equal to the total number of shots. If - ``postselect_mode="fill-shots"`` is used, then the postselected value will be picked unconditionally, and all - samples will be considered valid. The default behaviour is ``postselect_mode="hw-like"``. + of samples that are used for processing results can be less than the total number of shots. If + ``postselect_mode="fill-shots"`` is used, then the postselected value will be sampled unconditionally, and all + samples will be valid. This is equivalent to sampling until the number of valid samples matches the total number + of shots. The default behaviour is ``postselect_mode="hw-like"``. .. code-block:: python3 diff --git a/doc/releases/changelog-dev.md b/doc/releases/changelog-dev.md index ebd6af7aedc..a091ae173a6 100644 --- a/doc/releases/changelog-dev.md +++ b/doc/releases/changelog-dev.md @@ -4,13 +4,15 @@

New features since last release

-* `qml.QNode` and `qml.qnode` now accept two new keyword arguments: `postselect_mode` and `mcm_method`. +* `qml.QNode` and `qml.qnode` now accept two new keyword arguments: `postselect_mode` and `mcm_method`. These keyword arguments can be used to configure how the device should behave when running circuits with mid-circuit measurements. [(#5679)](https://github.com/PennyLaneAI/pennylane/pull/5679) * `postselect_mode="hw-like"` will indicate to devices to discard invalid shots when postselecting - mid-circuit measurements. Use `postselect_mode="fill-shots"` to treat all shots as valid. + mid-circuit measurements. Use `postselect_mode="fill-shots"` to unconditionally sample the postselected + value, thus making all samples valid. This is equivalent to sampling until the number of valid samples + matches the total number of shots. * `mcm_method` will indicate which strategy to use for running circuits with mid-circuit measurements. Use `mcm_method="deferred"` to use the deferred measurements principle, or `mcm_method="one-shot"` to execute once for each shot. diff --git a/pennylane/devices/__init__.py b/pennylane/devices/__init__.py index 0ef6d6dd7b5..137619584ae 100644 --- a/pennylane/devices/__init__.py +++ b/pennylane/devices/__init__.py @@ -53,6 +53,7 @@ :toctree: api ExecutionConfig + MCMConfig Device DefaultQubit NullQubit @@ -145,7 +146,7 @@ def execute(self, circuits, execution_config = qml.devices.DefaultExecutionConfi """ -from .execution_config import ExecutionConfig, DefaultExecutionConfig +from .execution_config import ExecutionConfig, DefaultExecutionConfig, MCMConfig from .device_api import Device from .default_qubit import DefaultQubit diff --git a/pennylane/devices/execution_config.py b/pennylane/devices/execution_config.py index 4dab2a83cfb..70db91e8b88 100644 --- a/pennylane/devices/execution_config.py +++ b/pennylane/devices/execution_config.py @@ -27,9 +27,10 @@ class MCMConfig: mcm_method: Optional[str] = None """Which mid-circuit measurement strategy to use. Use ``deferred`` for the deferred measurements principle and "one-shot" if using finite shots to execute the circuit - for each shot separately.""" + for each shot separately. If not specified, the device will decide which method to + use.""" - postselect_mode: Optional[str] = None + postselect_mode: str = "hw-like" """Configuration for handling shots with mid-circuit measurement postselection. If ``"hw-like"``, invalid shots will be discarded and only results for valid shots will be returned. If ``"fill-shots"``, results corresponding to the original number of diff --git a/pennylane/devices/preprocess.py b/pennylane/devices/preprocess.py index 75277eaa76a..23fed7614e6 100644 --- a/pennylane/devices/preprocess.py +++ b/pennylane/devices/preprocess.py @@ -158,13 +158,10 @@ def mid_circuit_measurements( if isinstance(mcm_config, dict): mcm_config = MCMConfig(**mcm_config) mcm_method = mcm_config.mcm_method + if mcm_method is None: + mcm_method = "one-shot" if tape.shots else "deferred" - if mcm_method is not None: - if mcm_method == "one-shot": - return qml.dynamic_one_shot(tape) - return qml.defer_measurements(tape, device=device) - - if tape.shots: + if mcm_method == "one-shot": return qml.dynamic_one_shot(tape) return qml.defer_measurements(tape, device=device) diff --git a/pennylane/devices/qubit/apply_operation.py b/pennylane/devices/qubit/apply_operation.py index 80d636393fd..dff2e530fb6 100644 --- a/pennylane/devices/qubit/apply_operation.py +++ b/pennylane/devices/qubit/apply_operation.py @@ -165,7 +165,18 @@ def apply_operation( is_state_batched (bool): Boolean representing whether the state is batched or not debugger (_Debugger): The debugger to use **execution_kwargs (Optional[dict]): Optional keyword arguments needed for applying - some operations + some operations described below. + + Keyword Arguments: + mid_measurements (dict, None): Mid-circuit measurement dictionary mutated to record the sampled value + interface (str): The machine learning interface of the state + postselect_mode (str): Configuration for handling shots with mid-circuit measurement + postselection. Use ``"hw-like"`` to discard invalid shots and ``"fill-shots"`` to + keep the same number of shots. ``"hw-like"`` by default. + rng (Optional[numpy.random._generator.Generator]): A NumPy random number generator. + prng_key (Optional[jax.random.PRNGKey]): An optional ``jax.random.PRNGKey``. This is + the key to the JAX pseudo random number generator. Only for simulation using JAX. + If None, a ``numpy.random.default_rng`` will be used for sampling. Returns: ndarray: output state @@ -233,6 +244,11 @@ def apply_conditional( is_state_batched (bool): Boolean representing whether the state is batched or not debugger (_Debugger): The debugger to use mid_measurements (dict, None): Mid-circuit measurement dictionary mutated to record the sampled value + interface (str): The machine learning interface of the state + rng (Optional[numpy.random._generator.Generator]): A NumPy random number generator. + prng_key (Optional[jax.random.PRNGKey]): An optional ``jax.random.PRNGKey``. This is + the key to the JAX pseudo random number generator. Only for simulation using JAX. + If None, a ``numpy.random.default_rng`` will be used for sampling. Returns: ndarray: output state @@ -286,11 +302,11 @@ def apply_mid_measure( mid_measurements (dict, None): Mid-circuit measurement dictionary mutated to record the sampled value postselect_mode (str): Configuration for handling shots with mid-circuit measurement postselection. Use ``"hw-like"`` to discard invalid shots and ``"fill-shots"`` to - keep the same number of shots. + keep the same number of shots. ``"hw-like"`` by default. rng (Optional[numpy.random._generator.Generator]): A NumPy random number generator. prng_key (Optional[jax.random.PRNGKey]): An optional ``jax.random.PRNGKey``. This is the key to the JAX pseudo random number generator. Only for simulation using JAX. - If None, a ``numpy.random.default_rng`` will be for sampling. + If None, a ``numpy.random.default_rng`` will be used for sampling. Returns: ndarray: output state @@ -298,20 +314,21 @@ def apply_mid_measure( mid_measurements = execution_kwargs.get("mid_measurements", None) rng = execution_kwargs.get("rng", None) prng_key = execution_kwargs.get("prng_key", None) - postselect_mode = execution_kwargs.get("postselect_mode", None) + postselect_mode = execution_kwargs.get("postselect_mode", "hw-like") if is_state_batched: raise ValueError("MidMeasureMP cannot be applied to batched states.") wire = op.wires - axis = wire.toarray()[0] - slices = [slice(None)] * qml.math.ndim(state) - slices[axis] = 0 - prob0 = qml.math.norm(state[tuple(slices)]) ** 2 interface = qml.math.get_deep_interface(state) if postselect_mode == "fill-shots" and op.postselect is not None: sample = op.postselect else: + axis = wire.toarray()[0] + slices = [slice(None)] * qml.math.ndim(state) + slices[axis] = 0 + prob0 = qml.math.norm(state[tuple(slices)]) ** 2 + if prng_key is not None: # pylint: disable=import-outside-toplevel from jax.random import binomial diff --git a/pennylane/devices/qubit/simulate.py b/pennylane/devices/qubit/simulate.py index 1a24c0da5e1..657c5d7a497 100644 --- a/pennylane/devices/qubit/simulate.py +++ b/pennylane/devices/qubit/simulate.py @@ -52,8 +52,6 @@ class _FlexShots(qml.measurements.Shots): """Shots class that allows zero shots.""" - _frozen = False - # pylint: disable=super-init-not-called def __init__(self, shots=None): if isinstance(shots, int): @@ -79,7 +77,6 @@ def _postselection_postprocess(state, is_state_batched, shots, **execution_kwarg rng = execution_kwargs.get("rng", None) prng_key = execution_kwargs.get("prng_key", None) - postselect_mode = execution_kwargs.get("postselect_mode", None) # The floor function is being used here so that a norm very close to zero becomes exactly # equal to zero so that the state can become invalid. This way, execution can continue, and @@ -103,7 +100,7 @@ def _postselection_postprocess(state, is_state_batched, shots, **execution_kwarg postselected_shots = ( [int(binomial_fn(s, float(norm**2))) for s in shots] - if postselect_mode in (None, "hw-like") and not qml.math.is_abstract(norm) + if not qml.math.is_abstract(norm) else shots ) @@ -130,10 +127,10 @@ def get_final_state(circuit, debugger=None, **execution_kwargs): rng (Optional[numpy.random._generator.Generator]): A NumPy random number generator. prng_key (Optional[jax.random.PRNGKey]): An optional ``jax.random.PRNGKey``. This is the key to the JAX pseudo random number generator. Only for simulation using JAX. - If None, a ``numpy.random.default_rng`` will be for sampling. + If None, a ``numpy.random.default_rng`` will be used for sampling. postselect_mode (str): Configuration for handling shots with mid-circuit measurement postselection. Use ``"hw-like"`` to discard invalid shots and ``"fill-shots"`` to - keep the same number of shots. + keep the same number of shots. Default is ``"hw-like"``. Returns: Tuple[TensorLike, bool]: A tuple containing the final state of the quantum script and @@ -170,7 +167,7 @@ def get_final_state(circuit, debugger=None, **execution_kwargs): state, new_shots = _postselection_postprocess( state, is_state_batched, circuit.shots, prng_key=key, **execution_kwargs ) - circuit._shots = circuit._shots = new_shots + circuit._shots = new_shots # new state is batched if i) the old state is batched, or ii) the new op adds a batch dim is_state_batched = is_state_batched or (op.batch_size is not None) @@ -201,7 +198,7 @@ def measure_final_state(circuit, state, is_state_batched, **execution_kwargs) -> prng_key (Optional[jax.random.PRNGKey]): An optional ``jax.random.PRNGKey``. This is the key to the JAX pseudo random number generator. Only for simulation using JAX. If None, the default ``sample_state`` function and a ``numpy.random.default_rng`` - will be for sampling. + will be used for sampling. mid_measurements (None, dict): Dictionary of mid-circuit measurements Returns: @@ -269,7 +266,7 @@ def simulate( interface (str): The machine learning interface to create the initial state with postselect_mode (str): Configuration for handling shots with mid-circuit measurement postselection. Use ``"hw-like"`` to discard invalid shots and ``"fill-shots"`` to - keep the same number of shots. + keep the same number of shots. Default is ``"hw-like"``. Returns: tuple(TensorLike): The results of the simulation @@ -343,7 +340,7 @@ def simulate_one_shot_native_mcm( interface (str): The machine learning interface to create the initial state with postselect_mode (str): Configuration for handling shots with mid-circuit measurement postselection. Use ``"hw-like"`` to discard invalid shots and ``"fill-shots"`` to - keep the same number of shots. + keep the same number of shots. Default is ``"hw-like"``. Returns: tuple(TensorLike): The results of the simulation diff --git a/pennylane/workflow/execution.py b/pennylane/workflow/execution.py index ffd9e47f87a..8a3e5cae0d5 100644 --- a/pennylane/workflow/execution.py +++ b/pennylane/workflow/execution.py @@ -550,14 +550,6 @@ def cost_fn(params, x): gradient_fn, grad_on_execution, interface, device, device_vjp, mcm_config ) - if any(not tape.shots for tape in tapes) and config.mcm_config.mcm_method == "one-shot": - warnings.warn( - "Cannot use the 'one-shot' method for mid-circuit measurements with " - "analytic mode. Using deferred measurements.", - UserWarning, - ) - config.mcm_config.mcm_method = "deferred" - if transform_program is None: if isinstance(device, qml.devices.Device): transform_program = device.preprocess(config)[0] diff --git a/pennylane/workflow/qnode.py b/pennylane/workflow/qnode.py index 9f461a831e6..ef3a32dbfdd 100644 --- a/pennylane/workflow/qnode.py +++ b/pennylane/workflow/qnode.py @@ -68,20 +68,19 @@ def _make_execution_config( _gradient_method = diff_method else: _gradient_method = "gradient-transform" - grad_on_execution = getattr(circuit, "execute_kwargs", {}).get("grad_on_execution") + execute_kwargs = getattr(circuit, "execute_kwargs", {}) + grad_on_execution = execute_kwargs.get("grad_on_execution") if getattr(circuit, "interface", "") == "jax": grad_on_execution = False elif grad_on_execution == "best": grad_on_execution = None - mcm_config = getattr(circuit, "execute_kwargs", {}).get("mcm_config", {}) + mcm_config = execute_kwargs.get("mcm_config", {}) return qml.devices.ExecutionConfig( interface=getattr(circuit, "interface", None), gradient_method=_gradient_method, grad_on_execution=grad_on_execution, - use_device_jacobian_product=getattr(circuit, "execute_kwargs", {"device_vjp": False})[ - "device_vjp" - ], + use_device_jacobian_product=execute_kwargs.get("device_vjp", False), mcm_config=mcm_config, ) @@ -219,14 +218,15 @@ class QNode: (classical) computational overhead during the backwards pass. device_vjp (bool): Whether or not to use the device-provided Vector Jacobian Product (VJP). A value of ``None`` indicates to use it if the device provides it, but use the full jacobian otherwise. - postselect_mode (str): Configuration for handling shots with mid-circuit measurement - postselection. If ``"hw-like"``, invalid shots will be discarded and only results for valid shots will be returned. If - ``"fill-shots"``, results corresponding to the original number of shots will be returned. For usage details, - please refer to the :doc:`main measurements page `. + postselect_mode (str): Configuration for handling shots with mid-circuit measurement postselection. If + ``"hw-like"``, invalid shots will be discarded and only results for valid shots will be returned. + If ``"fill-shots"``, results corresponding to the original number of shots will be returned. The + default is ``"hw-like"``. For usage details, please refer to the + :doc:`main measurements page `. mcm_method (str): Strategy to use when executing circuits with mid-circuit measurements. Use ``"deferred"`` - to execute using the deferred measurements principle (applied using the - :func:`~pennylane.defer_measurements` transform), or ``"one-shot"`` if using finite shots to execute the - circuit for each shot separately. For usage details, please refer to the + to apply the deferred measurements principle (using the :func:`~pennylane.defer_measurements` transform), + or ``"one-shot"`` if using finite shots to execute the circuit for each shot separately. If not provided, + the device will determine the best choice automatically. For usage details, please refer to the :doc:`main measurements page `. Keyword Args: @@ -451,7 +451,7 @@ def __init__( cachesize=10000, max_diff=1, device_vjp=False, - postselect_mode=None, + postselect_mode="hw-like", mcm_method=None, **gradient_kwargs, ): @@ -519,6 +519,12 @@ def __init__( self.max_expansion = max_expansion cache = (max_diff > 1) if cache == "auto" else cache + postselect_mode = postselect_mode or "hw-like" + if mcm_method not in ("deferred", "one-shot", None): + raise ValueError(f"Invalid mid-circuit measurements method '{mcm_method}'.") + if postselect_mode not in ("hw-like", "fill-shots"): + raise ValueError(f"Invalid postselection mode '{postselect_mode}'.") + # execution keyword arguments self.execute_kwargs = { "grad_on_execution": grad_on_execution, @@ -1022,14 +1028,14 @@ def _execution_component(self, args: tuple, kwargs: dict, override_shots) -> qml ) self._tape_cached = using_custom_cache and self.tape.hash in cache - mcm_method = self.execute_kwargs["mcm_config"].get("mcm_method", None) - if mcm_method not in ("deferred", "one-shot", None): + finite_shots = _get_device_shots if override_shots is False else override_shots + if not finite_shots and self.execute_kwargs["mcm_config"]["mcm_method"] == "one-shot": warnings.warn( - f"Invalid mid-circuit measurements method '{mcm_method}'. Automatically " - "detecting optimal method.", + "Cannot use the 'one-shot' method for mid-circuit measurements with " + "analytic mode. Using deferred measurements.", UserWarning, ) - mcm_method = self.execute_kwargs["mcm_config"]["mcm_method"] = None + self.execute_kwargs["mcm_config"]["mcm_method"] = "deferred" # Add the device program to the QNode program if isinstance(self.device, qml.devices.Device): @@ -1054,7 +1060,6 @@ def _execution_component(self, args: tuple, kwargs: dict, override_shots) -> qml elif hasattr(self.device, "capabilities"): full_transform_program.add_transform( qml.defer_measurements, - postselect_mode=self.execute_kwargs["mcm_config"]["postselect_mode"], device=self.device, ) diff --git a/tests/devices/qubit/test_simulate.py b/tests/devices/qubit/test_simulate.py index 861d938000a..634c8fed2e5 100644 --- a/tests/devices/qubit/test_simulate.py +++ b/tests/devices/qubit/test_simulate.py @@ -428,7 +428,7 @@ def test_init_with_zero_shots(self, shots, expected_shot_vector): assert shot_vector == expected_shot_vector def test_init_with_other_shots(self): - """Test that a new _FlexShots object is not created if the input is a Shots object.""" + """Test that a new _FlexShots object is not created if the input is a _FlexShots object.""" shots = _FlexShots(10) new_shots = _FlexShots(shots) assert new_shots is shots diff --git a/tests/devices/test_preprocess.py b/tests/devices/test_preprocess.py index b0ceecb5839..f6f6aab592a 100644 --- a/tests/devices/test_preprocess.py +++ b/tests/devices/test_preprocess.py @@ -455,7 +455,7 @@ class TestMidCircuitMeasurements: ], ) def test_mcm_method(self, mcm_method, shots, expected_transform, mocker): - """Test that the transform adheres to the specified transform""" + """Test that the preprocessing transform adheres to the specified transform""" dev = qml.device("default.qubit") mcm_config = {"postselect_mode": None, "mcm_method": mcm_method} tape = QuantumScript([qml.measurements.MidMeasureMP(0)], [], shots=shots) diff --git a/tests/test_qnode.py b/tests/test_qnode.py index 3b987c3cce3..2c2184f442b 100644 --- a/tests/test_qnode.py +++ b/tests/test_qnode.py @@ -1724,19 +1724,31 @@ def f(x): assert spy.call_count != 0 one_shot_spy.assert_not_called() - def test_invalid_mcm_method_warning(self): - """Test that a warning is raised if the requested mcm_method is invalid""" + def test_invalid_mcm_method_error(self): + """Test that an error is raised if the requested mcm_method is invalid""" shots = 100 dev = qml.device("default.qubit", wires=3, shots=shots) - @qml.qnode(dev, mcm_method="foo") def f(x): qml.RX(x, 0) _ = qml.measure(0, postselect=1) return qml.sample(wires=[0, 1]) - with pytest.warns(UserWarning, match="Invalid mid-circuit measurements method 'foo'"): - _ = f(1.8) + with pytest.raises(ValueError, match="Invalid mid-circuit measurements method 'foo'"): + _ = qml.QNode(f, dev, mcm_method="foo") + + def test_invalid_postselect_mode_error(self): + """Test that an error is raised if the requested postselect_mode is invalid""" + shots = 100 + dev = qml.device("default.qubit", wires=3, shots=shots) + + def f(x): + qml.RX(x, 0) + _ = qml.measure(0, postselect=1) + return qml.sample(wires=[0, 1]) + + with pytest.raises(ValueError, match="Invalid postselection mode 'foo'"): + _ = qml.QNode(f, dev, postselect_mode="foo") class TestTapeExpansion: diff --git a/tests/transforms/test_defer_measurements.py b/tests/transforms/test_defer_measurements.py index 85f0bc43202..4f22cc5166a 100644 --- a/tests/transforms/test_defer_measurements.py +++ b/tests/transforms/test_defer_measurements.py @@ -109,13 +109,14 @@ def circ(): def test_postselect_mode(postselect_mode, mocker): """Test that invalid shots are discarded if requested""" shots = 100 + postselect_value = 1 dev = qml.device("default.qubit", shots=shots) spy = mocker.spy(qml.defer_measurements, "_transform") @qml.qnode(dev, postselect_mode=postselect_mode, mcm_method="deferred") def f(x): qml.RX(x, 0) - _ = qml.measure(0, postselect=1) + _ = qml.measure(0, postselect=postselect_value) return qml.sample(wires=[0, 1]) res = f(np.pi / 4) @@ -125,6 +126,7 @@ def f(x): assert len(res) < shots else: assert len(res) == shots + assert np.allclose(res, postselect_value) @pytest.mark.parametrize( diff --git a/tests/transforms/test_dynamic_one_shot.py b/tests/transforms/test_dynamic_one_shot.py index 18caefff1ff..6c908840359 100644 --- a/tests/transforms/test_dynamic_one_shot.py +++ b/tests/transforms/test_dynamic_one_shot.py @@ -81,7 +81,7 @@ def f(x): assert len(res) < shots else: assert len(res) == shots - assert np.all(res != np.iinfo(np.int32).min) + assert np.all(res != np.iinfo(np.int32).min) @pytest.mark.jax From b809cd2ef9fa8e34b606cb40a172ce06e2e3a115 Mon Sep 17 00:00:00 2001 From: Mudit Pandey Date: Thu, 30 May 2024 16:07:58 -0400 Subject: [PATCH 37/48] Fixed defer_measurements fill-shots; added functionality for jax --- pennylane/devices/qubit/simulate.py | 7 ++-- pennylane/workflow/execution.py | 5 +++ .../experimental/test_execution_config.py | 2 +- tests/interfaces/test_jacobian_products.py | 4 +-- tests/test_qnode.py | 33 +++++++++++++++++++ tests/transforms/test_defer_measurements.py | 2 +- 6 files changed, 46 insertions(+), 7 deletions(-) diff --git a/pennylane/devices/qubit/simulate.py b/pennylane/devices/qubit/simulate.py index 657c5d7a497..c89875b4937 100644 --- a/pennylane/devices/qubit/simulate.py +++ b/pennylane/devices/qubit/simulate.py @@ -77,6 +77,7 @@ def _postselection_postprocess(state, is_state_batched, shots, **execution_kwarg rng = execution_kwargs.get("rng", None) prng_key = execution_kwargs.get("prng_key", None) + postselect_mode = execution_kwargs.get("postselect_mode", "hw-like") # The floor function is being used here so that a norm very close to zero becomes exactly # equal to zero so that the state can become invalid. This way, execution can continue, and @@ -99,9 +100,9 @@ def _postselection_postprocess(state, is_state_batched, shots, **execution_kwarg binomial_fn = np.random.binomial if rng is None else rng.binomial postselected_shots = ( - [int(binomial_fn(s, float(norm**2))) for s in shots] - if not qml.math.is_abstract(norm) - else shots + shots + if postselect_mode == "fill-shots" or qml.math.is_abstract(norm) + else [int(binomial_fn(s, float(norm**2))) for s in shots] ) # _FlexShots is used here since the binomial distribution could result in zero diff --git a/pennylane/workflow/execution.py b/pennylane/workflow/execution.py index 8a3e5cae0d5..cdfc37e8b4b 100644 --- a/pennylane/workflow/execution.py +++ b/pennylane/workflow/execution.py @@ -550,6 +550,11 @@ def cost_fn(params, x): gradient_fn, grad_on_execution, interface, device, device_vjp, mcm_config ) + if interface == "jax-jit" and config.mcm_config.mcm_method == "deferred": + # This is a current limitation of defer_measurements. "hw-like" behaviour is + # not yet accessible. + config.mcm_config.postselect_mode = "fill-shots" + if transform_program is None: if isinstance(device, qml.devices.Device): transform_program = device.preprocess(config)[0] diff --git a/tests/devices/experimental/test_execution_config.py b/tests/devices/experimental/test_execution_config.py index 9352b582734..b273b4823b6 100644 --- a/tests/devices/experimental/test_execution_config.py +++ b/tests/devices/experimental/test_execution_config.py @@ -36,7 +36,7 @@ def test_default_values(): def test_mcm_config_default_values(): """Test that the default values of MCMConfig are correct""" mcm_config = MCMConfig() - assert mcm_config.postselect_mode is None + assert mcm_config.postselect_mode == "hw-like" assert mcm_config.mcm_method is None diff --git a/tests/interfaces/test_jacobian_products.py b/tests/interfaces/test_jacobian_products.py index 07f3ccb60ab..58ffdd05ee7 100644 --- a/tests/interfaces/test_jacobian_products.py +++ b/tests/interfaces/test_jacobian_products.py @@ -152,7 +152,7 @@ def test_device_jacobians_repr(self): r" use_device_jacobian_product=None," r" gradient_method='adjoint', gradient_keyword_arguments={}," r" device_options={}, interface=None, derivative_order=1," - r" mcm_config=MCMConfig(mcm_method=None, postselect_mode=None))>" + r" mcm_config=MCMConfig(mcm_method=None, postselect_mode='hw-like'))>" ) assert repr(jpc) == expected @@ -171,7 +171,7 @@ def test_device_jacobian_products_repr(self): r" use_device_jacobian_product=None," r" gradient_method='adjoint', gradient_keyword_arguments={}, device_options={}," r" interface=None, derivative_order=1," - r" mcm_config=MCMConfig(mcm_method=None, postselect_mode=None))>" + r" mcm_config=MCMConfig(mcm_method=None, postselect_mode='hw-like'))>" ) assert repr(jpc) == expected diff --git a/tests/test_qnode.py b/tests/test_qnode.py index 2c2184f442b..94a5a643805 100644 --- a/tests/test_qnode.py +++ b/tests/test_qnode.py @@ -1750,6 +1750,39 @@ def f(x): with pytest.raises(ValueError, match="Invalid postselection mode 'foo'"): _ = qml.QNode(f, dev, postselect_mode="foo") + @pytest.mark.jax + @pytest.mark.parametrize("diff_method", [None, "best"]) + def test_defer_measurements_hw_like_with_jit(self, diff_method, mocker): + """Test that using mcm_method="deferred" with postselect_mode="hw-like" defaults + to behaviour like postselect_mode="fill-shots" when using jax jit.""" + import jax # pylint: disable=import-outside-toplevel + + shots = 100 + postselect = 1 + param = jax.numpy.array(np.pi / 2) + spy = mocker.spy(qml.defer_measurements, "_transform") + spy_one_shot = mocker.spy(qml.dynamic_one_shot, "_transform") + + dev = qml.device("default.qubit", wires=4, shots=shots, seed=jax.random.PRNGKey(123)) + + @qml.qnode(dev, diff_method=diff_method, postselect_mode="hw-like", mcm_method="deferred") + def f(x): + qml.RX(x, 0) + qml.measure(0, postselect=postselect) + return qml.sample(wires=0) + + f_jit = jax.jit(f) + res = f(param) + res_jit = f_jit(param) + + assert spy.call_count > 0 + spy_one_shot.assert_not_called() + + assert len(res) < shots + assert len(res_jit) == shots + assert qml.math.allclose(res, postselect) + assert qml.math.allclose(res_jit, postselect) + class TestTapeExpansion: """Test that tape expansion within the QNode works correctly""" diff --git a/tests/transforms/test_defer_measurements.py b/tests/transforms/test_defer_measurements.py index 4f22cc5166a..e1c7353e024 100644 --- a/tests/transforms/test_defer_measurements.py +++ b/tests/transforms/test_defer_measurements.py @@ -117,7 +117,7 @@ def test_postselect_mode(postselect_mode, mocker): def f(x): qml.RX(x, 0) _ = qml.measure(0, postselect=postselect_value) - return qml.sample(wires=[0, 1]) + return qml.sample(wires=[0]) res = f(np.pi / 4) spy.assert_called_once() From 85258d45f25a41cc7548e2ffa773e97b2cc1da75 Mon Sep 17 00:00:00 2001 From: Mudit Pandey Date: Mon, 3 Jun 2024 11:25:13 -0400 Subject: [PATCH 38/48] Update old device API MCM config support --- pennylane/_qubit_device.py | 7 ++- pennylane/workflow/execution.py | 10 +++- pennylane/workflow/qnode.py | 1 + tests/devices/test_default_qubit_legacy.py | 30 ----------- tests/test_qubit_device.py | 60 ++++++++++++++++++++++ 5 files changed, 75 insertions(+), 33 deletions(-) diff --git a/pennylane/_qubit_device.py b/pennylane/_qubit_device.py index aa8cad8ca87..7239a250215 100644 --- a/pennylane/_qubit_device.py +++ b/pennylane/_qubit_device.py @@ -468,7 +468,7 @@ def _multi_meas_with_counts_shot_vec(self, circuit: QuantumTape, shot_tuple, r): return new_r - def batch_execute(self, circuits): + def batch_execute(self, circuits, **kwargs): """Execute a batch of quantum circuits on the device. The circuits are represented by tapes, and they are executed one-by-one using the @@ -492,13 +492,16 @@ def batch_execute(self, circuits): ), ) + if self.capabilities().get("supports_mid_measure", False): + kwargs.setdefault("postselect_mode", "hw-like") + results = [] for circuit in circuits: # we need to reset the device here, else it will # not start the next computation in the zero state self.reset() - res = self.execute(circuit) + res = self.execute(circuit, **kwargs) results.append(res) if self.tracker.active: diff --git a/pennylane/workflow/execution.py b/pennylane/workflow/execution.py index cdfc37e8b4b..07071f1557f 100644 --- a/pennylane/workflow/execution.py +++ b/pennylane/workflow/execution.py @@ -267,7 +267,15 @@ def _make_inner_execute( """ if isinstance(device, qml.devices.LegacyDevice): - device_execution = set_shots(device, override_shots)(device.batch_execute) + dev_execute = ( + device.batch_execute + if execution_config is None + else partial( + device.batch_execute, + postselect_mode=execution_config.mcm_config.postselect_mode, + ) + ) + device_execution = set_shots(device, override_shots)(dev_execute) else: device_execution = partial(device.execute, execution_config=execution_config) diff --git a/pennylane/workflow/qnode.py b/pennylane/workflow/qnode.py index 942b69555a2..6296f992fd2 100644 --- a/pennylane/workflow/qnode.py +++ b/pennylane/workflow/qnode.py @@ -1064,6 +1064,7 @@ def _execution_component(self, args: tuple, kwargs: dict, override_shots) -> qml device=self.device, mcm_config=self.execute_kwargs["mcm_config"], ) + override_shots = 1 elif hasattr(self.device, "capabilities"): full_transform_program.add_transform( qml.defer_measurements, diff --git a/tests/devices/test_default_qubit_legacy.py b/tests/devices/test_default_qubit_legacy.py index a5b4df81cec..d3059b8ab38 100644 --- a/tests/devices/test_default_qubit_legacy.py +++ b/tests/devices/test_default_qubit_legacy.py @@ -17,7 +17,6 @@ # pylint: disable=too-many-arguments,too-few-public-methods # pylint: disable=protected-access,cell-var-from-loop import cmath -import copy import math from functools import partial @@ -89,35 +88,6 @@ VARPHI = np.linspace(0.02, 1, 3) -def test_qnode_native_mcm(mocker): - """Tests that the legacy devices may support native MCM execution via the dynamic_one_shot transform.""" - - class MCMDevice(DefaultQubitLegacy): - def apply(self, *args, **kwargs): - for op in args[0]: - if isinstance(op, qml.measurements.MidMeasureMP): - kwargs["mid_measurements"][op] = 0 - - @classmethod - def capabilities(cls): - default_capabilities = copy.copy(DefaultQubitLegacy.capabilities()) - default_capabilities["supports_mid_measure"] = True - return default_capabilities - - dev = MCMDevice(wires=1, shots=100) - dev.operations.add("MidMeasureMP") - spy = mocker.spy(qml.dynamic_one_shot, "_transform") - - @qml.qnode(dev, interface=None, diff_method=None) - def func(): - _ = qml.measure(0) - return qml.expval(op=qml.PauliZ(0)) - - res = func() - assert spy.call_count == 1 - assert isinstance(res, float) - - def test_analytic_deprecation(): """Tests if the kwarg `analytic` is used and displays error message.""" msg = "The analytic argument has been replaced by shots=None. " diff --git a/tests/test_qubit_device.py b/tests/test_qubit_device.py index 611bae8e93f..73687090d31 100644 --- a/tests/test_qubit_device.py +++ b/tests/test_qubit_device.py @@ -14,6 +14,7 @@ """ Unit tests for the :mod:`pennylane` :class:`QubitDevice` class. """ +import copy from random import random import numpy as np @@ -1147,6 +1148,65 @@ def test_defines_correct_capabilities(self): assert capabilities == QubitDevice.capabilities() +class TestNativeMidCircuitMeasurements: + """Unit tests for mid-circuit measurements related functionality""" + + class MCMDevice(qml.devices.DefaultQubitLegacy): + def apply(self, *args, **kwargs): + for op in args[0]: + if isinstance(op, qml.measurements.MidMeasureMP): + kwargs["mid_measurements"][op] = 0 + + @classmethod + def capabilities(cls): + default_capabilities = copy.copy(qml.devices.DefaultQubitLegacy.capabilities()) + default_capabilities["supports_mid_measure"] = True + return default_capabilities + + def test_qnode_native_mcm(self, mocker): + """Tests that the legacy devices may support native MCM execution via the dynamic_one_shot transform.""" + + dev = self.MCMDevice(wires=1, shots=100) + dev.operations.add("MidMeasureMP") + spy = mocker.spy(qml.dynamic_one_shot, "_transform") + + @qml.qnode(dev, interface=None, diff_method=None) + def func(): + _ = qml.measure(0) + return qml.expval(op=qml.PauliZ(0)) + + res = func() + assert spy.call_count == 1 + assert isinstance(res, float) + + @pytest.mark.parametrize("postselect_mode", ["hw-like", "fill-shots"]) + def test_postselect_mode_propagates_to_execute(self, monkeypatch, postselect_mode): + """Test that the specified postselect mode propagates to execution as expected.""" + dev = self.MCMDevice(wires=1, shots=100) + dev.operations.add("MidMeasureMP") + pm_propagated = False + + def new_apply(*args, **kwargs): # pylint: disable=unused-argument + nonlocal pm_propagated + if kwargs.get("postselect_mode", -1) == postselect_mode: + pm_propagated = True + else: + pm_propagated = False + + @qml.qnode(dev, postselect_mode=postselect_mode) + def func(): + _ = qml.measure(0, postselect=1) + return qml.expval(op=qml.PauliZ(0)) + + with monkeypatch.context() as m: + m.setattr(dev, "apply", new_apply) + with pytest.raises(Exception): + # Error expected as mocked apply method does not adhere to expected output. + func() + + assert pm_propagated is True + + class TestExecution: """Tests for the execute method""" From e0555e59263b43d2f225ee557ec19c448c45ef14 Mon Sep 17 00:00:00 2001 From: Mudit Pandey Date: Mon, 3 Jun 2024 16:42:19 -0400 Subject: [PATCH 39/48] Default postselect_mode=None, raise error with DM+jax_jit+hw_like, docs --- doc/introduction/measurements.rst | 7 ++-- pennylane/_qubit_device.py | 2 +- pennylane/devices/execution_config.py | 4 +-- pennylane/devices/qubit/apply_operation.py | 6 ++-- pennylane/devices/qubit/simulate.py | 11 ++++--- pennylane/workflow/execution.py | 2 ++ pennylane/workflow/qnode.py | 9 +++-- tests/test_qnode.py | 38 +++++++++++++++++++--- 8 files changed, 57 insertions(+), 22 deletions(-) diff --git a/doc/introduction/measurements.rst b/doc/introduction/measurements.rst index 872f71419ad..e0958538e3c 100644 --- a/doc/introduction/measurements.rst +++ b/doc/introduction/measurements.rst @@ -600,9 +600,10 @@ PennyLane. For ease of use, we provide the following configuration options to us -1.0000000e+00, -2.1474836e+09, -1.0000000e+00, -2.1474836e+09, -1.0000000e+00, -1.0000000e+00], dtype=float32, weak_type=True) - * If ``mcm_method="deferred"``, then using ``postselect_mode="hw-like"`` will have the same behaviour as when - ``postselect_mode="fill-shots"``. This is due to the limitations of the :func:`~pennylane.defer_measurements` - transform, and this behaviour will change in the future to be more consistent with ``mcm_method="one-shot"``. + * When using ``jax.jit``, using ``mcm_method="deferred"`` is not supported with ``postselect_mode="hw-like"``. + Therefore, the default behaviour will be to use ``postselect_mode="fill-shots"``. This is due to limitations + of the :func:`~pennylane.defer_measurements` transform, and this behaviour will change in the future to be + more consistent with ``mcm_method="one-shot"``. Changing the number of shots ---------------------------- diff --git a/pennylane/_qubit_device.py b/pennylane/_qubit_device.py index 7239a250215..963683a4107 100644 --- a/pennylane/_qubit_device.py +++ b/pennylane/_qubit_device.py @@ -493,7 +493,7 @@ def batch_execute(self, circuits, **kwargs): ) if self.capabilities().get("supports_mid_measure", False): - kwargs.setdefault("postselect_mode", "hw-like") + kwargs.setdefault("postselect_mode", None) results = [] for circuit in circuits: diff --git a/pennylane/devices/execution_config.py b/pennylane/devices/execution_config.py index 70db91e8b88..1a0788826d2 100644 --- a/pennylane/devices/execution_config.py +++ b/pennylane/devices/execution_config.py @@ -30,11 +30,11 @@ class MCMConfig: for each shot separately. If not specified, the device will decide which method to use.""" - postselect_mode: str = "hw-like" + postselect_mode: Optional[str] = None """Configuration for handling shots with mid-circuit measurement postselection. If ``"hw-like"``, invalid shots will be discarded and only results for valid shots will be returned. If ``"fill-shots"``, results corresponding to the original number of - shots will be returned.""" + shots will be returned. If not specified, the device will decide which mode to use.""" # pylint: disable=too-many-instance-attributes diff --git a/pennylane/devices/qubit/apply_operation.py b/pennylane/devices/qubit/apply_operation.py index dff2e530fb6..e5a3ed5464f 100644 --- a/pennylane/devices/qubit/apply_operation.py +++ b/pennylane/devices/qubit/apply_operation.py @@ -172,7 +172,7 @@ def apply_operation( interface (str): The machine learning interface of the state postselect_mode (str): Configuration for handling shots with mid-circuit measurement postselection. Use ``"hw-like"`` to discard invalid shots and ``"fill-shots"`` to - keep the same number of shots. ``"hw-like"`` by default. + keep the same number of shots. ``None`` by default. rng (Optional[numpy.random._generator.Generator]): A NumPy random number generator. prng_key (Optional[jax.random.PRNGKey]): An optional ``jax.random.PRNGKey``. This is the key to the JAX pseudo random number generator. Only for simulation using JAX. @@ -302,7 +302,7 @@ def apply_mid_measure( mid_measurements (dict, None): Mid-circuit measurement dictionary mutated to record the sampled value postselect_mode (str): Configuration for handling shots with mid-circuit measurement postselection. Use ``"hw-like"`` to discard invalid shots and ``"fill-shots"`` to - keep the same number of shots. ``"hw-like"`` by default. + keep the same number of shots. ``None`` by default. rng (Optional[numpy.random._generator.Generator]): A NumPy random number generator. prng_key (Optional[jax.random.PRNGKey]): An optional ``jax.random.PRNGKey``. This is the key to the JAX pseudo random number generator. Only for simulation using JAX. @@ -314,7 +314,7 @@ def apply_mid_measure( mid_measurements = execution_kwargs.get("mid_measurements", None) rng = execution_kwargs.get("rng", None) prng_key = execution_kwargs.get("prng_key", None) - postselect_mode = execution_kwargs.get("postselect_mode", "hw-like") + postselect_mode = execution_kwargs.get("postselect_mode", None) if is_state_batched: raise ValueError("MidMeasureMP cannot be applied to batched states.") diff --git a/pennylane/devices/qubit/simulate.py b/pennylane/devices/qubit/simulate.py index eb21877492e..23adc03ffe0 100644 --- a/pennylane/devices/qubit/simulate.py +++ b/pennylane/devices/qubit/simulate.py @@ -83,7 +83,10 @@ def _postselection_postprocess(state, is_state_batched, shots, **execution_kwarg rng = execution_kwargs.get("rng", None) prng_key = execution_kwargs.get("prng_key", None) - postselect_mode = execution_kwargs.get("postselect_mode", "hw-like") + postselect_mode = execution_kwargs.get("postselect_mode", None) + + if postselect_mode == "hw-like" and qml.math.is_abstract(state): + raise ValueError("Using postselect_mode='hw-like' is not supported with jax-jit.") # The floor function is being used here so that a norm very close to zero becomes exactly # equal to zero so that the state can become invalid. This way, execution can continue, and @@ -138,7 +141,7 @@ def get_final_state(circuit, debugger=None, **execution_kwargs): If None, a ``numpy.random.default_rng`` will be used for sampling. postselect_mode (str): Configuration for handling shots with mid-circuit measurement postselection. Use ``"hw-like"`` to discard invalid shots and ``"fill-shots"`` to - keep the same number of shots. Default is ``"hw-like"``. + keep the same number of shots. Default is ``None``. Returns: Tuple[TensorLike, bool]: A tuple containing the final state of the quantum script and @@ -276,7 +279,7 @@ def simulate( interface (str): The machine learning interface to create the initial state with postselect_mode (str): Configuration for handling shots with mid-circuit measurement postselection. Use ``"hw-like"`` to discard invalid shots and ``"fill-shots"`` to - keep the same number of shots. Default is ``"hw-like"``. + keep the same number of shots. Default is ``None``. Returns: tuple(TensorLike): The results of the simulation @@ -351,7 +354,7 @@ def simulate_one_shot_native_mcm( interface (str): The machine learning interface to create the initial state with postselect_mode (str): Configuration for handling shots with mid-circuit measurement postselection. Use ``"hw-like"`` to discard invalid shots and ``"fill-shots"`` to - keep the same number of shots. Default is ``"hw-like"``. + keep the same number of shots. Default is ``None``. Returns: tuple(TensorLike): The results of the simulation diff --git a/pennylane/workflow/execution.py b/pennylane/workflow/execution.py index 07071f1557f..c87ffe2600d 100644 --- a/pennylane/workflow/execution.py +++ b/pennylane/workflow/execution.py @@ -561,6 +561,8 @@ def cost_fn(params, x): if interface == "jax-jit" and config.mcm_config.mcm_method == "deferred": # This is a current limitation of defer_measurements. "hw-like" behaviour is # not yet accessible. + if config.mcm_config.postselect_mode == "hw-like": + raise ValueError("Using postselect_mode='hw-like' is not supported with jax-jit.") config.mcm_config.postselect_mode = "fill-shots" if transform_program is None: diff --git a/pennylane/workflow/qnode.py b/pennylane/workflow/qnode.py index 6296f992fd2..60b0dd13cfb 100644 --- a/pennylane/workflow/qnode.py +++ b/pennylane/workflow/qnode.py @@ -222,8 +222,8 @@ class QNode: postselect_mode (str): Configuration for handling shots with mid-circuit measurement postselection. If ``"hw-like"``, invalid shots will be discarded and only results for valid shots will be returned. If ``"fill-shots"``, results corresponding to the original number of shots will be returned. The - default is ``"hw-like"``. For usage details, please refer to the - :doc:`main measurements page `. + default is ``None``, in which case the device will automatically choose the best configuration. For + usage details, please refer to the :doc:`main measurements page `. mcm_method (str): Strategy to use when executing circuits with mid-circuit measurements. Use ``"deferred"`` to apply the deferred measurements principle (using the :func:`~pennylane.defer_measurements` transform), or ``"one-shot"`` if using finite shots to execute the circuit for each shot separately. If not provided, @@ -452,7 +452,7 @@ def __init__( cachesize=10000, max_diff=1, device_vjp=False, - postselect_mode="hw-like", + postselect_mode=None, mcm_method=None, **gradient_kwargs, ): @@ -520,10 +520,9 @@ def __init__( self.max_expansion = max_expansion cache = (max_diff > 1) if cache == "auto" else cache - postselect_mode = postselect_mode or "hw-like" if mcm_method not in ("deferred", "one-shot", None): raise ValueError(f"Invalid mid-circuit measurements method '{mcm_method}'.") - if postselect_mode not in ("hw-like", "fill-shots"): + if postselect_mode not in ("hw-like", "fill-shots", None): raise ValueError(f"Invalid postselection mode '{postselect_mode}'.") # execution keyword arguments diff --git a/tests/test_qnode.py b/tests/test_qnode.py index 94a5a643805..fd9d81a7c27 100644 --- a/tests/test_qnode.py +++ b/tests/test_qnode.py @@ -1752,9 +1752,9 @@ def f(x): @pytest.mark.jax @pytest.mark.parametrize("diff_method", [None, "best"]) - def test_defer_measurements_hw_like_with_jit(self, diff_method, mocker): - """Test that using mcm_method="deferred" with postselect_mode="hw-like" defaults - to behaviour like postselect_mode="fill-shots" when using jax jit.""" + def test_defer_measurements_with_jit(self, diff_method, mocker): + """Test that using mcm_method="deferred" defaults to behaviour like + postselect_mode="fill-shots" when using jax jit.""" import jax # pylint: disable=import-outside-toplevel shots = 100 @@ -1765,7 +1765,7 @@ def test_defer_measurements_hw_like_with_jit(self, diff_method, mocker): dev = qml.device("default.qubit", wires=4, shots=shots, seed=jax.random.PRNGKey(123)) - @qml.qnode(dev, diff_method=diff_method, postselect_mode="hw-like", mcm_method="deferred") + @qml.qnode(dev, diff_method=diff_method, mcm_method="deferred") def f(x): qml.RX(x, 0) qml.measure(0, postselect=postselect) @@ -1783,6 +1783,36 @@ def f(x): assert qml.math.allclose(res, postselect) assert qml.math.allclose(res_jit, postselect) + @pytest.mark.jax + # @pytest.mark.parametrize("diff_method", [None, "best"]) + @pytest.mark.parametrize("diff_method", ["best"]) + def test_hw_like_error_with_jit(self, diff_method): + """Test that an error is raised if attempting to use postselect_mode="hw-like" + with jax jit.""" + import jax # pylint: disable=import-outside-toplevel + + shots = 100 + postselect = 1 + param = jax.numpy.array(np.pi / 2) + + dev = qml.device("default.qubit", wires=4, shots=shots, seed=jax.random.PRNGKey(123)) + + @qml.qnode(dev, diff_method=diff_method, mcm_method="deferred", postselect_mode="hw-like") + def f(x): + qml.RX(x, 0) + qml.measure(0, postselect=postselect) + return qml.sample(wires=0) + + f_jit = jax.jit(f) + + # Checking that an error is not raised without jit + _ = f(param) + + with pytest.raises( + ValueError, match="Using postselect_mode='hw-like' is not supported with jax-jit." + ): + _ = f_jit(param) + class TestTapeExpansion: """Test that tape expansion within the QNode works correctly""" From bea44ba8cdbd0a7766084c8aebbd9c5737c54daf Mon Sep 17 00:00:00 2001 From: Mudit Pandey Date: Mon, 3 Jun 2024 17:04:28 -0400 Subject: [PATCH 40/48] Found a way to raise error with diff_method=None --- pennylane/workflow/execution.py | 62 ++++++++++++++++++++++----------- 1 file changed, 41 insertions(+), 21 deletions(-) diff --git a/pennylane/workflow/execution.py b/pennylane/workflow/execution.py index c87ffe2600d..a3f7dc845d9 100644 --- a/pennylane/workflow/execution.py +++ b/pennylane/workflow/execution.py @@ -368,6 +368,36 @@ def execution_function_with_caching(tapes): return execution_function_with_caching +def _get_interface_name(tapes, interface): + """Helper function to get the interface name of a list of tapes + + Args: + tapes (list[.QuantumScript]): Quantum tapes + interface (Optional[str]): Original interface to use as reference. + + Returns: + str: Interface name""" + if interface == "auto": + params = [] + for tape in tapes: + params.extend(tape.get_parameters(trainable_only=False)) + interface = qml.math.get_interface(*params) + if INTERFACE_MAP.get(interface, "") == "tf" and _use_tensorflow_autograph(): + interface = "tf-autograph" + if interface == "jax": + try: # pragma: no-cover + from .interfaces.jax import get_jax_interface_name + except ImportError as e: # pragma: no-cover + raise qml.QuantumFunctionError( # pragma: no-cover + "jax not found. Please install the latest " # pragma: no-cover + "version of jax to enable the 'jax' interface." # pragma: no-cover + ) from e # pragma: no-cover + + interface = get_jax_interface_name(tapes) + + return interface + + def execute( tapes: Sequence[QuantumTape], device: device_type, @@ -522,26 +552,10 @@ def cost_fn(params, x): ### Specifying and preprocessing variables #### - if interface == "auto": - params = [] - for tape in tapes: - params.extend(tape.get_parameters(trainable_only=False)) - interface = qml.math.get_interface(*params) - if INTERFACE_MAP.get(interface, "") == "tf" and _use_tensorflow_autograph(): - interface = "tf-autograph" - if interface == "jax": - try: # pragma: no-cover - from .interfaces.jax import get_jax_interface_name - except ImportError as e: # pragma: no-cover - raise qml.QuantumFunctionError( # pragma: no-cover - "jax not found. Please install the latest " # pragma: no-cover - "version of jax to enable the 'jax' interface." # pragma: no-cover - ) from e # pragma: no-cover - - interface = get_jax_interface_name(tapes) - # Only need to calculate derivatives with jax when we know it will be executed later. - if interface in {"jax", "jax-jit"}: - grad_on_execution = grad_on_execution if isinstance(gradient_fn, Callable) else False + interface = _get_interface_name(tapes, interface) + # Only need to calculate derivatives with jax when we know it will be executed later. + if interface in {"jax", "jax-jit"}: + grad_on_execution = grad_on_execution if isinstance(gradient_fn, Callable) else False if ( device_vjp @@ -558,7 +572,13 @@ def cost_fn(params, x): gradient_fn, grad_on_execution, interface, device, device_vjp, mcm_config ) - if interface == "jax-jit" and config.mcm_config.mcm_method == "deferred": + # Mid-circuit measurement configuration validation + if interface is None: + mcm_interface = "auto" + mcm_interface = _get_interface_name(tapes, mcm_interface) + else: + mcm_interface = interface + if mcm_interface == "jax-jit" and config.mcm_config.mcm_method == "deferred": # This is a current limitation of defer_measurements. "hw-like" behaviour is # not yet accessible. if config.mcm_config.postselect_mode == "hw-like": From ac78db5081ceab81162583ce23bb9c9c9c0b81f4 Mon Sep 17 00:00:00 2001 From: Mudit Pandey Date: Mon, 3 Jun 2024 17:05:39 -0400 Subject: [PATCH 41/48] Update test doc --- tests/test_qnode.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/test_qnode.py b/tests/test_qnode.py index fd9d81a7c27..647a32385fa 100644 --- a/tests/test_qnode.py +++ b/tests/test_qnode.py @@ -1786,9 +1786,9 @@ def f(x): @pytest.mark.jax # @pytest.mark.parametrize("diff_method", [None, "best"]) @pytest.mark.parametrize("diff_method", ["best"]) - def test_hw_like_error_with_jit(self, diff_method): + def test__deferred_hw_like_error_with_jit(self, diff_method): """Test that an error is raised if attempting to use postselect_mode="hw-like" - with jax jit.""" + with jax jit with mcm_method="deferred".""" import jax # pylint: disable=import-outside-toplevel shots = 100 From 62e2d4b6c24039b4ce09c0b6eb01fae349edf352 Mon Sep 17 00:00:00 2001 From: Mudit Pandey Date: Mon, 3 Jun 2024 17:43:30 -0400 Subject: [PATCH 42/48] Added additional condition to qml.execute --- pennylane/workflow/execution.py | 1 + 1 file changed, 1 insertion(+) diff --git a/pennylane/workflow/execution.py b/pennylane/workflow/execution.py index a3f7dc845d9..e38f0c5d89a 100644 --- a/pennylane/workflow/execution.py +++ b/pennylane/workflow/execution.py @@ -270,6 +270,7 @@ def _make_inner_execute( dev_execute = ( device.batch_execute if execution_config is None + or not device.capabilities().get("supports_mid_measure", False) else partial( device.batch_execute, postselect_mode=execution_config.mcm_config.postselect_mode, From cbdbd5cfed1709f60980ae134c5ac775ba6c4004 Mon Sep 17 00:00:00 2001 From: Mudit Pandey Date: Mon, 3 Jun 2024 17:46:20 -0400 Subject: [PATCH 43/48] Added dev comment --- pennylane/workflow/execution.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/pennylane/workflow/execution.py b/pennylane/workflow/execution.py index e38f0c5d89a..fbfcf382079 100644 --- a/pennylane/workflow/execution.py +++ b/pennylane/workflow/execution.py @@ -269,6 +269,8 @@ def _make_inner_execute( if isinstance(device, qml.devices.LegacyDevice): dev_execute = ( device.batch_execute + # If this condition is not met, then dev.batch_execute likely also doesn't include + # any kwargs in its signature, hence why we use partial if execution_config is None or not device.capabilities().get("supports_mid_measure", False) else partial( From 47399e2d7b3a3ca8bcde36f84edf1eda5b53c017 Mon Sep 17 00:00:00 2001 From: Mudit Pandey Date: Tue, 4 Jun 2024 14:11:45 -0400 Subject: [PATCH 44/48] [skip ci] Skip CI From 31857c8b7612fe997eed276e183d1c77e3b9d9d8 Mon Sep 17 00:00:00 2001 From: Mudit Pandey Date: Tue, 4 Jun 2024 14:46:44 -0400 Subject: [PATCH 45/48] Addressing code review; fixing tests --- pennylane/capture/capture_qnode.py | 5 ++++- pennylane/devices/execution_config.py | 11 ++++++++++ pennylane/transforms/dynamic_one_shot.py | 2 +- pennylane/workflow/execution.py | 8 ++------ pennylane/workflow/qnode.py | 7 ++----- tests/capture/test_capture_qnode.py | 3 +++ .../default_qubit/test_default_qubit.py | 7 ++++++- .../experimental/test_execution_config.py | 20 ++++++++++++++++--- tests/interfaces/test_jacobian_products.py | 4 ++-- tests/test_qnode.py | 16 +++++---------- 10 files changed, 53 insertions(+), 30 deletions(-) diff --git a/pennylane/capture/capture_qnode.py b/pennylane/capture/capture_qnode.py index bbcd731e934..41525355214 100644 --- a/pennylane/capture/capture_qnode.py +++ b/pennylane/capture/capture_qnode.py @@ -15,6 +15,7 @@ This submodule defines a capture compatible call to QNodes. """ +from copy import copy from functools import lru_cache, partial import pennylane as qml @@ -159,7 +160,9 @@ def f(x): qfunc = partial(qnode.func, **kwargs) if kwargs else qnode.func qfunc_jaxpr = jax.make_jaxpr(qfunc)(*args) - qnode_kwargs = {"diff_method": qnode.diff_method, **qnode.execute_kwargs} + execute_kwargs = copy(qnode.execute_kwargs) + mcm_config = execute_kwargs.pop("mcm_config") + qnode_kwargs = {"diff_method": qnode.diff_method, **execute_kwargs, **mcm_config} qnode_prim = _get_qnode_prim() return qnode_prim.bind( diff --git a/pennylane/devices/execution_config.py b/pennylane/devices/execution_config.py index 1a0788826d2..083f2880b6f 100644 --- a/pennylane/devices/execution_config.py +++ b/pennylane/devices/execution_config.py @@ -36,6 +36,17 @@ class MCMConfig: be returned. If ``"fill-shots"``, results corresponding to the original number of shots will be returned. If not specified, the device will decide which mode to use.""" + def __post_init__(self): + """ + Validate the configured mid-circuit measurement options. + + Note that this hook is automatically called after init via the dataclass integration. + """ + if self.mcm_method not in ("deferred", "one-shot", None): + raise ValueError(f"Invalid mid-circuit measurements method '{self.mcm_method}'.") + if self.postselect_mode not in ("hw-like", "fill-shots", None): + raise ValueError(f"Invalid postselection mode '{self.postselect_mode}'.") + # pylint: disable=too-many-instance-attributes @dataclass diff --git a/pennylane/transforms/dynamic_one_shot.py b/pennylane/transforms/dynamic_one_shot.py index 14a448eb8cd..831e30031af 100644 --- a/pennylane/transforms/dynamic_one_shot.py +++ b/pennylane/transforms/dynamic_one_shot.py @@ -293,7 +293,7 @@ def gather_non_mcm(measurement, samples, is_valid): Args: measurement (MeasurementProcess): measurement - samples (TensorLike): measurement samples + samples (TensorLike): Post-processed measurement samples is_valid (TensorLike): Boolean array with the same shape as ``samples`` where the value at each index specifies whether or not the respective sample is valid. diff --git a/pennylane/workflow/execution.py b/pennylane/workflow/execution.py index fbfcf382079..da7239ac921 100644 --- a/pennylane/workflow/execution.py +++ b/pennylane/workflow/execution.py @@ -270,7 +270,7 @@ def _make_inner_execute( dev_execute = ( device.batch_execute # If this condition is not met, then dev.batch_execute likely also doesn't include - # any kwargs in its signature, hence why we use partial + # any kwargs in its signature, hence why we use partial conditionally if execution_config is None or not device.capabilities().get("supports_mid_measure", False) else partial( @@ -576,11 +576,7 @@ def cost_fn(params, x): ) # Mid-circuit measurement configuration validation - if interface is None: - mcm_interface = "auto" - mcm_interface = _get_interface_name(tapes, mcm_interface) - else: - mcm_interface = interface + mcm_interface = _get_interface_name(tapes, "auto") if interface is None else interface if mcm_interface == "jax-jit" and config.mcm_config.mcm_method == "deferred": # This is a current limitation of defer_measurements. "hw-like" behaviour is # not yet accessible. diff --git a/pennylane/workflow/qnode.py b/pennylane/workflow/qnode.py index 87bd6c22f23..e0fd75458c4 100644 --- a/pennylane/workflow/qnode.py +++ b/pennylane/workflow/qnode.py @@ -1042,12 +1042,9 @@ def _execution_component(self, args: tuple, kwargs: dict, override_shots) -> qml finite_shots = _get_device_shots if override_shots is False else override_shots if not finite_shots and self.execute_kwargs["mcm_config"]["mcm_method"] == "one-shot": - warnings.warn( - "Cannot use the 'one-shot' method for mid-circuit measurements with " - "analytic mode. Using deferred measurements.", - UserWarning, + raise ValueError( + "Cannot use the 'one-shot' method for mid-circuit measurements with analytic mode." ) - self.execute_kwargs["mcm_config"]["mcm_method"] = "deferred" # Add the device program to the QNode program if isinstance(self.device, qml.devices.Device): diff --git a/tests/capture/test_capture_qnode.py b/tests/capture/test_capture_qnode.py index 1fe57c65e23..ec5387647d2 100644 --- a/tests/capture/test_capture_qnode.py +++ b/tests/capture/test_capture_qnode.py @@ -121,6 +121,7 @@ def circuit(x): assert eqn0.params["shots"] == qml.measurements.Shots(None) expected_kwargs = {"diff_method": "best"} expected_kwargs.update(circuit.execute_kwargs) + expected_kwargs.update(expected_kwargs.pop("mcm_config")) assert eqn0.params["qnode_kwargs"] == expected_kwargs qfunc_jaxpr = eqn0.params["qfunc_jaxpr"] @@ -294,5 +295,7 @@ def circuit(): "max_diff": 2, "max_expansion": 10, "device_vjp": False, + "mcm_method": None, + "postselect_mode": None, } assert jaxpr.eqns[0].params["qnode_kwargs"] == expected diff --git a/tests/devices/default_qubit/test_default_qubit.py b/tests/devices/default_qubit/test_default_qubit.py index b743f8fa940..4ee14716f77 100644 --- a/tests/devices/default_qubit/test_default_qubit.py +++ b/tests/devices/default_qubit/test_default_qubit.py @@ -1884,7 +1884,12 @@ def circ_expected(): @pytest.mark.parametrize( "mp, expected_shape", - [(qml.sample(wires=[0, 2]), (5, 2)), (qml.classical_shadow(wires=[0, 2]), (2, 5, 2))], + [ + (qml.sample(wires=[0, 2]), (5, 2)), + (qml.classical_shadow(wires=[0, 2]), (2, 5, 2)), + (qml.sample(wires=[0]), (5,)), + (qml.classical_shadow(wires=[0]), (2, 5, 1)), + ], ) @pytest.mark.parametrize("param", np.linspace(np.pi / 4, 3 * np.pi / 4, 3)) @pytest.mark.parametrize("shots", [10, (10, 10)]) diff --git a/tests/devices/experimental/test_execution_config.py b/tests/devices/experimental/test_execution_config.py index b273b4823b6..361712c112e 100644 --- a/tests/devices/experimental/test_execution_config.py +++ b/tests/devices/experimental/test_execution_config.py @@ -36,7 +36,7 @@ def test_default_values(): def test_mcm_config_default_values(): """Test that the default values of MCMConfig are correct""" mcm_config = MCMConfig() - assert mcm_config.postselect_mode == "hw-like" + assert mcm_config.postselect_mode is None assert mcm_config.mcm_method is None @@ -62,7 +62,7 @@ def test_invalid_grad_on_execution(): @pytest.mark.parametrize( "option", [MCMConfig(mcm_method="deferred"), {"mcm_method": "deferred"}, None] ) -def test_valid_mcm_config(option): +def test_valid_execution_config_mcm_config(option): """Test that the mcm_config attribute is set correctly""" config = ExecutionConfig(mcm_config=option) if option else ExecutionConfig() if option is None: @@ -71,8 +71,22 @@ def test_valid_mcm_config(option): assert config.mcm_config == MCMConfig(mcm_method="deferred") -def test_invalid_mcm_config(): +def test_invalid_execution_config_mcm_config(): """Test that an error is raised if mcm_config is set incorrectly""" option = "foo" with pytest.raises(ValueError, match="Got invalid type"): _ = ExecutionConfig(mcm_config=option) + + +def test_mcm_config_invalid_mcm_method(): + """Test that an error is raised if creating MCMConfig with invalid mcm_method""" + option = "foo" + with pytest.raises(ValueError, match="Invalid mid-circuit measurements method"): + _ = MCMConfig(mcm_method=option) + + +def test_mcm_config_invalid_postselect_mode(): + """Test that an error is raised if creating MCMConfig with invalid postselect_mode""" + option = "foo" + with pytest.raises(ValueError, match="Invalid postselection mode"): + _ = MCMConfig(postselect_mode=option) diff --git a/tests/interfaces/test_jacobian_products.py b/tests/interfaces/test_jacobian_products.py index 58ffdd05ee7..07f3ccb60ab 100644 --- a/tests/interfaces/test_jacobian_products.py +++ b/tests/interfaces/test_jacobian_products.py @@ -152,7 +152,7 @@ def test_device_jacobians_repr(self): r" use_device_jacobian_product=None," r" gradient_method='adjoint', gradient_keyword_arguments={}," r" device_options={}, interface=None, derivative_order=1," - r" mcm_config=MCMConfig(mcm_method=None, postselect_mode='hw-like'))>" + r" mcm_config=MCMConfig(mcm_method=None, postselect_mode=None))>" ) assert repr(jpc) == expected @@ -171,7 +171,7 @@ def test_device_jacobian_products_repr(self): r" use_device_jacobian_product=None," r" gradient_method='adjoint', gradient_keyword_arguments={}, device_options={}," r" interface=None, derivative_order=1," - r" mcm_config=MCMConfig(mcm_method=None, postselect_mode='hw-like'))>" + r" mcm_config=MCMConfig(mcm_method=None, postselect_mode=None))>" ) assert repr(jpc) == expected diff --git a/tests/test_qnode.py b/tests/test_qnode.py index 1f84a4543b3..cf840c47766 100644 --- a/tests/test_qnode.py +++ b/tests/test_qnode.py @@ -1718,11 +1718,10 @@ class TestMCMConfiguration: """Tests for MCM configuration arguments""" @pytest.mark.parametrize("dev_name", ["default.qubit", "default.qubit.legacy"]) - def test_one_shot_warning_without_shots(self, dev_name, mocker): - """Test that a warning is raised if mcm_method="one-shot" with no shots""" + def test_one_shot_error_without_shots(self, dev_name): + """Test that an error is raised if mcm_method="one-shot" with no shots""" dev = qml.device(dev_name, wires=3) - spy = mocker.spy(qml.defer_measurements, "_transform") - one_shot_spy = mocker.spy(qml.dynamic_one_shot, "_transform") + param = np.pi / 4 @qml.qnode(dev, mcm_method="one-shot") def f(x): @@ -1730,16 +1729,11 @@ def f(x): _ = qml.measure(0) return qml.probs(wires=[0, 1]) - param = np.pi / 4 - - with pytest.warns( - UserWarning, match="Cannot use the 'one-shot' method for mid-circuit measurements with" + with pytest.raises( + ValueError, match="Cannot use the 'one-shot' method for mid-circuit measurements with" ): _ = f(param) - assert spy.call_count != 0 - one_shot_spy.assert_not_called() - def test_invalid_mcm_method_error(self): """Test that an error is raised if the requested mcm_method is invalid""" shots = 100 From ef35642f72c0ce9c29011afc0dbddbfd47f117ff Mon Sep 17 00:00:00 2001 From: Mudit Pandey Date: Tue, 4 Jun 2024 15:26:57 -0400 Subject: [PATCH 46/48] Fixing code cov --- pennylane/devices/qubit/simulate.py | 3 --- pennylane/workflow/execution.py | 2 +- tests/test_qubit_device.py | 5 +---- 3 files changed, 2 insertions(+), 8 deletions(-) diff --git a/pennylane/devices/qubit/simulate.py b/pennylane/devices/qubit/simulate.py index 23adc03ffe0..b862d590e19 100644 --- a/pennylane/devices/qubit/simulate.py +++ b/pennylane/devices/qubit/simulate.py @@ -85,9 +85,6 @@ def _postselection_postprocess(state, is_state_batched, shots, **execution_kwarg prng_key = execution_kwargs.get("prng_key", None) postselect_mode = execution_kwargs.get("postselect_mode", None) - if postselect_mode == "hw-like" and qml.math.is_abstract(state): - raise ValueError("Using postselect_mode='hw-like' is not supported with jax-jit.") - # The floor function is being used here so that a norm very close to zero becomes exactly # equal to zero so that the state can become invalid. This way, execution can continue, and # bad postselection gives results that are invalid rather than results that look valid but diff --git a/pennylane/workflow/execution.py b/pennylane/workflow/execution.py index da7239ac921..9607c3c2ef3 100644 --- a/pennylane/workflow/execution.py +++ b/pennylane/workflow/execution.py @@ -389,7 +389,7 @@ def _get_interface_name(tapes, interface): interface = "tf-autograph" if interface == "jax": try: # pragma: no-cover - from .interfaces.jax import get_jax_interface_name + from .interfaces.jax import get_jax_interface_name # pragma: no-cover except ImportError as e: # pragma: no-cover raise qml.QuantumFunctionError( # pragma: no-cover "jax not found. Please install the latest " # pragma: no-cover diff --git a/tests/test_qubit_device.py b/tests/test_qubit_device.py index 73687090d31..f8a10f26ac4 100644 --- a/tests/test_qubit_device.py +++ b/tests/test_qubit_device.py @@ -1188,10 +1188,7 @@ def test_postselect_mode_propagates_to_execute(self, monkeypatch, postselect_mod def new_apply(*args, **kwargs): # pylint: disable=unused-argument nonlocal pm_propagated - if kwargs.get("postselect_mode", -1) == postselect_mode: - pm_propagated = True - else: - pm_propagated = False + pm_propagated = kwargs.get("postselect_mode", -1) == postselect_mode @qml.qnode(dev, postselect_mode=postselect_mode) def func(): From d7eeb1d67398cbe29fce6a931c353ee6e1d563ba Mon Sep 17 00:00:00 2001 From: Mudit Pandey Date: Tue, 4 Jun 2024 16:02:34 -0400 Subject: [PATCH 47/48] Add coverage; update docs --- doc/introduction/measurements.rst | 10 +++++++--- pennylane/workflow/execution.py | 14 +++++++------- 2 files changed, 14 insertions(+), 10 deletions(-) diff --git a/doc/introduction/measurements.rst b/doc/introduction/measurements.rst index e0958538e3c..4718855171c 100644 --- a/doc/introduction/measurements.rst +++ b/doc/introduction/measurements.rst @@ -263,6 +263,8 @@ outcome of such mid-circuit measurements: qml.cond(m_0, qml.RY)(y, wires=0) return qml.probs(wires=[0]) +.. _deferred_measurements: + Deferred measurements ********************* @@ -312,6 +314,8 @@ tensor([0.90165331, 0.09834669], requires_grad=True) and quantum hardware. The one-shot transform below does not have this limitation, but has computational cost that scales with the number of shots used. +.. _one_shot_transform: + The one-shot transform ********************** @@ -532,9 +536,9 @@ PennyLane. For ease of use, we provide the following configuration options to us :class:`~pennylane.QNode`: * ``mcm_method``: To set the method used for applying mid-circuit measurements. Use ``mcm_method="deferred"`` - to use the deferred measurements principle or ``mcm_method="one-shot"`` to use the one-shot transform as - described above. When executing with finite shots, ``mcm_method="one-shot"`` will be the default, and - ``mcm_method="deferred"`` otherwise. + to use the :ref:`deferred measurements principle ` or ``mcm_method="one-shot"`` to use + the :ref:`one-shot transform `. When executing with finite shots, ``mcm_method="one-shot"`` + will be the default, and ``mcm_method="deferred"`` otherwise. .. warning:: diff --git a/pennylane/workflow/execution.py b/pennylane/workflow/execution.py index 9607c3c2ef3..f27d100e752 100644 --- a/pennylane/workflow/execution.py +++ b/pennylane/workflow/execution.py @@ -388,13 +388,13 @@ def _get_interface_name(tapes, interface): if INTERFACE_MAP.get(interface, "") == "tf" and _use_tensorflow_autograph(): interface = "tf-autograph" if interface == "jax": - try: # pragma: no-cover - from .interfaces.jax import get_jax_interface_name # pragma: no-cover - except ImportError as e: # pragma: no-cover - raise qml.QuantumFunctionError( # pragma: no-cover - "jax not found. Please install the latest " # pragma: no-cover - "version of jax to enable the 'jax' interface." # pragma: no-cover - ) from e # pragma: no-cover + try: # pragma: no cover + from .interfaces.jax import get_jax_interface_name + except ImportError as e: # pragma: no cover + raise qml.QuantumFunctionError( # pragma: no cover + "jax not found. Please install the latest " # pragma: no cover + "version of jax to enable the 'jax' interface." # pragma: no cover + ) from e # pragma: no cover interface = get_jax_interface_name(tapes) From 2e3e3a03d1294155c5acedc794ad0376d1914a01 Mon Sep 17 00:00:00 2001 From: Mudit Pandey Date: Tue, 4 Jun 2024 16:05:19 -0400 Subject: [PATCH 48/48] doc fix --- doc/introduction/measurements.rst | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/doc/introduction/measurements.rst b/doc/introduction/measurements.rst index 4718855171c..e837d5603f7 100644 --- a/doc/introduction/measurements.rst +++ b/doc/introduction/measurements.rst @@ -604,10 +604,10 @@ PennyLane. For ease of use, we provide the following configuration options to us -1.0000000e+00, -2.1474836e+09, -1.0000000e+00, -2.1474836e+09, -1.0000000e+00, -1.0000000e+00], dtype=float32, weak_type=True) - * When using ``jax.jit``, using ``mcm_method="deferred"`` is not supported with ``postselect_mode="hw-like"``. - Therefore, the default behaviour will be to use ``postselect_mode="fill-shots"``. This is due to limitations - of the :func:`~pennylane.defer_measurements` transform, and this behaviour will change in the future to be - more consistent with ``mcm_method="one-shot"``. + * When using ``jax.jit``, using ``mcm_method="deferred"`` is not supported with ``postselect_mode="hw-like"`` and + an error will be raised if this configuration is requested. This is due to limitations of the + :func:`~pennylane.defer_measurements` transform, and this behaviour will change in the future to be more + consistent with ``mcm_method="one-shot"``. Changing the number of shots ----------------------------