Skip to content

Commit

Permalink
Rename
Browse files Browse the repository at this point in the history
  • Loading branch information
vincentmr committed Jun 3, 2024
1 parent 7dd313d commit 7fd6447
Showing 1 changed file with 31 additions and 31 deletions.
62 changes: 31 additions & 31 deletions pennylane/transforms/dynamic_one_shot.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,10 +213,10 @@ def init_auxiliary_tape(circuit: qml.tape.QuantumScript):
new_operations.append(new_op)
else:
new_operations.append(op)
if isinstance(op, MidMeasureMP):
new_measurements.append(qml.sample(MeasurementValue([op], lambda res: res)))
elif "MidCircuitMeasure" in str(type(op)):
if "MidCircuitMeasure" in str(type(op)):
new_measurements.append(qml.sample(op.out_classical_tracers[0]))
elif isinstance(op, MidMeasureMP):
new_measurements.append(qml.sample(MeasurementValue([op], lambda res: res)))
return qml.tape.QuantumScript(
new_operations,
new_measurements,
Expand Down Expand Up @@ -276,7 +276,7 @@ def measurement_with_no_shots(measurement):
)
if interface != "jax" and m.mv and not has_valid:
meas = measurement_with_no_shots(m)
elif m.mv and active_jit:
elif m.mv and active_qjit:
meas = gather_mcm_jit(m, mcm_samples, is_valid)
elif m.mv:
meas = gather_mcm(m, mcm_samples, is_valid)
Expand All @@ -290,7 +290,7 @@ def measurement_with_no_shots(measurement):
# as it assumes all elements of the input are of builtin python types and not belonging
# to any particular interface
result = qml.math.array(result, like=interface)
if active_jit:
if active_qjit:
if isinstance(m, CountsMP):
normalized_meas.append(
(result[0][0], qml.math.sum(result[1] * is_valid.reshape((-1, 1)), axis=0))
Expand All @@ -307,76 +307,76 @@ def measurement_with_no_shots(measurement):
return tuple(normalized_meas) if len(normalized_meas) > 1 else normalized_meas[0]


def gather_mcm_jit(circuit_measurement, measurement, is_valid):
"""Combines, gathers and normalizes several measurements with trivial measurement values
when the Catalyst compiler is active.
def gather_mcm_jit(measurement, samples, is_valid):
"""Process MCM measurements when the Catalyst compiler is active.
Args:
circuit_measurement (MeasurementProcess): measurement
measurement (TensorLike): measurement results
samples (List[dict]): Mid-circuit measurement samples
measurement (MeasurementProcess): measurement
samples (dict): Mid-circuit measurement samples
is_valid (TensorLike): Mask of valid samples
Returns:
TensorLike: The combined measurement outcome
"""
found, meas = False, None
for k, meas in measurement.items():
if circuit_measurement.mv is k.out_classical_tracers[0]:
for k, meas in samples.items():
if measurement.mv is k.out_classical_tracers[0]:
found = True
break
if not found:
raise LookupError("MCM not found")
meas = qml.math.squeeze(meas)
sum_valid = qml.math.sum(is_valid)
count_1 = qml.math.sum(meas * is_valid)
if isinstance(circuit_measurement, CountsMP):
if isinstance(measurement, (CountsMP, ProbabilityMP)):
sum_valid = qml.math.sum(is_valid)
count_1 = qml.math.sum(meas * is_valid)
if isinstance(measurement, CountsMP):
return {0: sum_valid - count_1, 1: count_1}
if isinstance(circuit_measurement, ProbabilityMP):
if isinstance(measurement, ProbabilityMP):
counts = qml.math.array(
[sum_valid - count_1, count_1], like=qml.math.get_deep_interface(is_valid)
)
return counts / sum_valid
return gather_non_mcm(circuit_measurement, meas, is_valid)
return gather_non_mcm(measurement, meas, is_valid)


def gather_non_mcm(circuit_measurement, measurement, is_valid):
"""Combines, gathers and normalizes several measurements with trivial measurement values.
def gather_non_mcm(circuit_measurement, measurements, is_valid):
"""Combines, gathers and normalizes an array of terminal measurements.
Args:
circuit_measurement (MeasurementProcess): measurement
measurement (TensorLike): measurement results
samples (List[dict]): Mid-circuit measurement samples
circuit_measurement (MeasurementProcess): Measurement
measurements (TensorLike): Stacked measurement results
is_valid (TensorLike): Mask of valid samples
Returns:
TensorLike: The combined measurement outcome
"""
if isinstance(circuit_measurement, CountsMP):
tmp = Counter()
for i, d in enumerate(measurement):
for i, d in enumerate(measurements):
tmp.update(
dict((k if isinstance(k, str) else float(k), v * is_valid[i]) for k, v in d.items())
)
if not circuit_measurement.all_outcomes:
tmp = Counter({k: v for k, v in tmp.items() if v > 0})
return dict(sorted(tmp.items()))
if isinstance(circuit_measurement, ExpectationMP):
return qml.math.sum(measurement * is_valid) / qml.math.sum(is_valid)
return qml.math.sum(measurements * is_valid) / qml.math.sum(is_valid)
if isinstance(circuit_measurement, ProbabilityMP):
return qml.math.sum(measurement * is_valid.reshape((-1, 1)), axis=0) / qml.math.sum(
return qml.math.sum(measurements * is_valid.reshape((-1, 1)), axis=0) / qml.math.sum(
is_valid
)
if isinstance(circuit_measurement, SampleMP):
is_interface_jax = qml.math.get_deep_interface(is_valid) == "jax"
if is_interface_jax and measurement.ndim == 2:
if is_interface_jax and measurements.ndim == 2:
is_valid = is_valid.reshape((-1, 1))
return (
qml.math.where(is_valid, measurement, fill_in_value)
qml.math.where(is_valid, measurements, fill_in_value)
if is_interface_jax
else measurement[is_valid]
else measurements[is_valid]
)
# VarianceMP
expval = qml.math.sum(measurement * is_valid) / qml.math.sum(is_valid)
return qml.math.sum((measurement - expval) ** 2 * is_valid) / qml.math.sum(is_valid)
expval = qml.math.sum(measurements * is_valid) / qml.math.sum(is_valid)
return qml.math.sum((measurements - expval) ** 2 * is_valid) / qml.math.sum(is_valid)


def gather_mcm(measurement, samples, is_valid):
Expand Down

0 comments on commit 7fd6447

Please sign in to comment.