Skip to content

Commit

Permalink
dynamic_one_shot is compatible with jax.jit (#5557)
Browse files Browse the repository at this point in the history
### Before submitting

Please complete the following checklist when submitting a PR:

- [x] All new features must include a unit test.
If you've fixed a bug or added code that should be tested, add a test to
the
      test directory!

- [x] All new functions and code must be clearly commented and
documented.
If you do make documentation changes, make sure that the docs build and
      render correctly by running `make docs`.

- [x] Ensure that the test suite passes, by running `make test`.

- [x] Add a new entry to the `doc/releases/changelog-dev.md` file,
summarizing the
      change, and including a link back to the PR.

- [x] The PennyLane source code conforms to
      [PEP8 standards](https://www.python.org/dev/peps/pep-0008/).
We check all of our code against [Pylint](https://www.pylint.org/).
      To lint modified files, simply `pip install pylint`, and then
      run `pylint pennylane/path/to/file.py`.

When all the above are checked, delete everything above the dashed
line and fill in the pull request template.


------------------------------------------------------------------------------------------------------------

**Context:**
As a first step toward our goal to support `dynamic_one_shot` with
Catalyst, we attempt to support `jax.jit`.

**Description of the Change:**
Make all array shapes static. Perform post-selection as a filtering
operation during the post-processing step of `dynamic_one_shot`. Add
`jax.jit` tests. We also introduce the constant `fillin_value` which is
used as a placeholder for postselection mismatched samples.

**Benefits:**

**Possible Drawbacks:**

**Related GitHub Issues:**
[sc-61173]
[sc-61315]
[sc-62094]

---------

Co-authored-by: Mudit Pandey <mudit.pandey@xanadu.ai>
Co-authored-by: Christina Lee <christina@xanadu.ai>
Co-authored-by: Amintor Dusko <87949283+AmintorDusko@users.noreply.github.com>
  • Loading branch information
4 people authored May 9, 2024
1 parent e0d0a32 commit 9f7e8ba
Show file tree
Hide file tree
Showing 9 changed files with 277 additions and 206 deletions.
8 changes: 7 additions & 1 deletion doc/releases/changelog-dev.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,11 @@

<h3>Improvements 🛠</h3>

<h4>Mid-circuit measurements and dynamic circuits</h4>

* The `dynamic_one_shot` transform can be compiled with `jax.jit`.
[(#5557)](https://github.com/PennyLaneAI/pennylane/pull/5557)

* When using `defer_measurements` with postselecting mid-circuit measurements, operations
that will never be active due to the postselected state are skipped in the transformed
quantum circuit. In addition, postselected controls are skipped, as they are evaluated
Expand Down Expand Up @@ -98,4 +103,5 @@ This release contains contributions from (in alphabetical order):
Pietropaolo Frisoni,
Soran Jahangiri,
Christina Lee,
David Wierichs.
Vincent Michaud-Rioux,
David Wierichs.
84 changes: 61 additions & 23 deletions pennylane/devices/qubit/apply_operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Functions to apply an operation to a state vector."""
# pylint: disable=unused-argument
# pylint: disable=unused-argument, too-many-arguments

from functools import singledispatch
from string import ascii_letters as alphabet
Expand All @@ -21,7 +21,7 @@

import pennylane as qml
from pennylane import math
from pennylane.measurements import MidMeasureMP, Shots
from pennylane.measurements import MidMeasureMP
from pennylane.ops import Conditional

SQRT2INV = 1 / math.sqrt(2)
Expand Down Expand Up @@ -238,14 +238,36 @@ def apply_conditional(
ndarray: output state
"""
mid_measurements = execution_kwargs.get("mid_measurements", None)

rng = execution_kwargs.get("rng", None)
prng_key = execution_kwargs.get("prng_key", None)
interface = qml.math.get_deep_interface(state)
if interface == "jax":
# pylint: disable=import-outside-toplevel
from jax.lax import cond

return cond(
op.meas_val.concretize(mid_measurements),
lambda x: apply_operation(
op.then_op,
x,
is_state_batched=is_state_batched,
debugger=debugger,
mid_measurements=mid_measurements,
rng=rng,
prng_key=prng_key,
),
lambda x: x,
state,
)
if op.meas_val.concretize(mid_measurements):
return apply_operation(
op.then_op,
state,
is_state_batched=is_state_batched,
debugger=debugger,
mid_measurements=mid_measurements,
rng=rng,
prng_key=prng_key,
)
return state

Expand Down Expand Up @@ -273,31 +295,47 @@ def apply_mid_measure(
mid_measurements = execution_kwargs.get("mid_measurements", None)
rng = execution_kwargs.get("rng", None)
prng_key = execution_kwargs.get("prng_key", None)

if is_state_batched:
raise ValueError("MidMeasureMP cannot be applied to batched states.")
if not np.allclose(np.linalg.norm(state), 1.0):
mid_measurements[op] = -1
return np.zeros_like(state)
wire = op.wires
sample = qml.devices.qubit.sampling.measure_with_samples(
[qml.sample(wires=wire)], state, Shots(1), rng=rng, prng_key=prng_key
)
sample = int(sample[0])
mid_measurements[op] = sample
if op.postselect is not None and sample != op.postselect:
mid_measurements[op] = -1
return np.zeros_like(state)
axis = wire.toarray()[0]
slices = [slice(None)] * qml.math.ndim(state)
slices[axis] = int(not sample)
state[tuple(slices)] = 0.0
state_norm = np.linalg.norm(state)
state = state / state_norm
if op.reset and sample == 1:
state = apply_operation(
qml.X(wire), state, is_state_batched=is_state_batched, debugger=debugger
)
slices[axis] = 0
prob0 = qml.math.norm(state[tuple(slices)]) ** 2
interface = qml.math.get_deep_interface(state)
if prng_key is not None:
# pylint: disable=import-outside-toplevel
from jax.random import binomial

def binomial_fn(n, p):
return binomial(prng_key, n, p).astype(int)

else:
binomial_fn = np.random.binomial if rng is None else rng.binomial
sample = binomial_fn(1, 1 - prob0)
mid_measurements[op] = sample

# Using apply_operation(qml.QubitUnitary,...) instead of apply_operation(qml.Projector([sample], wire),...)
# to select the sample branch enables jax.jit and prevents it from using Python callbacks
matrix = qml.math.array([[(sample + 1) % 2, 0.0], [0.0, (sample) % 2]], like=interface)
state = apply_operation(
qml.QubitUnitary(matrix, wire),
state,
is_state_batched=is_state_batched,
debugger=debugger,
)
state = state / qml.math.norm(state)

# Using apply_operation(qml.QubitUnitary,...) instead of apply_operation(qml.X(wire), ...)
# to reset enables jax.jit and prevents it from using Python callbacks
element = op.reset and sample == 1
matrix = qml.math.array(
[[(element + 1) % 2, (element) % 2], [(element) % 2, (element + 1) % 2]], like=interface
).astype(float)
state = apply_operation(
qml.QubitUnitary(matrix, wire), state, is_state_batched=is_state_batched, debugger=debugger
)

return state


Expand Down
15 changes: 5 additions & 10 deletions pennylane/devices/qubit/sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
def jax_random_split(prng_key, num: int = 2):
"""Get a new key with ``jax.random.split``."""
if prng_key is None:
return [None] * num
return (None,) * num
# pylint: disable=import-outside-toplevel
from jax.random import split

Expand Down Expand Up @@ -213,15 +213,11 @@ def measure_with_samples(
"""
# last N measurements are sampling MCMs in ``dynamic_one_shot`` execution mode
mps = measurements[0 : -len(mid_measurements)] if mid_measurements else measurements
skip_measure = any(v == -1 for v in mid_measurements.values()) if mid_measurements else False

groups, indices = _group_measurements(mps)

all_res = []
for group in groups:
if skip_measure:
all_res.extend([None] * len(group))
continue
if isinstance(group[0], ExpectationMP) and isinstance(
group[0].obs, (Hamiltonian, LinearCombination)
):
Expand Down Expand Up @@ -477,11 +473,10 @@ def sample_state(
# probabilities must be renormalized as they may not sum to one
# see https://github.com/PennyLaneAI/pennylane/issues/5444
norm = qml.math.sum(probs, axis=-1)
abs_diff = np.abs(norm - 1.0)
abs_diff = qml.math.abs(norm - 1.0)
cutoff = 1e-07

if is_state_batched:

normalize_condition = False

for s in abs_diff:
Expand All @@ -497,9 +492,9 @@ def sample_state(
# rng.choice doesn't support broadcasting
samples = np.stack([rng.choice(basis_states, shots, p=p) for p in probs])
else:

if 0 < abs_diff < cutoff:
probs /= norm
if not 0 < abs_diff < cutoff:
norm = 1.0
probs = probs / norm

samples = rng.choice(basis_states, shots, p=probs)

Expand Down
4 changes: 2 additions & 2 deletions pennylane/measurements/mid_measure.py
Original file line number Diff line number Diff line change
Expand Up @@ -443,10 +443,10 @@ def __ge__(self, other):
return self._transform_bin_op(lambda a, b: a >= b, other)

def __and__(self, other):
return self._transform_bin_op(lambda a, b: a and b, other)
return self._transform_bin_op(qml.math.logical_and, other)

def __or__(self, other):
return self._transform_bin_op(lambda a, b: a or b, other)
return self._transform_bin_op(qml.math.logical_or, other)

def _apply(self, fn):
"""Apply a post computation to this measurement"""
Expand Down
Loading

0 comments on commit 9f7e8ba

Please sign in to comment.