Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG] jax.grad + jax.jit does not work with AmplitudeEmbedding and finite shots #5541

Closed
1 task done
KetpuntoG opened this issue Apr 18, 2024 · 13 comments · Fixed by #5620
Closed
1 task done

[BUG] jax.grad + jax.jit does not work with AmplitudeEmbedding and finite shots #5541

KetpuntoG opened this issue Apr 18, 2024 · 13 comments · Fixed by #5620
Assignees
Labels
bug 🐛 Something isn't working

Comments

@KetpuntoG
Copy link
Contributor

Expected behavior

qml.AmplitudeEmbedding should work with jit , grad and finite shots

from pennylane import numpy as np
import pennylane as qml
import jax 

# with shots = None it works
dev = qml.device("default.qubit", wires = 4, shots=100)

@qml.qnode(dev)
def circuit(coeffs):
        qml.AmplitudeEmbedding(coeffs, normalize = True, wires = [0,1])
        return qml.expval(qml.PauliZ(0) @ qml.PauliZ(1))

params = jax.numpy.array([0.4, 0.5, 0.1, 0.3])

jac_fn = jax.jacobian(circuit)
# without jax.jit it works
jac_fn = jax.jit(jac_fn)

jac = jac_fn(params)
print(jac)

Actual behavior

ValueError: need at least one array to stack

Additional information

Same issue with qml.StatePrep and qml.MottonenStatePreparation

Source code

No response

Tracebacks

No response

System information

Name: PennyLane
Version: 0.36.0.dev0
Summary: PennyLane is a cross-platform Python library for quantum computing, quantum machine learning, and quantum chemistry. Train a quantum computer the same way as a neural network.
Home-page: https://github.com/PennyLaneAI/pennylane
Author: 
Author-email: 
License: Apache License 2.0
Location: /usr/local/lib/python3.10/dist-packages
Requires: appdirs, autograd, autoray, cachetools, networkx, numpy, pennylane-lightning, requests, rustworkx, scipy, semantic-version, toml, typing-extensions
Required-by: PennyLane_Lightning

Platform info:           Linux-6.1.58+-x86_64-with-glibc2.35
Python version:          3.10.12
Numpy version:           1.25.2
Scipy version:           1.11.4
Installed devices:
- default.clifford (PennyLane-0.36.0.dev0)
- default.gaussian (PennyLane-0.36.0.dev0)
- default.mixed (PennyLane-0.36.0.dev0)
- default.qubit (PennyLane-0.36.0.dev0)
- default.qubit.autograd (PennyLane-0.36.0.dev0)
- default.qubit.jax (PennyLane-0.36.0.dev0)
- default.qubit.legacy (PennyLane-0.36.0.dev0)
- default.qubit.tf (PennyLane-0.36.0.dev0)
- default.qubit.torch (PennyLane-0.36.0.dev0)
- default.qutrit (PennyLane-0.36.0.dev0)
- null.qubit (PennyLane-0.36.0.dev0)
- lightning.qubit (PennyLane_Lightning-0.35.1)

Existing GitHub issues

  • I have searched existing GitHub issues to make sure the issue does not already exist.
@KetpuntoG KetpuntoG added the bug 🐛 Something isn't working label Apr 18, 2024
@albi3ro
Copy link
Contributor

albi3ro commented Apr 19, 2024

Issue seems to be with taking parameter shift of a GlobalPhase:

This gives the exact same error:

import numpy as np
import pennylane as qml
import jax 

# with shots = None it works
dev = qml.device("default.qubit", wires = 4, shots=100)

@qml.qnode(dev)
def circuit(phase):
    qml.GlobalPhase(phase)
    return qml.expval(qml.PauliZ(0) @ qml.PauliZ(1))

params = jax.numpy.array(0.4)

jac_fn = jax.jacobian(circuit)
# without jax.jit it works
jac_fn = jax.jit(jac_fn)

jac = jac_fn(params)
print(jac)

Potentially need a custom gradient recipe or something. Or just identify that the gradient of a global phase is always zero.

@Tarun-Kumar07
Copy link
Contributor

Hey @albi3ro and @KetpuntoG, I was looking at issues to contribute to. Can I tackle this issue ?

@albi3ro
Copy link
Contributor

albi3ro commented Apr 19, 2024

Just setting GlobalPhase.grad_method = None seems to work 🤞 . That would basically just indicate "no differentiable parameters here". Should be a simple enough fix if you're interested.

Screenshot 2024-04-19 at 12 50 58 PM

@KetpuntoG
Copy link
Contributor Author

In the operator I'm working on, I am differentiating with respect to the global phase parameter, would it stop working @albi3ro ?

@albi3ro
Copy link
Contributor

albi3ro commented Apr 19, 2024

Using a controlled global phase or something?

@KetpuntoG
Copy link
Contributor Author

Correct 👌
It seems strange, but it appears naturally in the formulation

@albi3ro
Copy link
Contributor

albi3ro commented Apr 19, 2024

We could potentially update Controlled.grad_method from:

    @property
    def grad_method(self):
        return self.base.grad_method

to:

    @property
    def grad_method(self):
        return "A" if self.base.name == "GlobalPhase" else self.base.grad_method

@Tarun-Kumar07
Copy link
Contributor

Hey @albi3ro and @KetpuntoG , I want to work on this. Can you please assign it to me.

@albi3ro
Copy link
Contributor

albi3ro commented Apr 22, 2024

Assigned. Note that our next release is in two weeks on May 6th, so we may take it over next week to make sure we get the fix in.

@Tarun-Kumar07
Copy link
Contributor

Hey @albi3ro , I am unsure how to add tests for this. The below test fails as the values are far apart

@pytest.mark.jax
def test_jacobian_with_and_without_jit_has_same_output():
    import jax

    dev = qml.device("default.qubit", wires=4, shots=100)

    @qml.qnode(dev)
    def circuit(coeffs):
        qml.AmplitudeEmbedding(coeffs, normalize=True, wires=[0, 1])
        return qml.expval(qml.PauliZ(0) @ qml.PauliZ(1))

    params = jax.numpy.array([0.4, 0.5, 0.1, 0.3])
    jac_fn = jax.jacobian(circuit)
    jac_jit_fn = jax.jit(jac_fn)

    # Array([ 1.67562983, -1.88046271, -0.48002058,  1.05993827], dtype=float64
    jac = jac_fn(params)

    # Array([ 1.77878858, -2.02651428, -0.11825829,  1.04522512], dtype=float64))
    jac_jit = jac_jit_fn(params)

    assert qml.math.allclose(jac, jac_jit)

I just made GlobalPhase.grad_method = None as mentioned in this comment

@albi3ro
Copy link
Contributor

albi3ro commented Apr 30, 2024

Sorry about not getting back to you earlier @Tarun-Kumar07 . Been a bit busy the last few days.

This looks to be a case of the shots being too low. I bumped it up to shots=5000, and then the numbers started to converge better.

So there's two options:

  1. Just bumping up the shot number and setting a seed on the device to reduce flakiness qml.device('default.qubit', shots=10000, seed=7890234)
  2. Using analytic mode shots=None and manually specifying diff_method="parameter-shift".

Potentially we can just test both.

@Tarun-Kumar07
Copy link
Contributor

Hey @albi3ro ,

I tried writing tests for MottonenStatePreparation and StatePrep similar to the one in the above comment. Even by providing a high number of shots and using the analytic mode, these tests fail.

I have opened a draft PR: #5620, and these are the only tests failing, not sure how to fix them.

@Alex-Preciado
Copy link
Contributor

Fixed by #5620. Thank you so much @Tarun-Kumar07! 😎

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug 🐛 Something isn't working
Projects
None yet
Development

Successfully merging a pull request may close this issue.

4 participants