diff --git a/doc/releases/changelog-dev.md b/doc/releases/changelog-dev.md index 73c266f2833..90dbe86a773 100644 --- a/doc/releases/changelog-dev.md +++ b/doc/releases/changelog-dev.md @@ -105,6 +105,9 @@ * The ``qml.Qubitization`` template now orders the ``control`` wires first and the ``hamiltonian`` wires second, which is the expected according to other templates. [(#6229)](https://github.com/PennyLaneAI/pennylane/pull/6229) +* The ``qml.FABLE`` template now returns the correct value when JIT is enabled. + [(#6263)](https://github.com/PennyLaneAI/pennylane/pull/6263) + *

Contributors ✍️

This release contains contributions from (in alphabetical order): diff --git a/pennylane/templates/subroutines/fable.py b/pennylane/templates/subroutines/fable.py index 16f1160ddd0..d9637676738 100644 --- a/pennylane/templates/subroutines/fable.py +++ b/pennylane/templates/subroutines/fable.py @@ -166,17 +166,19 @@ def compute_decomposition(input_matrix, wires, tol=0): # pylint:disable=argumen for c_wire in nots: op_list.append(qml.CNOT(wires=[c_wire] + ancilla)) op_list.append(qml.RY(2 * theta, wires=ancilla)) + nots = {} nots[wire_map[control_index]] = 1 + continue + + if qml.math.abs(2 * theta) > tol: + for c_wire in nots: + op_list.append(qml.CNOT(wires=[c_wire] + ancilla)) + op_list.append(qml.RY(2 * theta, wires=ancilla)) + nots = {} + if wire_map[control_index] in nots: + del nots[wire_map[control_index]] else: - if abs(2 * theta) > tol: - for c_wire in nots: - op_list.append(qml.CNOT(wires=[c_wire] + ancilla)) - op_list.append(qml.RY(2 * theta, wires=ancilla)) - nots = {} - if wire_map[control_index] in nots: - del nots[wire_map[control_index]] - else: - nots[wire_map[control_index]] = 1 + nots[wire_map[control_index]] = 1 for c_wire in nots: op_list.append(qml.CNOT([c_wire] + ancilla)) diff --git a/tests/templates/test_subroutines/test_fable.py b/tests/templates/test_subroutines/test_fable.py index 8649fe71748..d2ba5f2496a 100644 --- a/tests/templates/test_subroutines/test_fable.py +++ b/tests/templates/test_subroutines/test_fable.py @@ -235,7 +235,7 @@ def circuit_jax(input_matrix): assert np.allclose(gradient_numeric, gradient_jax[0, 0], rtol=0.001) @pytest.mark.jax - def test_fable_grad_jax_jit(self, input_matrix): + def test_fable_jax_jit(self, input_matrix): """Test that FABLE is differentiable when using jax.""" import jax import jax.numpy as jnp @@ -272,18 +272,21 @@ def test_fable_grad_jax_jit(self, input_matrix): input_jax_negative_delta = jnp.array(input_negative_delta) input_matrix_jax = jnp.array(input_matrix) - @jax.jit @qml.qnode(dev, diff_method="backprop") def circuit_jax(input_matrix): qml.FABLE(input_matrix, wires=range(5), tol=0) return qml.expval(qml.PauliZ(wires=0)) - grad_fn = jax.grad(circuit_jax) + jitted_fn = jax.jit(circuit_jax) + + grad_fn = jax.grad(jitted_fn) gradient_numeric = ( circuit_jax(input_jax_positive_delta) - circuit_jax(input_jax_negative_delta) ) / (2 * delta) gradient_jax = grad_fn(input_matrix_jax) - assert np.allclose(gradient_numeric, gradient_jax[0, 0], rtol=0.001) + + assert qml.math.allclose(gradient_numeric, gradient_jax[0, 0], rtol=0.001) + assert qml.math.allclose(jitted_fn(input_matrix), circuit_jax(input_matrix)) @pytest.mark.jax def test_fable_grad_jax_jit_error(self, input_matrix):