-
Notifications
You must be signed in to change notification settings - Fork 603
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Capture the adjoint transform into jaxpr #5966
Conversation
Codecov ReportAll modified and coverable lines are covered by tests ✅
Additional details and impacted files@@ Coverage Diff @@
## master #5966 +/- ##
==========================================
- Coverage 99.66% 99.65% -0.01%
==========================================
Files 430 430
Lines 41510 41245 -265
==========================================
- Hits 41370 41104 -266
- Misses 140 141 +1 ☔ View full report in Codecov by Sentry. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This PR looks great. Thanks! I have studied it over the last few days and it was quite helpful.
I have several questions (not because I think there's something wrong, but rather for my understanding).
For example:
qml.capture.enable()
dev = qml.device("default.qubit", wires=2)
@qml.qnode(dev)
def circuit(op):
qml.RX(0.1, wires=0)
qml.RY(0.2, wires=1)
qml.adjoint(op)(0.1, wires=0)
return qml.expval(qml.PauliZ(0) @ qml.PauliZ(1))
>>> circuit(qml.RX) # this raises an error
>>> circuit(op=qml.RX) # this doesn't
Why do we have this different behavior? Thanks!
Co-authored-by: Thomas R. Bromley <49409390+trbromley@users.noreply.github.com> Co-authored-by: Pietropaolo Frisoni <pietropaolo.frisoni@xanadu.ai>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice work 👍 The big question to resolve is the return type of the operator, otherwise this is great :)
Co-authored-by: David Ittah <dime10@users.noreply.github.com>
This behavior would be analogous to:
Doesn't work. Any positional arguments to a jitted function need to be tensorlike. This will just be a general constraint going forward. |
Ok. It's still unclear to me why this works though:
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks again for the quite instructive material!
@albi3ro Do you see this conflicting with the idea to make the operators PyTrees? That always seemed strange to me for this reason, although I can see it being convenient in some instances (I think of PyTree operators as implicitly making the operator type a static argument to the function). Regarding the positional vs keywords arguments however I think we should be careful. At the user level JAX does not distinguish between the two as an expression of "dynamic" arguments and "static" arguments, both are by default fully dynamic tracer arguments. This type of positional vs keyword distinction only exists in the So the example Pietro points out seems to me like somewhere we have built in this distinction. |
Would this prevent the current pattern of being able to pass instantiated operators, like Hamiltonians, to jitted functions as positional arguments? E.g., This is a super important pattern in use in a lot of places, and I wouldn't want to introduce any differences in the UX between positional and keyword arguments of functions |
What's interesting is if this use-case is that important, then (departing from JAX and) using keyword arguments to indicate their staticness is actually a convenient way to do so. The alternative is for the user to declare them static via the |
@dime10 @josh146 @PietropaoloFrisoni So passing operators as positional arguments is perfectly fine:
They are just flattened and unflattened. So if the type/ metadata changed, we would have to recompile, but it works fine. As for the issue that @PietropaoloFrisoni brought up, I've now realized that has to do with the qnode primitive. For the primitive bind call in a qnode, we don't have any more information about how to call pennylane/pennylane/capture/capture_qnode.py Line 180 in ce1db61
Maybe we could iterate through the args and the signature, and turn things into keyword arguments if they aren't tensorlike. I can look into that if needed. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Mostly looks good to me, nice work @albi3ro , also with the explanation material!
I just had a question regarding closure variables necessarily being traced and about the order of operations in adjoint
.
Co-authored-by: David Wierichs <david.wierichs@xanadu.ai>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good to me, thanks @albi3ro! Nice work 😍
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice work figuring out the closure problem 💯
Ah that explains it, perfect 👌 |
**Context:** We have an ongoing experimental project to enable capture of quantum functions into a jaxpr representation. Following on from #5966 , this PR adds the ability to capture the `ctrl` transform into jaxpr. **Description of the Change:** **Benefits:** **Possible Drawbacks:** **Related GitHub Issues:** [sc-68090] --------- Co-authored-by: Thomas R. Bromley <49409390+trbromley@users.noreply.github.com> Co-authored-by: Pietropaolo Frisoni <pietropaolo.frisoni@xanadu.ai> Co-authored-by: David Ittah <dime10@users.noreply.github.com> Co-authored-by: David Wierichs <david.wierichs@xanadu.ai>
Context:
We have an ongoing experimental project to enable capture of quantum functions into a jaxpr representation.
While we can now capture any operator or measurement process, we want to also be able to natively capture certain kinds of qfunc transforms; the first of which is
adjoint
.Description of the Change:
Adds information to
capture/explanations.md
about nested jaxpr.Adds a capture-compatible implementation of the adjoint qfunc transform.
Benefits:
Possible Drawbacks:
I've gone back and forth about where to place the definition of the adjoint primitive. Placing it in
capture
keeps it better isolated from normal pennylane development. Most developers wouldn't see the capture implementation details. But placing it incapture
would also cause circular dependency issues.We are getting closer to a place where more PL developers will need to be familiar with the basics of program capture, and so shouldn't be scared off by jax primitives. So placing it in the normal source code is good for getting everyone used to it.
Originally, I implemented this task with a
bind_nested_plxpr
transform-transform that abstracted away a lot of the details about creating primitives. But I now think that was premature optimization. While many details are similar, the different transforms have enough slight differences that we should have custom implementations for each right now. We can abstract away patterns later.Related GitHub Issues:
[sc-68084]