-
Notifications
You must be signed in to change notification settings - Fork 603
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[BUG] jax.grad
+ jax.jit
does not work with AmplitudeEmbedding
and finite shots
#5541
Comments
Issue seems to be with taking parameter shift of a This gives the exact same error:
Potentially need a custom gradient recipe or something. Or just identify that the gradient of a global phase is always zero. |
Hey @albi3ro and @KetpuntoG, I was looking at issues to contribute to. Can I tackle this issue ? |
In the operator I'm working on, I am differentiating with respect to the global phase parameter, would it stop working @albi3ro ? |
Using a controlled global phase or something? |
Correct 👌 |
We could potentially update
to:
|
Hey @albi3ro and @KetpuntoG , I want to work on this. Can you please assign it to me. |
Assigned. Note that our next release is in two weeks on May 6th, so we may take it over next week to make sure we get the fix in. |
Hey @albi3ro , I am unsure how to add tests for this. The below test fails as the values are far apart @pytest.mark.jax
def test_jacobian_with_and_without_jit_has_same_output():
import jax
dev = qml.device("default.qubit", wires=4, shots=100)
@qml.qnode(dev)
def circuit(coeffs):
qml.AmplitudeEmbedding(coeffs, normalize=True, wires=[0, 1])
return qml.expval(qml.PauliZ(0) @ qml.PauliZ(1))
params = jax.numpy.array([0.4, 0.5, 0.1, 0.3])
jac_fn = jax.jacobian(circuit)
jac_jit_fn = jax.jit(jac_fn)
# Array([ 1.67562983, -1.88046271, -0.48002058, 1.05993827], dtype=float64
jac = jac_fn(params)
# Array([ 1.77878858, -2.02651428, -0.11825829, 1.04522512], dtype=float64))
jac_jit = jac_jit_fn(params)
assert qml.math.allclose(jac, jac_jit) I just made |
Sorry about not getting back to you earlier @Tarun-Kumar07 . Been a bit busy the last few days. This looks to be a case of the shots being too low. I bumped it up to So there's two options:
Potentially we can just test both. |
Hey @albi3ro , I tried writing tests for I have opened a draft PR: #5620, and these are the only tests failing, not sure how to fix them. |
Fixed by #5620. Thank you so much @Tarun-Kumar07! 😎 |
Expected behavior
qml.AmplitudeEmbedding
should work withjit
,grad
and finite shotsActual behavior
ValueError: need at least one array to stack
Additional information
Same issue with
qml.StatePrep
andqml.MottonenStatePreparation
Source code
No response
Tracebacks
No response
System information
Name: PennyLane Version: 0.36.0.dev0 Summary: PennyLane is a cross-platform Python library for quantum computing, quantum machine learning, and quantum chemistry. Train a quantum computer the same way as a neural network. Home-page: https://github.com/PennyLaneAI/pennylane Author: Author-email: License: Apache License 2.0 Location: /usr/local/lib/python3.10/dist-packages Requires: appdirs, autograd, autoray, cachetools, networkx, numpy, pennylane-lightning, requests, rustworkx, scipy, semantic-version, toml, typing-extensions Required-by: PennyLane_Lightning Platform info: Linux-6.1.58+-x86_64-with-glibc2.35 Python version: 3.10.12 Numpy version: 1.25.2 Scipy version: 1.11.4 Installed devices: - default.clifford (PennyLane-0.36.0.dev0) - default.gaussian (PennyLane-0.36.0.dev0) - default.mixed (PennyLane-0.36.0.dev0) - default.qubit (PennyLane-0.36.0.dev0) - default.qubit.autograd (PennyLane-0.36.0.dev0) - default.qubit.jax (PennyLane-0.36.0.dev0) - default.qubit.legacy (PennyLane-0.36.0.dev0) - default.qubit.tf (PennyLane-0.36.0.dev0) - default.qubit.torch (PennyLane-0.36.0.dev0) - default.qutrit (PennyLane-0.36.0.dev0) - null.qubit (PennyLane-0.36.0.dev0) - lightning.qubit (PennyLane_Lightning-0.35.1)
Existing GitHub issues
The text was updated successfully, but these errors were encountered: