Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow qml.for_loop and qml.while_loop to fallback to the Python interpreter if a compiler is not available. #6014

Merged
merged 12 commits into from
Jul 24, 2024
46 changes: 46 additions & 0 deletions doc/releases/changelog-dev.md
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,51 @@
* Molecules and Hamiltonians can now be constructed for all the elements present in the periodic table.
[(#5821)](https://github.com/PennyLaneAI/pennylane/pull/5821)

* `qml.for_loop` and `qml.while_loop` now fallback to standard Python control
flow if `@qjit` is not present, allowing the same code to work with and without
`@qjit` without any rewrites.
[(#6014)](https://github.com/PennyLaneAI/pennylane/pull/6014)

```python
dev = qml.device("lightning.qubit", wires=3)

@qml.qnode(dev)
def circuit(x, n):

@qml.for_loop(0, n, 1)
def init_state(i):
qml.Hadamard(wires=i)

init_state()

@qml.for_loop(0, n, 1)
def apply_operations(i, x):
qml.RX(x, wires=i)

@qml.for_loop(i + 1, n, 1)
def inner(j):
qml.CRY(x**2, [i, j])

inner()
return jnp.sin(x)

apply_operations(x)
return qml.probs()
```

```pycon
>>> print(qml.draw(circuit)(0.5, 3))
0: ──H──RX(0.50)─╭●────────╭●──────────────────────────────────────┤ Probs
1: ──H───────────╰RY(0.25)─│──────────RX(0.48)─╭●──────────────────┤ Probs
2: ──H─────────────────────╰RY(0.25)───────────╰RY(0.23)──RX(0.46)─┤ Probs
>>> circuit(0.5, 3)
array([0.125 , 0.125 , 0.09949758, 0.15050242, 0.07594666,
0.11917543, 0.08942104, 0.21545687])
>>> qml.qjit(circuit)(0.5, 3)
Array([0.125 , 0.125 , 0.09949758, 0.15050242, 0.07594666,
0.11917543, 0.08942104, 0.21545687], dtype=float64)
```

<h4>Community contributions 🥳</h4>

* `DefaultQutritMixed` readout error has been added using parameters `readout_relaxation_probs` and
Expand Down Expand Up @@ -153,6 +198,7 @@ Lillian M. A. Frederiksen,
Pietropaolo Frisoni,
Emiliano Godinez,
Renke Huang,
Josh Izaac,
Christina Lee,
Austin Huang,
Christina Lee,
Expand Down
76 changes: 70 additions & 6 deletions pennylane/compiler/qjit_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -377,7 +377,37 @@ def loop_rx(x):
ops_loader = compilers[active_jit]["ops"].load()
return ops_loader.while_loop(cond_fn)

raise CompileError("There is no active compiler package.") # pragma: no cover
# if there is no active compiler, simply interpret the while loop
# via the Python interpretor.
def _decorator(body_fn):
josh146 marked this conversation as resolved.
Show resolved Hide resolved
return WhileLoopCallable(cond_fn, body_fn)

return _decorator


class WhileLoopCallable: # pylint:disable=too-few-public-methods
"""Base class to represent a while loop. This class
when called with an initial state will execute the while
loop via the Python interpreter.

Args:
cond_fn (Callable): the condition function in the while loop
body_fn (Callable): the function that is executed within the while loop
"""

def __init__(self, cond_fn, body_fn):
self.cond_fn = cond_fn
self.body_fn = body_fn

def __call__(self, *init_state):
args = init_state
fn_res = args if len(args) > 1 else args[0] if len(args) == 1 else None

while self.cond_fn(*args):
fn_res = self.body_fn(*args)
args = fn_res if len(args) > 1 else (fn_res,) if len(args) == 1 else ()

return fn_res


def for_loop(lower_bound, upper_bound, step):
Expand Down Expand Up @@ -430,14 +460,10 @@ def for_loop(lower_bound, upper_bound, step, loop_fn, *args):
across iterations is handled automatically by the provided loop bounds, it must not be
returned from the function.

Raises:
CompileError: if the compiler is not installed

.. seealso:: :func:`~.while_loop`, :func:`~.qjit`

**Example**


albi3ro marked this conversation as resolved.
Show resolved Hide resolved
.. code-block:: python

dev = qml.device("lightning.qubit", wires=1)
Expand Down Expand Up @@ -468,4 +494,42 @@ def loop_rx(i, x):
ops_loader = compilers[active_jit]["ops"].load()
return ops_loader.for_loop(lower_bound, upper_bound, step)

raise CompileError("There is no active compiler package.") # pragma: no cover
# if there is no active compiler, simply interpret the for loop
# via the Python interpretor.
def _decorator(body_fn):
josh146 marked this conversation as resolved.
Show resolved Hide resolved
return ForLoopCallable(lower_bound, upper_bound, step, body_fn)

return _decorator


class ForLoopCallable: # pylint:disable=too-few-public-methods
"""Base class to represent a for loop. This class
when called with an initial state will execute the while
loop via the Python interpreter.

Args:
lower_bound (int): starting value of the iteration index
upper_bound (int): (exclusive) upper bound of the iteration index
step (int): increment applied to the iteration index at the end of each iteration
body_fn (Callable): The function called within the for loop. Note that the loop body
function must always have the iteration index as its first
argument, which can be used arbitrarily inside the loop body. As the value of the index
across iterations is handled automatically by the provided loop bounds, it must not be
returned from the function.
"""

def __init__(self, lower_bound, upper_bound, step, body_fn):
self.lower_bound = lower_bound
self.upper_bound = upper_bound
self.step = step
self.body_fn = body_fn

def __call__(self, *init_state):
args = init_state
fn_res = args if len(args) > 1 else args[0] if len(args) == 1 else None

for i in range(self.lower_bound, self.upper_bound, self.step):
fn_res = self.body_fn(i, *args)
args = fn_res if len(args) > 1 else (fn_res,) if len(args) == 1 else ()

return fn_res
69 changes: 69 additions & 0 deletions tests/test_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -353,6 +353,24 @@ def inner(j):
assert circuit(5, 6) == 30 # 5 * 6
assert circuit(4, 7) == 28 # 4 * 7

def test_while_loop_python_fallback(self):
"""Test that qml.while_loop fallsback to
Python without qjit"""

def f(n, m):
@qml.while_loop(lambda i, _: i < n)
def outer(i, sm):
@qml.while_loop(lambda j: j < m)
def inner(j):
albi3ro marked this conversation as resolved.
Show resolved Hide resolved
return j + 1

return i + 1, sm + inner(0)

return outer(0, 0)[1]

assert f(5, 6) == 30 # 5 * 6
assert f(4, 7) == 28 # 4 * 7

def test_dynamic_wires_for_loops(self):
"""Test for loops with iteration index-dependant wires."""
dev = qml.device("lightning.qubit", wires=6)
Expand Down Expand Up @@ -405,6 +423,57 @@ def inner(j):

assert jnp.allclose(circuit(4), jnp.eye(2**4)[0])

def test_for_loop_python_fallback(self):
"""Test that qml.for_loop fallsback to Python
interpretation if Catalyst is not available"""
dev = qml.device("lightning.qubit", wires=2)

@qml.qnode(dev)
def circuit(x, n):

# for loop with dynamic bounds
@qml.for_loop(0, n, 1)
def loop_fn(i):
qml.Hadamard(wires=i)

# nested for loops.
# outer for loop updates x
@qml.for_loop(0, n, 1)
def loop_fn_returns(i, x):
qml.RX(x, wires=i)

# inner for loop
@qml.for_loop(i + 1, n, 1)
def inner(j):
qml.CRY(x**2, [i, j])

inner()

return jnp.sin(x)

loop_fn()
loop_fn_returns(x)

return qml.expval(qml.PauliZ(0))

x = 0.5
assert jnp.allclose(circuit(x, 2), qml.qjit(circuit)(x, 2))

res = circuit.tape.operations
expected = [
qml.Hadamard(wires=[0]),
qml.Hadamard(wires=[1]),
qml.Hadamard(wires=[2]),
qml.RX(0.5, wires=[0]),
qml.CRY(0.25, wires=[0, 1]),
qml.CRY(0.25, wires=[0, 2]),
qml.RX(0.6, wires=[1]),
qml.CRY(0.36, wires=[1, 2]),
qml.RX(0.7, wires=[2]),
]

assert [qml.equal(i, j) for i, j in zip(res, expected)]
josh146 marked this conversation as resolved.
Show resolved Hide resolved

def test_cond(self):
"""Test condition with simple true_fn"""
dev = qml.device("lightning.qubit", wires=1)
Expand Down
Loading