Skip to content

Commit

Permalink
MCM - tree-traversal implementation of native MCM execution (#5180)
Browse files Browse the repository at this point in the history
### Before submitting

Please complete the following checklist when submitting a PR:

- [x] All new features must include a unit test.
If you've fixed a bug or added code that should be tested, add a test to
the
      test directory!

- [x] All new functions and code must be clearly commented and
documented.
If you do make documentation changes, make sure that the docs build and
      render correctly by running `make docs`.

- [x] Ensure that the test suite passes, by running `make test`.

- [x] Add a new entry to the `doc/releases/changelog-dev.md` file,
summarizing the
      change, and including a link back to the PR.

- [x] The PennyLane source code conforms to
      [PEP8 standards](https://www.python.org/dev/peps/pep-0008/).
We check all of our code against [Pylint](https://www.pylint.org/).
      To lint modified files, simply `pip install pylint`, and then
      run `pylint pennylane/path/to/file.py`.

When all the above are checked, delete everything above the dashed
line and fill in the pull request template.


------------------------------------------------------------------------------------------------------------

**Context:**
Native MCM execution is slow because executing `n_shots` tapes is
generally redundant and has a lot of overheads.

**Description of the Change:**
Introduce `simulate_tree_mcm` and make it the default execution mode
when using finite shots & MCMs. `dynamic_one_shot` can still be applied
explicitly as a transform. `simulate_tree_mcm` implements a
"high-memory" depth-first tree-traversal algorithm. It is deemed
high-memory because a copy of the state vector is made at every node in
the tree. Since this is a depth first traversal, it incurs a memory cost
proportional to `(n_mcm + 1) 2 ** n_qubit` to store the state vectors at
any moment.

**Benefits:**
Much faster execution in almost any case. Opens avenues for improvement
and other features, for example low-memory depth-first tree-traversal,
high-prob-first traversal, quantum noise simulations.

Here are some benchmarks to illustrate the gains. The following
synthetic workloads shows that for small circuits with not too many MCM,
deferred measurements is best. The tree-traversal approach is slower
than deferred measurements, but much faster than the one-shot
implementation.

![synthetic_time_vs_shots](https://github.com/PennyLaneAI/pennylane/assets/8711156/ad32a445-68b7-4823-b1ff-a14d87c020bf)

A more meaningful example is to run iterative QPE for 10 iterations with
a varying number of shots. The one-shot implementation is again
sluggish. The tree-traversal implementation does better, but appears to
scale worst then deferred measurements, again the fastest.

![iterqpe_time_vs_shots](https://github.com/PennyLaneAI/pennylane/assets/8711156/8cf92fe5-84dd-43fe-9ac1-70c990080910)

The picture changes when running iterative QPE with 1e6 samples for
varying iterations. We do not perform one-shot benchmarks since it is
too slow. The tree-traversal implementation is indeed much slower in the
10-20 iteration range, but starts winning over deferred measurements
beyond that. It thus appears to have a larger prefactor which is
eventually compensated by slightly better scaling.

![iterqpe_time_vs_iters](https://github.com/PennyLaneAI/pennylane/assets/8711156/836fb41f-492b-4c59-ba00-4d23a191e111)

Finally, we perform few-shots calculations to illustrate regimes where
one-shot could be useful. There is indeed an observable cross-over
between one-shot and deferred-measurements. The tree traversal
implementation however is usually faster even with few shots because it
then has a limited number of branches to explore before running out of
shots.

![iterqpe_time_vs_iters_all](https://github.com/PennyLaneAI/pennylane/assets/8711156/2bf338a7-e79c-46fc-929e-9e1f326e6915)


**Possible Drawbacks:**
Some features not tested yet:
- `jax.jit`
- Catalyst `qjit`

**Related GitHub Issues:**
Mid circuit Measurements tree traversal implementation [sc-56035]
[sc-65242]

---------

Co-authored-by: Mudit Pandey <mudit.pandey@xanadu.ai>
Co-authored-by: Christina Lee <christina@xanadu.ai>
Co-authored-by: Matthew Silverman <matthews@xanadu.ai>
Co-authored-by: Thomas R. Bromley <49409390+trbromley@users.noreply.github.com>
  • Loading branch information
5 people committed Jun 17, 2024
1 parent d616c3e commit 2da3b69
Show file tree
Hide file tree
Showing 23 changed files with 1,324 additions and 225 deletions.
794 changes: 688 additions & 106 deletions .github/workflows/core_tests_durations.json

Large diffs are not rendered by default.

4 changes: 2 additions & 2 deletions .github/workflows/format.yml
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,8 @@ jobs:
- name: Run isort
run: |
isort --py 311 --profile black -l 100 -p ./pennylane --skip __init__.py --filter-files ./pennylane --check
isort --py 311 --profile black -l 100 -p ./pennylane --skip __init__.py --filter-files ./tests --check
isort --py 311 --profile black -l 100 -o autoray -p ./pennylane --skip __init__.py --filter-files ./pennylane --check
isort --py 311 --profile black -l 100 -o autoray -p ./pennylane --skip __init__.py --filter-files ./tests --check
- name: Run Pylint (source files)
if: always()
Expand Down
2 changes: 2 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ repos:
"black",
"-l",
"100",
"-o",
"autoray",
"-p",
"./pennylane",
"--skip",
Expand Down
61 changes: 58 additions & 3 deletions doc/introduction/measurements.rst
Original file line number Diff line number Diff line change
Expand Up @@ -333,7 +333,7 @@ analytic calculations.

The :func:`~.pennylane.dynamic_one_shot` transform is usually advantageous compared
with the :func:`~.pennylane.defer_measurements` transform in the
large-number-of-mid-circuit-measurements and small-number-of-shots limit. This is because, unlike the
large-number-of-mid-circuit-measurements and small-number-of-shots limit. This is because, unlike the
deferred measurement principle, the method does not need an additional wire for every
mid-circuit measurement present in the circuit. Otherwise, one generally gets
equivalent results, so you may try both in an attempt to improve performance without
Expand All @@ -354,6 +354,57 @@ The transform can be applied to a QNode as follows:
If the ``defer_measurements`` transform is used in analytic mode, ``backprop`` is also a viable
option.

.. _tree_traversal:

The tree-traversal algorithm
****************************

Dynamic circuit execution is akin to traversing a binary tree where each MCM
corresponds to a node and groups of gates between the MCMs correspond to edges.
The :func:`~.pennylane.dynamic_one_shot` approach picks a branch of the tree randomly
and simulates it from beginning to end.
This is wasteful in many cases; the same branch is simulated many times
when there are more shots than branches for example.
The tree-traversal algorithm does away with such redundancy while retaining the
exponential gains in memory of the one-shot approach compared with the deferred
measurement principle, among other advantages.

Briefly, it proceeds cutting an :math:`n_{MCM}` circuit into :math:`n_{MCM}+1`
circuit segments. Each segment can be executed on either the 0- or 1-branch,
which gives rise to a binary tree with :math:`2^{n_{MCM}}` leaves. Terminal
measurements are obtained at the leaves, and propagated and combined back up at each
node up the tree. The tree is visited using a depth-first pattern. The tree-traversal
method improves on :func:`~.pennylane.dynamic_one_shot` by taking all samples at a
node or leaf at once. Neglecting overheads, simulating all branches requires the same
amount of computations as :func:`~. pennylane.defer_measurements`, but without the
:math:`O(2^{n_{MCM}})` memory requirement. To save time, a copy of the state vector
is made at every branching point, or MCM, requiring at most :math:`n_{MCM}+1` state
vectors at any instant, an exponential improvement compared with :func:`~. pennylane.
defer_measurements`. Since the counts of many nodes come out to be zero in practice,
it is often possible to ignore entire sub-trees, thereby reducing the computational
burden.

To summarize, this algorithm gives us the best of both worlds. In the limit of few
shots and/or many mid-circuit measurements, it is as fast as the naive shot-by-shot implementation
because few sub-trees are explored. In the limit of many shots and/or few mid-circuit measurements, it is
equal to or faster than the deferred measurement algorithm (albeit with more
overheads in practice) because each tree edge is visited at most once, all while
reducing the memory requirements exponentially.

The tree-traversal algorithm is not a transform. Its usage is therefore specified
by passing an ``mcm_method`` option to a QNode (see section
:ref:`"Configuring mid-circuit measurements" <mcm_config>`). For example,

.. code-block:: python
@qml.qnode(dev, mcm_method="tree-traversal")
def my_quantum_function(x, y):
(...)
.. warning::

The tree-traversal algorithm is only implemented in the :class:`~.pennylane.devices.DefaultQubit` device.

Resetting wires
***************

Expand Down Expand Up @@ -528,6 +579,8 @@ Collecting statistics for sequences of mid-circuit measurements is supported wit
When collecting statistics for a list of mid-circuit measurements, values manipulated using
arithmetic operators should not be used as this behaviour is not supported.

.. _mcm_config:

Configuring mid-circuit measurements
************************************

Expand All @@ -536,8 +589,10 @@ PennyLane. For ease of use, we provide the following configuration options to us
:class:`~pennylane.QNode`:

* ``mcm_method``: To set the method used for applying mid-circuit measurements. Use ``mcm_method="deferred"``
to use the :ref:`deferred measurements principle <deferred_measurements>` or ``mcm_method="one-shot"`` to use
the :ref:`one-shot transform <one_shot_transform>`. When executing with finite shots, ``mcm_method="one-shot"``
to apply the :ref:`deferred measurements principle <deferred_measurements>`, ``mcm_method="one-shot"`` to apply
the :ref:`one-shot transform <one_shot_transform>` or ``mcm_method="tree-traversal"`` to execute the
:ref:`tree-traversal algorithm <tree_traversal>`.
When executing with finite shots, ``mcm_method="one-shot"``
will be the default, and ``mcm_method="deferred"`` otherwise. Additionally, if using :func:`~pennylane.qjit`,
``mcm_method="single-branch-statistics"`` can also be used and will be the default. Using this method, a single
branch of the execution tree will be randomly explored.
Expand Down
6 changes: 6 additions & 0 deletions doc/releases/changelog-dev.md
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,12 @@

<h4>Mid-circuit measurements and dynamic circuits</h4>

* The `default.qubit` device implements a depth-first tree-traversal algorithm to
accelerate native mid-circuit measurement execution. The new implementation
supports classical control, collecting statistics, and post-selection, along
with all measurements enabled with `qml.dynamic_one_shot`.
[(#5180)](https://github.com/PennyLaneAI/pennylane/pull/5180)

* `qml.QNode` and `qml.qnode` now accept two new keyword arguments: `postselect_mode` and `mcm_method`.
These keyword arguments can be used to configure how the device should behave when running circuits with
mid-circuit measurements.
Expand Down
5 changes: 4 additions & 1 deletion pennylane/devices/default_qubit.py
Original file line number Diff line number Diff line change
Expand Up @@ -519,7 +519,8 @@ def preprocess(
transform_program.add_transform(
validate_observables, stopping_condition=observable_stopping_condition, name=self.name
)

if config.mcm_config.mcm_method == "tree-traversal":
transform_program.add_transform(qml.transforms.broadcast_expand)
# Validate multi processing
max_workers = config.device_options.get("max_workers", self._max_workers)
if max_workers:
Expand Down Expand Up @@ -602,6 +603,7 @@ def execute(
"interface": interface,
"state_cache": self._state_cache,
"prng_key": _key,
"mcm_method": execution_config.mcm_config.mcm_method,
"postselect_mode": execution_config.mcm_config.postselect_mode,
},
)
Expand All @@ -614,6 +616,7 @@ def execute(
{
"rng": _rng,
"prng_key": _key,
"mcm_method": execution_config.mcm_config.mcm_method,
"postselect_mode": execution_config.mcm_config.postselect_mode,
}
for _rng, _key in zip(seeds, prng_keys)
Expand Down
8 changes: 7 additions & 1 deletion pennylane/devices/execution_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,13 @@ def __post_init__(self):
Note that this hook is automatically called after init via the dataclass integration.
"""
if self.mcm_method not in ("deferred", "one-shot", "single-branch-statistics", None):
if self.mcm_method not in (
"deferred",
"one-shot",
"single-branch-statistics",
"tree-traversal",
None,
):
raise ValueError(f"Invalid mid-circuit measurements method '{self.mcm_method}'.")
if self.postselect_mode not in ("hw-like", "fill-shots", None):
raise ValueError(f"Invalid postselection mode '{self.postselect_mode}'.")
Expand Down
2 changes: 2 additions & 0 deletions pennylane/devices/preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,8 @@ def mid_circuit_measurements(

if mcm_method == "one-shot":
return qml.dynamic_one_shot(tape, interface=interface)
if mcm_method == "tree-traversal":
return (tape,), null_postprocessing
return qml.defer_measurements(tape, device=device)


Expand Down
Loading

0 comments on commit 2da3b69

Please sign in to comment.