Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make GlobalPhase not differentiable #5620

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
e2f138b
Make GlobalPhase not differentiable
Tarun-Kumar07 May 2, 2024
c974440
Add entry to changelog-dev.md
Tarun-Kumar07 May 2, 2024
b377658
In Controlled handle `grad_method` differently when base is `GlobalPh…
Tarun-Kumar07 May 2, 2024
3230596
Fix core-tests
Tarun-Kumar07 May 2, 2024
10edc7a
Add seed to unittests
Tarun-Kumar07 May 2, 2024
e4db193
Update changelog-dev.md
Tarun-Kumar07 May 3, 2024
a33fde9
Set Make GlobalPhase `grad_method` to finite difference
Tarun-Kumar07 May 3, 2024
f2959ed
Revert "In Controlled handle `grad_method` differently when base is `…
Tarun-Kumar07 May 3, 2024
477d8dd
Merge branch 'master' into amplitude-embedding-jit-grad-bug
Tarun-Kumar07 May 3, 2024
af22ece
Revert "Fix core-tests"
Tarun-Kumar07 May 3, 2024
9214385
Merge branch 'master' into amplitude-embedding-jit-grad-bug
Tarun-Kumar07 May 4, 2024
6dc48c9
Merge branch 'master' into amplitude-embedding-jit-grad-bug
dwierichs May 21, 2024
56189a0
Merge branch 'master' into amplitude-embedding-jit-grad-bug
dwierichs Jun 13, 2024
9f85c03
Merge branch 'master' into amplitude-embedding-jit-grad-bug
dwierichs Jul 25, 2024
5cec5f0
changelog update
dwierichs Jul 25, 2024
5fd3336
update test, add test elsewhere
dwierichs Jul 25, 2024
e4b4bed
changelog, bug fix
dwierichs Jul 28, 2024
4164ad2
Merge branch 'master' into amplitude-embedding-jit-grad-bug
dwierichs Jul 28, 2024
8ec5a77
Update doc/releases/changelog-dev.md
dwierichs Jul 29, 2024
218c552
Merge branch 'master' into amplitude-embedding-jit-grad-bug
dwierichs Jul 29, 2024
c3bcb67
Merge branch 'master' into amplitude-embedding-jit-grad-bug
dwierichs Jul 29, 2024
a7d2aeb
Merge branch 'master' into amplitude-embedding-jit-grad-bug
dwierichs Jul 30, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions doc/releases/changelog-dev.md
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,11 @@

<h3>Breaking changes 💔</h3>

* `GlobalPhase` is considered non-differentiable with tape transforms.
As a consequence, `qml.gradients.finite_diff` and `qml.gradients.spsa_grad` no longer
support differentiation of `GlobalPhase` with state-based outputs.
[(#5620)](https://github.com/PennyLaneAI/pennylane/pull/5620)

* The `CircuitGraph.graph` rustworkx graph now stores indices into the circuit as the node labels,
instead of the operator/ measurement itself. This allows the same operator to occur multiple times in
the circuit.
Expand Down Expand Up @@ -218,6 +223,8 @@
[(#5974)](https://github.com/PennyLaneAI/pennylane/pull/5974)

<h3>Bug fixes 🐛</h3>
* Fix `jax.grad` + `jax.jit` not working for `AmplitudeEmbedding`, `StatePrep` and `MottonenStatePreparation`.
[(#5620)](https://github.com/PennyLaneAI/pennylane/pull/5620)

* Fixed a bug in `qml.SPSAOptimizer` that ignored keyword arguments in the objective function.
[(#6027)](https://github.com/PennyLaneAI/pennylane/pull/6027)
Expand Down Expand Up @@ -246,6 +253,7 @@

This release contains contributions from (in alphabetical order):

Tarun Kumar Allamsetty,
Guillermo Alonso,
Utkarsh Azad,
Ahmed Darwish,
Expand Down
2 changes: 1 addition & 1 deletion pennylane/ops/identity.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,7 +282,7 @@ def circuit():

"""

grad_method = "A"
grad_method = None
Tarun-Kumar07 marked this conversation as resolved.
Show resolved Hide resolved
num_params = 1
num_wires = AllWires
"""int: Number of wires that the operator acts on."""
Expand Down
1 change: 1 addition & 0 deletions tests/ops/qubit/test_parametric_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -2015,6 +2015,7 @@ def test_globalphase_autograd_grad(self, tol, dev_name, diff_method, wires):
@qml.qnode(dev, diff_method=diff_method)
def circuit(x):
qml.Identity(wires[0])
qml.GlobalPhase(x, wires=[0, 1]) # Does not change the derivative, but tests it
dwierichs marked this conversation as resolved.
Show resolved Hide resolved
qml.Hadamard(wires[1])
qml.ctrl(qml.GlobalPhase(x), control=wires[1])
qml.Hadamard(wires[1])
Expand Down
25 changes: 25 additions & 0 deletions tests/templates/test_embeddings/test_amplitude.py
Original file line number Diff line number Diff line change
Expand Up @@ -584,3 +584,28 @@ def test_torch(self, tol, features, pad_with, dtype):
res2 = circuit2(features, pad_with, normalize=True)

assert qml.math.allclose(res, res2, atol=tol, rtol=0)


@pytest.mark.jax
@pytest.mark.parametrize("shots, atol", [(10000, 0.05), (None, 1e-8)])
def test_jacobian_with_and_without_jit_has_same_output(shots, atol):
"""Test that the jacobian of AmplitudeEmbedding is the same with and without jit."""

import jax

dev = qml.device("default.qubit", shots=shots, seed=7890234)

@qml.qnode(dev, diff_method="parameter-shift")
def circuit(coeffs):
qml.AmplitudeEmbedding(coeffs, normalize=True, wires=[0, 1])
return qml.expval(qml.PauliZ(0) @ qml.PauliZ(1))

params = jax.numpy.array([0.4, 0.5, 0.1, 0.3])
jac_fn = jax.jacobian(circuit)
jac_jit_fn = jax.jit(jac_fn)

jac = jac_fn(params)

jac_jit = jac_jit_fn(params)

assert qml.math.allclose(jac, jac_jit, atol=atol)
dwierichs marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
Expand Up @@ -430,16 +430,21 @@ def circuit(coeffs):
return qml.probs(wires=[0, 1])

circuit_fd = qml.QNode(circuit, dev, diff_method="finite-diff", h=0.05)
circuit_ps = qml.QNode(circuit, dev, diff_method="parameter-shift")
circuit_exact = qml.QNode(circuit, dev_no_shots)

params = jax.numpy.array([0.5, 0.5, 0.5, 0.5])
jac_exact_fn = jax.jacobian(circuit_exact)
jac_fn = jax.jacobian(circuit_fd)
jac_jit_fn = jax.jit(jac_fn)
jac_fd_fn = jax.jacobian(circuit_fd)
jac_fd_fn_jit = jax.jit(jac_fd_fn)
jac_ps_fn = jax.jacobian(circuit_ps)
jac_ps_fn_jit = jax.jit(jac_ps_fn)

jac_exact = jac_exact_fn(params)
jac = jac_fn(params)
jac_jit = jac_jit_fn(params)
jac_fd = jac_fd_fn(params)
jac_fd_jit = jac_fd_fn_jit(params)
jac_ps = jac_ps_fn(params)
jac_ps_jit = jac_ps_fn_jit(params)

assert qml.math.allclose(jac_exact, jac_jit, atol=atol)
assert qml.math.allclose(jac, jac_jit, atol=atol)
for compare in [jac_fd, jac_fd_jit, jac_ps, jac_ps_jit]:
assert qml.math.allclose(jac_exact, compare, atol=atol)
dwierichs marked this conversation as resolved.
Show resolved Hide resolved
5 changes: 2 additions & 3 deletions tests/transforms/test_tape_expand.py
Original file line number Diff line number Diff line change
Expand Up @@ -336,9 +336,8 @@ class NonDiffPhaseShift(qml.PhaseShift):

assert new_tape.operations[0].name == "RZ"
assert new_tape.operations[0].grad_method == "A"
assert new_tape.operations[1].name == "GlobalPhase"
assert new_tape.operations[2].name == "RY"
assert new_tape.operations[3].name == "CNOT"
assert new_tape.operations[1].name == "RY"
assert new_tape.operations[2].name == "CNOT"

def test_nontrainable_nondiff(self, mocker):
"""Test that a circuit with non-differentiable
Expand Down
Loading