diff --git a/doc/releases/changelog-dev.md b/doc/releases/changelog-dev.md
index 73c266f2833..90dbe86a773 100644
--- a/doc/releases/changelog-dev.md
+++ b/doc/releases/changelog-dev.md
@@ -105,6 +105,9 @@
* The ``qml.Qubitization`` template now orders the ``control`` wires first and the ``hamiltonian`` wires second, which is the expected according to other templates.
[(#6229)](https://github.com/PennyLaneAI/pennylane/pull/6229)
+* The ``qml.FABLE`` template now returns the correct value when JIT is enabled.
+ [(#6263)](https://github.com/PennyLaneAI/pennylane/pull/6263)
+
*
Contributors ✍️
This release contains contributions from (in alphabetical order):
diff --git a/pennylane/templates/subroutines/fable.py b/pennylane/templates/subroutines/fable.py
index 16f1160ddd0..d9637676738 100644
--- a/pennylane/templates/subroutines/fable.py
+++ b/pennylane/templates/subroutines/fable.py
@@ -166,17 +166,19 @@ def compute_decomposition(input_matrix, wires, tol=0): # pylint:disable=argumen
for c_wire in nots:
op_list.append(qml.CNOT(wires=[c_wire] + ancilla))
op_list.append(qml.RY(2 * theta, wires=ancilla))
+ nots = {}
nots[wire_map[control_index]] = 1
+ continue
+
+ if qml.math.abs(2 * theta) > tol:
+ for c_wire in nots:
+ op_list.append(qml.CNOT(wires=[c_wire] + ancilla))
+ op_list.append(qml.RY(2 * theta, wires=ancilla))
+ nots = {}
+ if wire_map[control_index] in nots:
+ del nots[wire_map[control_index]]
else:
- if abs(2 * theta) > tol:
- for c_wire in nots:
- op_list.append(qml.CNOT(wires=[c_wire] + ancilla))
- op_list.append(qml.RY(2 * theta, wires=ancilla))
- nots = {}
- if wire_map[control_index] in nots:
- del nots[wire_map[control_index]]
- else:
- nots[wire_map[control_index]] = 1
+ nots[wire_map[control_index]] = 1
for c_wire in nots:
op_list.append(qml.CNOT([c_wire] + ancilla))
diff --git a/tests/templates/test_subroutines/test_fable.py b/tests/templates/test_subroutines/test_fable.py
index 8649fe71748..d2ba5f2496a 100644
--- a/tests/templates/test_subroutines/test_fable.py
+++ b/tests/templates/test_subroutines/test_fable.py
@@ -235,7 +235,7 @@ def circuit_jax(input_matrix):
assert np.allclose(gradient_numeric, gradient_jax[0, 0], rtol=0.001)
@pytest.mark.jax
- def test_fable_grad_jax_jit(self, input_matrix):
+ def test_fable_jax_jit(self, input_matrix):
"""Test that FABLE is differentiable when using jax."""
import jax
import jax.numpy as jnp
@@ -272,18 +272,21 @@ def test_fable_grad_jax_jit(self, input_matrix):
input_jax_negative_delta = jnp.array(input_negative_delta)
input_matrix_jax = jnp.array(input_matrix)
- @jax.jit
@qml.qnode(dev, diff_method="backprop")
def circuit_jax(input_matrix):
qml.FABLE(input_matrix, wires=range(5), tol=0)
return qml.expval(qml.PauliZ(wires=0))
- grad_fn = jax.grad(circuit_jax)
+ jitted_fn = jax.jit(circuit_jax)
+
+ grad_fn = jax.grad(jitted_fn)
gradient_numeric = (
circuit_jax(input_jax_positive_delta) - circuit_jax(input_jax_negative_delta)
) / (2 * delta)
gradient_jax = grad_fn(input_matrix_jax)
- assert np.allclose(gradient_numeric, gradient_jax[0, 0], rtol=0.001)
+
+ assert qml.math.allclose(gradient_numeric, gradient_jax[0, 0], rtol=0.001)
+ assert qml.math.allclose(jitted_fn(input_matrix), circuit_jax(input_matrix))
@pytest.mark.jax
def test_fable_grad_jax_jit_error(self, input_matrix):