Skip to content

Commit

Permalink
Store measurement value in MidMeasureMP.mv
Browse files Browse the repository at this point in the history
  • Loading branch information
mudit2812 committed Jun 20, 2024
1 parent 493f37a commit 5dccf5d
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 9 deletions.
10 changes: 9 additions & 1 deletion pennylane/measurements/mid_measure.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,7 @@ def func(x):
raise NotImplementedError(
"Capture cannot currently handle classical output from mid circuit measurements."
)
return MeasurementValue([mp], processing_fn=lambda v: v)
return mp.mv


T = TypeVar("T")
Expand Down Expand Up @@ -261,6 +261,14 @@ def __init__(
super().__init__(wires=Wires(wires), id=id)
self.reset = reset
self.postselect = postselect
self.mv = MeasurementValue([self], processing_fn=lambda v: v)

@property
def wires(self):
# Overriden wires property as MeasurementProcess.wires uses mv.wires when mv is not None,
# and MeasurementValue.wires uses the wires of the mid-circuit measurements, which would
# lead to infinite recursion.
return self._wires

# pylint: disable=arguments-renamed, arguments-differ
@classmethod
Expand Down
10 changes: 2 additions & 8 deletions pennylane/transforms/dynamic_one_shot.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@
from pennylane.measurements import (
CountsMP,
ExpectationMP,
MeasurementValue,
MidMeasureMP,
ProbabilityMP,
SampleMP,
Expand Down Expand Up @@ -210,12 +209,7 @@ def init_auxiliary_tape(circuit: qml.tape.QuantumScript):
new_measurements.append(m)
for op in circuit.operations:
if isinstance(op, MidMeasureMP):
mv = (
op.mcm_tracer
if "MidCircuitMeasure" in str(type(op))
else MeasurementValue([op], lambda res: res)
)
new_measurements.append(qml.sample(op=mv))
new_measurements.append(qml.sample(op=op.mv))
return qml.tape.QuantumScript(
circuit.operations,
new_measurements,
Expand Down Expand Up @@ -326,7 +320,7 @@ def gather_mcm_qjit(measurement, samples, is_valid): # pragma: no cover
"""
found, meas = False, None
for k, meas in samples.items():
if measurement.mv is k.mcm_tracer:
if measurement.mv is k.mv:
found = True
break
if not found:
Expand Down

0 comments on commit 5dccf5d

Please sign in to comment.