Skip to content

Commit

Permalink
dynamic_one_shot supports broadcasting; broadcast_expand supports…
Browse files Browse the repository at this point in the history
… shot vectors (#5473)

**Context:**
Native mid-circuit measurements with `default.qubit` are not compatible
with parameter broadcasting. Due to the complexity of a "native"
implementation, I decided to use `broadcast_expand`, but realized that
it does not work with shot vectors.

**Description of the Change:**
This PR does two things:
* Update `dynamic_one_shot` transform to use `broadcast_expand` and
process batched results correctly.
* Update `broadcast_expand` to support shot vectors.
* Raise error when postselecting with broadcasting and returning
samples. This change was made to both `dynamic_one_shot` and
`defer_measurements` because both transforms use `broadcast_expand` for
broadcasting, although `defer_measurements` only uses `broadcast_expand`
with postselection.

Note that broadcasting with `qml.sample` and postselection will still
not work due to ragged dimensions. If reviewers are okay with it, I
would like to merge this and leave that improvement as technical debt.
cc @trbromley @isaacdevlugt.
Edit about note: Talked offline, decided to raise a more informative
error if a user requests postselection with broadcasting when returning
samples.

**Benefits:**
Both transforms are more capable.

**Possible Drawbacks:**
Because of the stacking and squeezing happening in the post-processing
function of `broadcast_expand`, counts dictionaries get wrapped inside
0-D numpy arrays, which makes indexing into the dict impossible. To
access the dictionary and its contents, one has to use `array.item()` to
extract the single item inside the array.

**Related GitHub Issues:**
#5443
  • Loading branch information
mudit2812 authored Apr 17, 2024
1 parent a47d9bc commit 2b47a00
Show file tree
Hide file tree
Showing 13 changed files with 29,932 additions and 27,651 deletions.
57,102 changes: 29,515 additions & 27,587 deletions .github/workflows/core_tests_durations.json

Large diffs are not rendered by default.

10 changes: 10 additions & 0 deletions doc/releases/changelog-dev.md
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,13 @@

<h3>Improvements 🛠</h3>

* `qml.ops.Conditional` now stores the `data`, `num_params`, and `ndim_param` attributes of
the operator it wraps.
[(#5473)](https://github.com/PennyLaneAI/pennylane/pull/5473)

* `qml.transforms.broadcast_expand` now supports shot vectors when returning `qml.sample()`.
[(#5473)](https://github.com/PennyLaneAI/pennylane/pull/5473)

* `LightningVJPs` is now compatible with Lightning devices using the new device API.
[(#5469)](https://github.com/PennyLaneAI/pennylane/pull/5469)

Expand Down Expand Up @@ -368,6 +375,9 @@

<h3>Bug fixes 🐛</h3>

* The `dynamic_one_shot` transform now works with broadcasting.
[(#5473)](https://github.com/PennyLaneAI/pennylane/pull/5473)

* Diagonalize the state around `ProbabilityMP` measurements in `statistics` when executing on a Lightning device.
[(#5529)](https://github.com/PennyLaneAI/pennylane/pull/5529)

Expand Down
2 changes: 1 addition & 1 deletion pennylane/devices/preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ def mid_circuit_measurements(
and use the ``qml.defer_measurements`` transform otherwise.
"""

if tape.shots and tape.batch_size is None:
if tape.shots:
return qml.dynamic_one_shot(tape)
return qml.defer_measurements(tape, device=device)

Expand Down
8 changes: 8 additions & 0 deletions pennylane/ops/functions/bind_new_parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,3 +240,11 @@ def bind_new_parameters_tensor(op: Tensor, params: Sequence[TensorLike]):
new_obs.append(bind_new_parameters(obs, sub_params))

return Tensor(*new_obs)


@bind_new_parameters.register
def bind_new_parameters_conditional(op: qml.ops.Conditional, params: Sequence[TensorLike]):
then_op = bind_new_parameters(op.then_op, params)
mv = copy.deepcopy(op.meas_val)

return qml.ops.Conditional(mv, then_op)
10 changes: 9 additions & 1 deletion pennylane/ops/op_math/condition.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,11 +52,19 @@ class Conditional(Operation):
def __init__(self, expr, then_op: Type[Operation], id=None):
self.meas_val = expr
self.then_op = then_op
super().__init__(wires=then_op.wires, id=id)
super().__init__(*then_op.data, wires=then_op.wires, id=id)

def label(self, decimals=None, base_label=None, cache=None):
return self.then_op.label(decimals=decimals, base_label=base_label, cache=cache)

@property
def num_params(self):
return self.then_op.num_params

@property
def ndim_params(self):
return self.then_op.ndim_params

def map_wires(self, wire_map):
meas_val = self.meas_val.map_wires(wire_map)
then_op = self.then_op.map_wires(wire_map)
Expand Down
35 changes: 29 additions & 6 deletions pennylane/transforms/broadcast_expand.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,6 @@ def broadcast_expand(tape: qml.tape.QuantumTape) -> (Sequence[qml.tape.QuantumTa
>>> fn(qml.execute(tapes, qml.device("default.qubit", wires=1), None))
tensor([0.98006658, 0.82533561, 0.54030231], requires_grad=True)
"""
# pylint: disable=protected-access
if tape.batch_size is None:
output_tapes = [tape]

Expand All @@ -143,14 +142,38 @@ def null_postprocessing(results):
output_tapes.append(new_tape)

def processing_fn(results: qml.typing.ResultBatch) -> qml.typing.Result:
# The shape of the results should be as follows: results[s][m][b], where s is the shot
# vector index, m is the measurement index, and b is the batch index. The shape that
# the processing function receives is results[b][s][m].

if tape.shots.has_partitioned_shots:
if len(tape.measurements) > 1:
return tuple(
tuple(
qml.math.squeeze(
qml.math.stack([results[b][s][m] for b in range(tape.batch_size)])
)
for m in range(len(tape.measurements))
)
for s in range(tape.shots.num_copies)
)

# Only need to transpose results[b][s] -> results[s][b]
return tuple(
qml.math.squeeze(
qml.math.stack([results[b][s] for b in range(tape.batch_size)])
)
for s in range(tape.shots.num_copies)
)

if len(tape.measurements) > 1:
processed_results = [
# Only need to transpose results[b][m] -> results[m][b]
return tuple(
qml.math.squeeze(
qml.math.stack([results[b][i] for b in range(tape.batch_size)])
qml.math.stack([results[b][m] for b in range(tape.batch_size)])
)
for i in range(len(tape.measurements))
]
return tuple(processed_results)
for m in range(len(tape.measurements))
)
return qml.math.squeeze(qml.math.stack(results))

return output_tapes, processing_fn
10 changes: 10 additions & 0 deletions pennylane/transforms/defer_measurements.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,16 @@ def _check_tape_validity(tape: QuantumTape):
"measurements on a device that does not support them."
)

samples_present = any(isinstance(mp, SampleMP) for mp in tape.measurements)
postselect_present = any(
op.postselect is not None for op in tape.operations if isinstance(op, MidMeasureMP)
)
if postselect_present and samples_present and tape.batch_size is not None:
raise ValueError(
"Returning qml.sample is not supported when postselecting mid-circuit "
"measurements with broadcasting"
)


def _collect_mid_measure_info(tape: QuantumTape):
"""Helper function to collect information related to mid-circuit measurements in the tape."""
Expand Down
68 changes: 55 additions & 13 deletions pennylane/transforms/dynamic_one_shot.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,8 +76,8 @@ def func(x, y):
calculations, where in each calculation the ``qml.measure`` operations dynamically
measures the 0-wire and collapse the state vector stochastically. This transforms
contrasts with ``qml.defer_measurements``, which instead introduces an extra wire
for each mid-circuit measurement. The ``qml.dynamic_one_shot`` transform is favorable in the few-shots
several-mid-circuit-measurement limit, whereas ``qml.defer_measurements`` is favorable
for each mid-circuit measurement. The ``qml.dynamic_one_shot`` transform is favorable in the
few-shots several-mid-circuit-measurement limit, whereas ``qml.defer_measurements`` is favorable
in the opposite limit.
"""

Expand All @@ -87,31 +87,72 @@ def func(x, y):
for m in tape.measurements:
if not isinstance(m, (CountsMP, ExpectationMP, ProbabilityMP, SampleMP, VarianceMP)):
raise TypeError(
f"Native mid-circuit measurement mode does not support {type(m).__name__} measurements."
f"Native mid-circuit measurement mode does not support {type(m).__name__} "
"measurements."
)

aux_tape = init_auxiliary_tape(tape)
output_tapes = [aux_tape] * tape.shots.total_shots
samples_present = any(isinstance(mp, SampleMP) for mp in tape.measurements)
postselect_present = any(
op.postselect is not None for op in tape.operations if isinstance(op, MidMeasureMP)
)
if postselect_present and samples_present and tape.batch_size is not None:
raise ValueError(
"Returning qml.sample is not supported when postselecting mid-circuit "
"measurements with broadcasting"
)

if (batch_size := tape.batch_size) is not None:
tapes, broadcast_fn = qml.transforms.broadcast_expand(tape)
else:
tapes = [tape]
broadcast_fn = None

aux_tapes = [init_auxiliary_tape(t) for t in tapes]
# Shape of output_tapes is (batch_size * total_shots,) with broadcasting,
# and (total_shots,) otherwise
output_tapes = [at for at in aux_tapes for _ in range(tape.shots.total_shots)]

def processing_fn(results, has_partitioned_shots=None, batched_results=None):
if batched_results is None and batch_size is not None:
# If broadcasting, recursively process the results for each batch. For each batch
# there are tape.shots.total_shots results. The length of the first axis of final_results
# will be batch_size.
results = list(results)
final_results = []
for _ in range(batch_size):
final_results.append(
processing_fn(results[0 : tape.shots.total_shots], batched_results=False)
)
del results[0 : tape.shots.total_shots]
return broadcast_fn(final_results)

def processing_fn(results, has_partitioned_shots=None):
if has_partitioned_shots is None and tape.shots.has_partitioned_shots:
# If using shot vectors, recursively process the results for each shot bin. The length
# of the first axis of final_results will be the length of the shot vector.
results = list(results)
final_results = []
for s in tape.shots:
final_results.append(processing_fn(results[0:s], has_partitioned_shots=False))
final_results.append(
processing_fn(results[0:s], has_partitioned_shots=False, batched_results=False)
)
del results[0:s]
return tuple(final_results)

# The following code assumes no broadcasting and no shot vectors. The above code should
# handle those cases
all_shot_meas, list_mcm_values_dict, valid_shots = None, [], 0
for res in results:
one_shot_meas, mcm_values_dict = res
if one_shot_meas is None:
continue
valid_shots += 1
all_shot_meas = accumulate_native_mcm(aux_tape, all_shot_meas, one_shot_meas)
all_shot_meas = accumulate_native_mcm(aux_tapes[0], all_shot_meas, one_shot_meas)
list_mcm_values_dict.append(mcm_values_dict)
if not valid_shots:
warnings.warn(
"All shots were thrown away as invalid. This can happen for example when post-selecting the 1-branch of a 0-state. Make sure your circuit has some probability of producing a valid shot.",
"All shots were thrown away as invalid. This can happen for example when "
"post-selecting the 1-branch of a 0-state. Make sure your circuit has some "
"probability of producing a valid shot.",
UserWarning,
)
return parse_native_mid_circuit_measurements(tape, all_shot_meas, list_mcm_values_dict)
Expand All @@ -131,12 +172,13 @@ def _dynamic_one_shot_qnode(self, qnode, targs, tkwargs):
support_mcms = hasattr(qnode.device, "capabilities") and qnode.device.capabilities().get(
"supports_mid_measure", False
)
support_mcms = support_mcms or isinstance(
qnode.device, qml.devices.default_qubit.DefaultQubit
)
support_mcms = support_mcms or qnode.device.name in ("default.qubit", "lightning.qubit")
if not support_mcms:
raise TypeError(
f"Device {qnode.device.name} does not support mid-circuit measurements natively, and hence it does not support the dynamic_one_shot transform. `default.qubit` and `lightning.qubit` currently support mid-circuit measurements and the dynamic_one_shot transform."
f"Device {qnode.device.name} does not support mid-circuit measurements "
"natively, and hence it does not support the dynamic_one_shot transform. "
"'default.qubit' and 'lightning.qubit' currently support mid-circuit "
"measurements and the dynamic_one_shot transform."
)
tkwargs.setdefault("device", qnode.device)
return self.default_qnode_transform(qnode, targs, tkwargs)
Expand Down
Loading

0 comments on commit 2b47a00

Please sign in to comment.