Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow qml.cond to fallback to the Python interpreter if a compiler is not available and there are no MCMs #6016

Merged
merged 29 commits into from
Aug 2, 2024
Merged
Show file tree
Hide file tree
Changes from 15 commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
08c5886
Allow `qml.for_loop` and `qml.while_loop` to fallback to the python i…
josh146 Jul 18, 2024
d48ba18
add tests
josh146 Jul 21, 2024
84548c2
Merge branch 'master' into josh146-patch-3
josh146 Jul 21, 2024
ea87c28
black test compiler
josh146 Jul 21, 2024
51ec3c0
Allow `qml.cond` to fallback to the Python interpreter if a compiler …
josh146 Jul 21, 2024
b8e6f9f
linting and changelog
josh146 Jul 21, 2024
93afc09
changelog
josh146 Jul 21, 2024
e6bbaf9
changelog
josh146 Jul 21, 2024
8abdf7f
merge
josh146 Jul 21, 2024
e7fbe71
linting
josh146 Jul 21, 2024
3d43fe8
coverage
josh146 Jul 21, 2024
fd03214
test
josh146 Jul 21, 2024
9f9acb7
lint
josh146 Jul 21, 2024
c2aee7d
isort
josh146 Jul 21, 2024
ba83c8d
remove device test
josh146 Jul 21, 2024
8877f38
Merge branch 'master' into josh146-patch-3
josh146 Jul 22, 2024
f31d153
Apply suggestions from code review
josh146 Jul 23, 2024
ee25322
Update pennylane/compiler/qjit_api.py
josh146 Jul 23, 2024
01428a5
Update pennylane/ops/op_math/condition.py
josh146 Jul 23, 2024
ab489ab
suggested changes
josh146 Jul 23, 2024
f833dd3
suggested changes - test
josh146 Jul 23, 2024
fcb2368
Merge branch 'master' into josh146-patch-3
josh146 Jul 23, 2024
a4b7ceb
merge main
josh146 Jul 23, 2024
2f33782
merge master
josh146 Jul 24, 2024
3daab62
merge master
josh146 Aug 1, 2024
99a1454
Incorporate the 'captured' and 'non-captured' versions of `qml.cond` …
PietropaoloFrisoni Aug 1, 2024
3e186af
Merge branch 'master' into cond-python-fallback
josh146 Aug 1, 2024
eea07ac
Merge branch 'master' into cond-python-fallback
josh146 Aug 2, 2024
a659340
suggested change
josh146 Aug 2, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 69 additions & 0 deletions doc/releases/changelog-dev.md
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,74 @@
* Molecules and Hamiltonians can now be constructed for all the elements present in the periodic table.
[(#5821)](https://github.com/PennyLaneAI/pennylane/pull/5821)

* `qml.for_loop` and `qml.while_loop` now fallback to standard Python control
flow if `@qjit` is not present, allowing the same code to work with and without
`@qjit` without any rewrites.
[(#6014)](https://github.com/PennyLaneAI/pennylane/pull/6014)

```python
dev = qml.device("lightning.qubit", wires=3)

@qml.qnode(dev)
def circuit(x, n):

@qml.for_loop(0, n, 1)
def init_state(i):
qml.Hadamard(wires=i)

init_state()

@qml.for_loop(0, n, 1)
def apply_operations(i, x):
qml.RX(x, wires=i)

@qml.for_loop(i + 1, n, 1)
def inner(j):
qml.CRY(x**2, [i, j])

inner()
return jnp.sin(x)

apply_operations(x)
return qml.probs()
```

```pycon
>>> print(qml.draw(circuit)(0.5, 3))
0: ──H──RX(0.50)─╭●────────╭●──────────────────────────────────────┤ Probs
1: ──H───────────╰RY(0.25)─│──────────RX(0.48)─╭●──────────────────┤ Probs
2: ──H─────────────────────╰RY(0.25)───────────╰RY(0.23)──RX(0.46)─┤ Probs
>>> circuit(0.5, 3)
array([0.125 , 0.125 , 0.09949758, 0.15050242, 0.07594666,
0.11917543, 0.08942104, 0.21545687])
>>> qml.qjit(circuit)(0.5, 3)
Array([0.125 , 0.125 , 0.09949758, 0.15050242, 0.07594666,
0.11917543, 0.08942104, 0.21545687], dtype=float64)
```

* If the conditional does not include a mid-circuit measurement, then `qml.cond`
will automatically evaluate conditionals using standard Python control flow.
[(#6016)](https://github.com/PennyLaneAI/pennylane/pull/6016)

This allows `qml.cond` to be used to represent a wider range of conditionals:
dime10 marked this conversation as resolved.
Show resolved Hide resolved

```python
dev = qml.device("default.qubit", wires=1)

@qml.qnode(dev)
def circuit(x):
c = qml.cond(x > 2.7, qml.RX, qml.RZ)
c(x, wires=0)
return qml.probs(wires=0)
```

```pycon
>>> print(qml.draw(circuit)(3.8))
0: ──RX(3.80)─┤ Probs
>>> print(qml.draw(circuit)(0.54))
0: ──RZ(0.54)─┤ Probs
```

<h4>Community contributions 🥳</h4>

* `DefaultQutritMixed` readout error has been added using parameters `readout_relaxation_probs` and
Expand Down Expand Up @@ -151,6 +219,7 @@ Lillian M. A. Frederiksen,
Pietropaolo Frisoni,
Emiliano Godinez,
Renke Huang,
Josh Izaac,
Christina Lee,
Austin Huang,
Christina Lee,
Expand Down
76 changes: 70 additions & 6 deletions pennylane/compiler/qjit_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -377,7 +377,37 @@ def loop_rx(x):
ops_loader = compilers[active_jit]["ops"].load()
return ops_loader.while_loop(cond_fn)

raise CompileError("There is no active compiler package.") # pragma: no cover
# if there is no active compiler, simply interpret the while loop
# via the Python interpretor.
def _decorator(body_fn):
return WhileLoopCallable(cond_fn, body_fn)

return _decorator


class WhileLoopCallable: # pylint:disable=too-few-public-methods
"""Base class to represent a while loop. This class
when called with an initial state will execute the while
loop via the Python interpreter.

Args:
cond_fn (Callable): the condition function in the while loop
body_fn (Callable): the function that is executed within the while loop
"""

def __init__(self, cond_fn, body_fn):
self.cond_fn = cond_fn
self.body_fn = body_fn

def __call__(self, *init_state):
args = init_state
fn_res = args if len(args) > 1 else args[0] if len(args) == 1 else None

while self.cond_fn(*args):
fn_res = self.body_fn(*args)
args = fn_res if len(args) > 1 else (fn_res,) if len(args) == 1 else ()

return fn_res


def for_loop(lower_bound, upper_bound, step):
Expand Down Expand Up @@ -430,14 +460,10 @@ def for_loop(lower_bound, upper_bound, step, loop_fn, *args):
across iterations is handled automatically by the provided loop bounds, it must not be
returned from the function.

Raises:
CompileError: if the compiler is not installed

.. seealso:: :func:`~.while_loop`, :func:`~.qjit`

**Example**


.. code-block:: python

dev = qml.device("lightning.qubit", wires=1)
Expand Down Expand Up @@ -468,4 +494,42 @@ def loop_rx(i, x):
ops_loader = compilers[active_jit]["ops"].load()
return ops_loader.for_loop(lower_bound, upper_bound, step)

raise CompileError("There is no active compiler package.") # pragma: no cover
# if there is no active compiler, simply interpret the for loop
# via the Python interpretor.
def _decorator(body_fn):
return ForLoopCallable(lower_bound, upper_bound, step, body_fn)

return _decorator


class ForLoopCallable: # pylint:disable=too-few-public-methods
"""Base class to represent a for loop. This class
when called with an initial state will execute the while
loop via the Python interpreter.

Args:
lower_bound (int): starting value of the iteration index
upper_bound (int): (exclusive) upper bound of the iteration index
step (int): increment applied to the iteration index at the end of each iteration
body_fn (Callable): The function called within the for loop. Note that the loop body
function must always have the iteration index as its first
argument, which can be used arbitrarily inside the loop body. As the value of the index
across iterations is handled automatically by the provided loop bounds, it must not be
returned from the function.
"""

def __init__(self, lower_bound, upper_bound, step, body_fn):
self.lower_bound = lower_bound
self.upper_bound = upper_bound
self.step = step
self.body_fn = body_fn

def __call__(self, *init_state):
args = init_state
fn_res = args if len(args) > 1 else args[0] if len(args) == 1 else None

for i in range(self.lower_bound, self.upper_bound, self.step):
fn_res = self.body_fn(i, *args)
args = fn_res if len(args) > 1 else (fn_res,) if len(args) == 1 else ()

return fn_res
137 changes: 132 additions & 5 deletions pennylane/ops/op_math/condition.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

from pennylane import QueuingManager
from pennylane.compiler import compiler
from pennylane.measurements import MeasurementValue
from pennylane.operation import AnyWires, Operation, Operator
from pennylane.ops.op_math.symbolicop import SymbolicOp
from pennylane.tape import make_qscript
Expand Down Expand Up @@ -100,7 +101,111 @@ def adjoint(self):
return Conditional(self.meas_val, self.base.adjoint())


def cond(condition, true_fn, false_fn=None, elifs=()):
class CondCallable: # pylint:disable=too-few-public-methods
"""Base class to represent a conditional function with boolean predicates.

Args:
condition (bool): a conditional expression
true_fn (callable): The function to apply if ``condition`` is ``True``
false_fn (callable): The function to apply if ``condition`` is ``False``
elifs (List(Tuple(bool, callable))): A list of (bool, elif_fn) clauses.
josh146 marked this conversation as resolved.
Show resolved Hide resolved

Passing ``true_fn``, ``false_fn``, and ``elifs`` on initialization
josh146 marked this conversation as resolved.
Show resolved Hide resolved
is optional; these functions can be registered post-initialization
via decorators:

.. code-block:: python

def f(x):
@qml.cond(x > 0)
def conditional(y):
return y ** 2

@conditional.else_if(x < -2)
def conditional(y):
return y

@conditional.otherwise
def conditional_false_fn(y):
return -y

return conditional(x + 1)

>>> [f(0.5), f(-3), f(-0.5)]
[2.25, -2, -0.5]
"""

def __init__(self, condition, true_fn, false_fn=None, elifs=()):
self.preds = [condition]
self.branch_fns = [true_fn]
self.otherwise_fn = false_fn

if false_fn is None:
self.otherwise_fn = lambda *args, **kwargs: None
josh146 marked this conversation as resolved.
Show resolved Hide resolved

if elifs:
elif_preds, elif_fns = list(zip(*elifs))
self.preds.extend(elif_preds)
self.branch_fns.extend(elif_fns)

def else_if(self, pred):
"""Decorator that allows else-if functions to be registered with a corresponding
boolean predicate.

Args:
pred (bool): The predicate that will determine if this branch is executed.

Returns:
callable: decorator that is applied to the else-if function
"""

def decorator(branch_fn):
self.preds.append(pred)
self.branch_fns.append(branch_fn)
return self

return decorator

def otherwise(self, otherwise_fn):
"""Decorator that registers the function to be run if all
conditional predicates (including optional) evaluates to ``False``.

Args:
otherwise_fn (callable): the function to apply if all ``self.preds`` evaluate to ``False``
"""
self.otherwise_fn = otherwise_fn
return self

@property
def false_fn(self):
"""callable: the function to apply if all ``self.preds`` evaluate to ``False``"""
return self.otherwise_fn

@property
def true_fn(self):
"""callable: the function to apply if all ``self.condition`` evaluate to ``True``"""
return self.branch_fns[0]

@property
def condition(self):
"""bool: the condition that determines if ``self.true_fn`` is applied"""
return self.preds[0]

@property
def elifs(self):
"""(List(Tuple(bool, callable))): a list of (bool, elif_fn) clauses"""
return list(zip(self.preds[1:], self.branch_fns[1:]))

def __call__(self, *args, **kwargs):
# python fallback
for pred, branch_fn in zip(self.preds, self.branch_fns):
if pred:
return branch_fn(*args, **kwargs)

return self.false_fn(*args, **kwargs) # pylint: disable=not-callable


def cond(condition, true_fn=None, false_fn=None, elifs=()):
"""Quantum-compatible if-else conditionals --- condition quantum operations
on parameters such as the results of mid-circuit qubit measurements.

Expand Down Expand Up @@ -128,14 +233,14 @@ def cond(condition, true_fn, false_fn=None, elifs=()):

Args:
condition (Union[.MeasurementValue, bool]): a conditional expression involving a mid-circuit
measurement value (see :func:`.pennylane.measure`). This can only be of type ``bool`` when
decorated by :func:`~.qjit`.
measurement value (see :func:`.pennylane.measure`).
josh146 marked this conversation as resolved.
Show resolved Hide resolved
true_fn (callable): The quantum function or PennyLane operation to
apply if ``condition`` is ``True``
false_fn (callable): The quantum function or PennyLane operation to
apply if ``condition`` is ``False``
elifs (List(Tuple(bool, callable))): A list of (bool, elif_fn) clauses. Can only
be used when decorated by :func:`~.qjit`.
be used when decorated by :func:`~.qjit` or if the condition is not
a mid-circuit measurement.

Returns:
function: A new function that applies the conditional equivalent of ``true_fn``. The returned
Expand Down Expand Up @@ -367,6 +472,10 @@ def qnode(a, x, y, z):
if active_jit := compiler.active_compiler():
available_eps = compiler.AvailableCompilers.names_entrypoints
ops_loader = available_eps[active_jit]["ops"].load()

if true_fn is None:
return ops_loader.cond(condition)

cond_func = ops_loader.cond(condition)(true_fn)

# Optional 'elif' branches
Expand All @@ -379,8 +488,26 @@ def qnode(a, x, y, z):

return cond_func

if not isinstance(condition, MeasurementValue):
# The condition is not a mid-circuit measurement.
if true_fn is None:
return lambda fn: CondCallable(condition, fn)

return CondCallable(condition, true_fn, false_fn, elifs)

if true_fn is None:
raise TypeError(
"cond missing 1 required positional argument: 'true_fn'.\n"
"Note that if the conditional includes a mid-circuit measurement, "
"qml.cond cannot be used as a decorator.\n"
"Instead, please use the form qml.cond(condition, true_fn, false_fn)."
)

if elifs:
raise ConditionalTransformError("'elif' branches are not supported in interpreted mode.")
raise ConditionalTransformError(
"'elif' branches are not supported when not using @qjit and the "
"conditional include mid-circuit measurements."
)

if callable(true_fn):
# We assume that the callable is an operation or a quantum function
Expand Down
Loading
Loading