From 5dccf5d7f939b5e4356f2f593948d2337b06981d Mon Sep 17 00:00:00 2001 From: Mudit Pandey Date: Thu, 20 Jun 2024 15:55:12 -0400 Subject: [PATCH] Store measurement value in MidMeasureMP.mv --- pennylane/measurements/mid_measure.py | 10 +++++++++- pennylane/transforms/dynamic_one_shot.py | 10 ++-------- 2 files changed, 11 insertions(+), 9 deletions(-) diff --git a/pennylane/measurements/mid_measure.py b/pennylane/measurements/mid_measure.py index 971eb8bcb87..a941f2b9e59 100644 --- a/pennylane/measurements/mid_measure.py +++ b/pennylane/measurements/mid_measure.py @@ -222,7 +222,7 @@ def func(x): raise NotImplementedError( "Capture cannot currently handle classical output from mid circuit measurements." ) - return MeasurementValue([mp], processing_fn=lambda v: v) + return mp.mv T = TypeVar("T") @@ -261,6 +261,14 @@ def __init__( super().__init__(wires=Wires(wires), id=id) self.reset = reset self.postselect = postselect + self.mv = MeasurementValue([self], processing_fn=lambda v: v) + + @property + def wires(self): + # Overriden wires property as MeasurementProcess.wires uses mv.wires when mv is not None, + # and MeasurementValue.wires uses the wires of the mid-circuit measurements, which would + # lead to infinite recursion. + return self._wires # pylint: disable=arguments-renamed, arguments-differ @classmethod diff --git a/pennylane/transforms/dynamic_one_shot.py b/pennylane/transforms/dynamic_one_shot.py index bbd83c1bada..40c91b1679f 100644 --- a/pennylane/transforms/dynamic_one_shot.py +++ b/pennylane/transforms/dynamic_one_shot.py @@ -27,7 +27,6 @@ from pennylane.measurements import ( CountsMP, ExpectationMP, - MeasurementValue, MidMeasureMP, ProbabilityMP, SampleMP, @@ -210,12 +209,7 @@ def init_auxiliary_tape(circuit: qml.tape.QuantumScript): new_measurements.append(m) for op in circuit.operations: if isinstance(op, MidMeasureMP): - mv = ( - op.mcm_tracer - if "MidCircuitMeasure" in str(type(op)) - else MeasurementValue([op], lambda res: res) - ) - new_measurements.append(qml.sample(op=mv)) + new_measurements.append(qml.sample(op=op.mv)) return qml.tape.QuantumScript( circuit.operations, new_measurements, @@ -326,7 +320,7 @@ def gather_mcm_qjit(measurement, samples, is_valid): # pragma: no cover """ found, meas = False, None for k, meas in samples.items(): - if measurement.mv is k.mcm_tracer: + if measurement.mv is k.mv: found = True break if not found: