Skip to content

Commit

Permalink
Make GlobalPhase not differentiable (#5620)
Browse files Browse the repository at this point in the history
**Context:**
When using the following state preparation methods
(`AmplitudeEmbedding`, `StatePrep`, `MottonenStatePreparation`) with
`jit` and `grad`, the error `ValueError: need at least one array to
stack` was encountered.

**Description of the Change:**
All state preparation strategies used `GlobalPhase` under the hood,
which caused the above error. After this PR, `GlobalPhase` may not be
differentiable anymore, as its `grad_method` is set to `None`.

**Benefits:**

**Possible Drawbacks:**

**Related GitHub Issues:**
It fixes  #5541

---------

Co-authored-by: David Wierichs <david.wierichs@xanadu.ai>
Co-authored-by: Thomas R. Bromley <49409390+trbromley@users.noreply.github.com>
  • Loading branch information
3 people authored Jul 30, 2024
1 parent 20eed81 commit 6e122ae
Show file tree
Hide file tree
Showing 6 changed files with 48 additions and 10 deletions.
8 changes: 8 additions & 0 deletions doc/releases/changelog-dev.md
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,11 @@

<h3>Breaking changes 💔</h3>

* `GlobalPhase` is considered non-differentiable with tape transforms.
As a consequence, `qml.gradients.finite_diff` and `qml.gradients.spsa_grad` no longer
support differentiation of `GlobalPhase` with state-based outputs.
[(#5620)](https://github.com/PennyLaneAI/pennylane/pull/5620)

* The `CircuitGraph.graph` rustworkx graph now stores indices into the circuit as the node labels,
instead of the operator/ measurement itself. This allows the same operator to occur multiple times in
the circuit.
Expand Down Expand Up @@ -218,6 +223,8 @@
[(#5974)](https://github.com/PennyLaneAI/pennylane/pull/5974)

<h3>Bug fixes 🐛</h3>
* Fix `jax.grad` + `jax.jit` not working for `AmplitudeEmbedding`, `StatePrep` and `MottonenStatePreparation`.
[(#5620)](https://github.com/PennyLaneAI/pennylane/pull/5620)

* Fixed a bug in `qml.SPSAOptimizer` that ignored keyword arguments in the objective function.
[(#6027)](https://github.com/PennyLaneAI/pennylane/pull/6027)
Expand Down Expand Up @@ -246,6 +253,7 @@

This release contains contributions from (in alphabetical order):

Tarun Kumar Allamsetty,
Guillermo Alonso,
Utkarsh Azad,
Ahmed Darwish,
Expand Down
2 changes: 1 addition & 1 deletion pennylane/ops/identity.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,7 +282,7 @@ def circuit():
"""

grad_method = "A"
grad_method = None
num_params = 1
num_wires = AllWires
"""int: Number of wires that the operator acts on."""
Expand Down
1 change: 1 addition & 0 deletions tests/ops/qubit/test_parametric_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -2015,6 +2015,7 @@ def test_globalphase_autograd_grad(self, tol, dev_name, diff_method, wires):
@qml.qnode(dev, diff_method=diff_method)
def circuit(x):
qml.Identity(wires[0])
qml.GlobalPhase(x, wires=[0, 1]) # Does not change the derivative, but tests it
qml.Hadamard(wires[1])
qml.ctrl(qml.GlobalPhase(x), control=wires[1])
qml.Hadamard(wires[1])
Expand Down
25 changes: 25 additions & 0 deletions tests/templates/test_embeddings/test_amplitude.py
Original file line number Diff line number Diff line change
Expand Up @@ -584,3 +584,28 @@ def test_torch(self, tol, features, pad_with, dtype):
res2 = circuit2(features, pad_with, normalize=True)

assert qml.math.allclose(res, res2, atol=tol, rtol=0)


@pytest.mark.jax
@pytest.mark.parametrize("shots, atol", [(10000, 0.05), (None, 1e-8)])
def test_jacobian_with_and_without_jit_has_same_output(shots, atol):
"""Test that the jacobian of AmplitudeEmbedding is the same with and without jit."""

import jax

dev = qml.device("default.qubit", shots=shots, seed=7890234)

@qml.qnode(dev, diff_method="parameter-shift")
def circuit(coeffs):
qml.AmplitudeEmbedding(coeffs, normalize=True, wires=[0, 1])
return qml.expval(qml.PauliZ(0) @ qml.PauliZ(1))

params = jax.numpy.array([0.4, 0.5, 0.1, 0.3])
jac_fn = jax.jacobian(circuit)
jac_jit_fn = jax.jit(jac_fn)

jac = jac_fn(params)

jac_jit = jac_jit_fn(params)

assert qml.math.allclose(jac, jac_jit, atol=atol)
Original file line number Diff line number Diff line change
Expand Up @@ -430,16 +430,21 @@ def circuit(coeffs):
return qml.probs(wires=[0, 1])

circuit_fd = qml.QNode(circuit, dev, diff_method="finite-diff", h=0.05)
circuit_ps = qml.QNode(circuit, dev, diff_method="parameter-shift")
circuit_exact = qml.QNode(circuit, dev_no_shots)

params = jax.numpy.array([0.5, 0.5, 0.5, 0.5])
jac_exact_fn = jax.jacobian(circuit_exact)
jac_fn = jax.jacobian(circuit_fd)
jac_jit_fn = jax.jit(jac_fn)
jac_fd_fn = jax.jacobian(circuit_fd)
jac_fd_fn_jit = jax.jit(jac_fd_fn)
jac_ps_fn = jax.jacobian(circuit_ps)
jac_ps_fn_jit = jax.jit(jac_ps_fn)

jac_exact = jac_exact_fn(params)
jac = jac_fn(params)
jac_jit = jac_jit_fn(params)
jac_fd = jac_fd_fn(params)
jac_fd_jit = jac_fd_fn_jit(params)
jac_ps = jac_ps_fn(params)
jac_ps_jit = jac_ps_fn_jit(params)

assert qml.math.allclose(jac_exact, jac_jit, atol=atol)
assert qml.math.allclose(jac, jac_jit, atol=atol)
for compare in [jac_fd, jac_fd_jit, jac_ps, jac_ps_jit]:
assert qml.math.allclose(jac_exact, compare, atol=atol)
5 changes: 2 additions & 3 deletions tests/transforms/test_tape_expand.py
Original file line number Diff line number Diff line change
Expand Up @@ -336,9 +336,8 @@ class NonDiffPhaseShift(qml.PhaseShift):

assert new_tape.operations[0].name == "RZ"
assert new_tape.operations[0].grad_method == "A"
assert new_tape.operations[1].name == "GlobalPhase"
assert new_tape.operations[2].name == "RY"
assert new_tape.operations[3].name == "CNOT"
assert new_tape.operations[1].name == "RY"
assert new_tape.operations[2].name == "CNOT"

def test_nontrainable_nondiff(self, mocker):
"""Test that a circuit with non-differentiable
Expand Down

0 comments on commit 6e122ae

Please sign in to comment.