Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make GlobalPhase not differentiable #5620

Merged

Conversation

Tarun-Kumar07
Copy link
Contributor

@Tarun-Kumar07 Tarun-Kumar07 commented May 2, 2024

Context:
When using the following state preparation methods (AmplitudeEmbedding, StatePrep, MottonenStatePreparation) with jit and grad, the error ValueError: need at least one array to stack was encountered.

Description of the Change:
All state preparation strategies used GlobalPhase under the hood, which caused the above error. After this PR, GlobalPhase may not be differentiable anymore, as its grad_method is set to None.

Benefits:

Possible Drawbacks:

Related GitHub Issues:
It fixes #5541

@albi3ro
Copy link
Contributor

albi3ro commented May 2, 2024

Thanks for this @Tarun-Kumar07

For the failures due to errors:

ValueError: Computing the gradient of circuits that return the state with the parameter-shift rule gradient transform is not supported, as it is a hardware-compatible method.

That would be expected, and we should shift the measurement to expectation values.

For the failures due to:

FAILED tests/templates/test_state_preparations/test_mottonen_state_prep.py::test_jacobian_with_and_without_jit_has_same_output_with_high_shots[StatePrep] - AssertionError: assert Array(False, dtype=bool)
 +  where Array(False, dtype=bool) = <function allclose at 0x7fdb77481940>(Array([-0.0003,  0.0153, -0.0153,  0.0003], dtype=float64), Array([ 1.0187, -0.9953, -1.0047,  0.9813], dtype=float64), atol=0.02)

Those are legitimately different results, so we can safely safe we are getting wrong results in that case 😢 I'll investigate.

Copy link
Contributor

@dwierichs dwierichs left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I left a couple of small comments and one major suggestion: Could we set GlobalPhase.grad_method = "F"? This will produce unnecessary shifted tapes for expectation values and probabilities, but it will avoid wrong results when differentiating qml.state with finite_diff and param_shift.

doc/releases/changelog-dev.md Outdated Show resolved Hide resolved
doc/releases/changelog-dev.md Outdated Show resolved Hide resolved
pennylane/ops/identity.py Show resolved Hide resolved
pennylane/ops/op_math/controlled.py Outdated Show resolved Hide resolved
tests/templates/test_embeddings/test_amplitude.py Outdated Show resolved Hide resolved
tests/templates/test_embeddings/test_amplitude.py Outdated Show resolved Hide resolved
tests/templates/test_embeddings/test_amplitude.py Outdated Show resolved Hide resolved
@dwierichs
Copy link
Contributor

Those are legitimately different results, so we can safely safe we are getting wrong results in that case 😢 I'll investigate.

@Tarun-Kumar07 @albi3ro Not sure you got to this yet, but it seems that the decomposition of those state preparation methods handle special parameter values differently than others. This makes the derivative wrong at those special values, because param_shift is handed a tape that does not contain the general decomposition, so it will not shift all operations that need shifting. As JITting does not allow such special cases, we only make this mistake without JITting, hence the difference in the results within those tests.

Basically, the decomposition does something like the following decomposition for RZ:

def compute_decomposition(theta, wires):
    if not qml.math.is_abstract(theta) and qml.math.isclose(theta, 0):
        return []
    return [qml.RZ(theta, wires)]

It's correct but it does not have the correct parameter-shift derivative at 0.

This looks like an independent bug to me, and like one that could be hiding across the codebase for other ops as well, theoretically.

@dwierichs
Copy link
Contributor

@Tarun-Kumar07 Sorry for taking so long with this! We decided to move ahead with this PR as you originally drafted it (with grad_method=None). Before this PR can be merged, the fix in #5774 will need to be merged in, though.
I took the liberty of using a test you wrote here for that PR, your work is much appreciated! 🙏

@Tarun-Kumar07
Copy link
Contributor Author

Hey @dwierichs , once the PR #5774 is merged I will revert changes to grad_method=None. Thanks for updating :)

dwierichs added a commit that referenced this pull request Jun 13, 2024
**Context:**
The decomposition of `MottonenStatePreparation` skips some gates for
special parameter values/input states.
See the linked issue for details.

**Description of the Change:**
This PR introduces a check for differentiability so that the gates only
are skipped when no derivatives are being computed.
Note that this does *not* fix the non-differentiability at other special
parameter points that also is referenced in #5715 and that is being
warned against in the docs already.
Also, the linked issue is about multiple operations and we here only
address `MottonenStatePreparation`.

**Benefits:**
Fixes parts of #5715. Unblocks #5620 .

**Possible Drawbacks:**

**Related GitHub Issues:**
#5715
@dwierichs
Copy link
Contributor

@Tarun-Kumar07 It is merged :)
I took the liberty to merge the master branch into your branch, updating the tests you wrote (and that made it into #5774) accordingly.

@dwierichs
Copy link
Contributor

Hi @Tarun-Kumar07 ,
Do you know when you might have time to get back to this? :) Or would you prefer us to finalize it?

@Tarun-Kumar07
Copy link
Contributor Author

Hi @dwierichs,

I am currently tied up until August 17th. After that, I will be able to work on this.
If this is a priority, I suggest assigning someone else to take over in the meantime.

Thank you for understanding.

@dwierichs
Copy link
Contributor

Thanks @Tarun-Kumar07.
That is very understandable of course! I'll try and push your PR through, as it is closed to finish anyways, I think.
Thank you for your contribution, looking forward to the next one 🤩

Copy link
Contributor

@dwierichs dwierichs left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Approving the first part of this PR. I modified it slightly, which will need approval by someone else.

@dwierichs dwierichs marked this pull request as ready for review July 25, 2024 14:30
@mudit2812 mudit2812 self-requested a review July 25, 2024 15:36
Copy link

codecov bot commented Jul 28, 2024

Codecov Report

All modified and coverable lines are covered by tests ✅

Project coverage is 99.65%. Comparing base (20eed81) to head (a7d2aeb).
Report is 297 commits behind head on master.

Additional details and impacted files
@@           Coverage Diff            @@
##           master    #5620    +/-   ##
========================================
  Coverage   99.65%   99.65%            
========================================
  Files         430      430            
  Lines       41505    41210   -295     
========================================
- Hits        41362    41069   -293     
+ Misses        143      141     -2     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@dwierichs dwierichs enabled auto-merge (squash) July 30, 2024 07:09
@dwierichs dwierichs merged commit 6e122ae into PennyLaneAI:master Jul 30, 2024
40 checks passed
@Alex-Preciado
Copy link
Contributor

Thank you so much for this contribution @Tarun-Kumar07!! 🚀

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[BUG] jax.grad + jax.jit does not work with AmplitudeEmbedding and finite shots
6 participants