Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Capture the adjoint transform into jaxpr #5966

Merged
merged 22 commits into from
Jul 30, 2024
Merged

Capture the adjoint transform into jaxpr #5966

merged 22 commits into from
Jul 30, 2024

Conversation

albi3ro
Copy link
Contributor

@albi3ro albi3ro commented Jul 8, 2024

Context:

We have an ongoing experimental project to enable capture of quantum functions into a jaxpr representation.

While we can now capture any operator or measurement process, we want to also be able to natively capture certain kinds of qfunc transforms; the first of which is adjoint.

Description of the Change:

Adds information to capture/explanations.md about nested jaxpr.

Adds a capture-compatible implementation of the adjoint qfunc transform.

Benefits:

Possible Drawbacks:

I've gone back and forth about where to place the definition of the adjoint primitive. Placing it in capture keeps it better isolated from normal pennylane development. Most developers wouldn't see the capture implementation details. But placing it in capture would also cause circular dependency issues.

We are getting closer to a place where more PL developers will need to be familiar with the basics of program capture, and so shouldn't be scared off by jax primitives. So placing it in the normal source code is good for getting everyone used to it.

Originally, I implemented this task with a bind_nested_plxpr transform-transform that abstracted away a lot of the details about creating primitives. But I now think that was premature optimization. While many details are similar, the different transforms have enough slight differences that we should have custom implementations for each right now. We can abstract away patterns later.

Related GitHub Issues:

[sc-68084]

Copy link

codecov bot commented Jul 9, 2024

Codecov Report

All modified and coverable lines are covered by tests ✅

Project coverage is 99.65%. Comparing base (0515647) to head (cde1d17).
Report is 297 commits behind head on master.

Additional details and impacted files
@@            Coverage Diff             @@
##           master    #5966      +/-   ##
==========================================
- Coverage   99.66%   99.65%   -0.01%     
==========================================
  Files         430      430              
  Lines       41510    41245     -265     
==========================================
- Hits        41370    41104     -266     
- Misses        140      141       +1     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

Copy link
Contributor

@PietropaoloFrisoni PietropaoloFrisoni left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This PR looks great. Thanks! I have studied it over the last few days and it was quite helpful.

I have several questions (not because I think there's something wrong, but rather for my understanding).

For example:

qml.capture.enable()

dev = qml.device("default.qubit", wires=2)

@qml.qnode(dev)
def circuit(op):
    qml.RX(0.1, wires=0)
    qml.RY(0.2, wires=1)
    qml.adjoint(op)(0.1, wires=0)
    return qml.expval(qml.PauliZ(0) @ qml.PauliZ(1))
    
    >>> circuit(qml.RX)  # this raises an error
    >>> circuit(op=qml.RX)  # this doesn't
    

Why do we have this different behavior? Thanks!

tests/capture/test_nested_plxpr.py Outdated Show resolved Hide resolved
albi3ro and others added 2 commits July 23, 2024 13:30
Co-authored-by: Thomas R. Bromley <49409390+trbromley@users.noreply.github.com>
Co-authored-by: Pietropaolo Frisoni <pietropaolo.frisoni@xanadu.ai>
Copy link
Contributor

@dime10 dime10 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice work 👍 The big question to resolve is the return type of the operator, otherwise this is great :)

pennylane/capture/explanations.md Outdated Show resolved Hide resolved
pennylane/capture/explanations.md Outdated Show resolved Hide resolved
pennylane/capture/explanations.md Show resolved Hide resolved
pennylane/ops/op_math/adjoint.py Outdated Show resolved Hide resolved
Co-authored-by: David Ittah <dime10@users.noreply.github.com>
@albi3ro
Copy link
Contributor Author

albi3ro commented Jul 26, 2024

This PR looks great. Thanks! I have studied it over the last few days and it was quite helpful.

I have several questions (not because I think there's something wrong, but rather for my understanding).

For example:

qml.capture.enable()

dev = qml.device("default.qubit", wires=2)

@qml.qnode(dev)
def circuit(op):
    qml.RX(0.1, wires=0)
    qml.RY(0.2, wires=1)
    qml.adjoint(op)(0.1, wires=0)
    return qml.expval(qml.PauliZ(0) @ qml.PauliZ(1))
    
    >>> circuit(qml.RX)  # this raises an error
    >>> circuit(op=qml.RX)  # this doesn't
    

Why do we have this different behavior? Thanks!

This behavior would be analogous to:

def f(g):
    return g(jax.numpy.array(0.5))

jax.jit(f)(jax.numpy.sin)

Doesn't work. Any positional arguments to a jitted function need to be tensorlike. This will just be a general constraint going forward.

@PietropaoloFrisoni
Copy link
Contributor

Any positional arguments to a jitted function need to be tensorlike. This will just be a general constraint going forward.

Ok. It's still unclear to me why this works though:

qml.capture.enable()

dev = qml.device("default.qubit", wires=2)

@qml.qnode(dev)
def circuit(op):
    qml.RX(0.1, wires=0)
    qml.RY(0.2, wires=1)
    qml.adjoint(op)(0.1, wires=0)
    return qml.expval(qml.PauliZ(0) @ qml.PauliZ(1))
    
    >>> circuit(op=qml.RX) 

Copy link
Contributor

@PietropaoloFrisoni PietropaoloFrisoni left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks again for the quite instructive material!

@dime10
Copy link
Contributor

dime10 commented Jul 26, 2024

This behavior would be analogous to:

def f(g):
    return g(jax.numpy.array(0.5))

jax.jit(f)(jax.numpy.sin)

Doesn't work. Any positional arguments to a jitted function need to be tensorlike. This will just be a general constraint going forward.

@albi3ro Do you see this conflicting with the idea to make the operators PyTrees? That always seemed strange to me for this reason, although I can see it being convenient in some instances (I think of PyTree operators as implicitly making the operator type a static argument to the function).

Regarding the positional vs keywords arguments however I think we should be careful. At the user level JAX does not distinguish between the two as an expression of "dynamic" arguments and "static" arguments, both are by default fully dynamic tracer arguments. This type of positional vs keyword distinction only exists in the bind functions of primitives (and I think some high-level API functions don't forward keyword arguments).

So the example Pietro points out seems to me like somewhere we have built in this distinction.

@josh146
Copy link
Member

josh146 commented Jul 26, 2024

Doesn't work. Any positional arguments to a jitted function need to be tensorlike. This will just be a general constraint going forward.

Would this prevent the current pattern of being able to pass instantiated operators, like Hamiltonians, to jitted functions as positional arguments? E.g., circuit(weights, H).

This is a super important pattern in use in a lot of places, and I wouldn't want to introduce any differences in the UX between positional and keyword arguments of functions

@dime10
Copy link
Contributor

dime10 commented Jul 26, 2024

Doesn't work. Any positional arguments to a jitted function need to be tensorlike. This will just be a general constraint going forward.

Would this prevent the current pattern of being able to pass instantiated operators, like Hamiltonians, to jitted functions as positional arguments? E.g., circuit(weights, H).

This is a super important pattern in use in a lot of places, and I wouldn't want to introduce any differences in the UX between positional and keyword arguments of functions

What's interesting is if this use-case is that important, then (departing from JAX and) using keyword arguments to indicate their staticness is actually a convenient way to do so. The alternative is for the user to declare them static via the static_argnums/static_argnames parameter 🤔

@albi3ro
Copy link
Contributor Author

albi3ro commented Jul 29, 2024

@dime10 @josh146 @PietropaoloFrisoni

So passing operators as positional arguments is perfectly fine:

def f(op):
    op + qml.X(0)
jax.make_jaxpr(f)(qml.RX(0.5, 0))
{ lambda ; a:f32[]. let
    b:AbstractOperator() = RX[n_wires=1] a 0
    c:AbstractOperator() = PauliX[n_wires=1] 0
    _:AbstractOperator() = Sum[grouping_type=None id=None method=rlf] b c
  in () }

They are just flattened and unflattened. So if the type/ metadata changed, we would have to recompile, but it works fine.

As for the issue that @PietropaoloFrisoni brought up, I've now realized that has to do with the qnode primitive.

For the primitive bind call in a qnode, we don't have any more information about how to call qnode_prim.bind other than how the user provided args and keyword args.

Maybe we could iterate through the args and the signature, and turn things into keyword arguments if they aren't tensorlike. I can look into that if needed.

Copy link
Contributor

@dwierichs dwierichs left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Mostly looks good to me, nice work @albi3ro , also with the explanation material!
I just had a question regarding closure variables necessarily being traced and about the order of operations in adjoint.

pennylane/capture/explanations.md Outdated Show resolved Hide resolved
pennylane/capture/explanations.md Outdated Show resolved Hide resolved
pennylane/capture/explanations.md Outdated Show resolved Hide resolved
pennylane/capture/explanations.md Show resolved Hide resolved
pennylane/capture/explanations.md Show resolved Hide resolved
pennylane/capture/explanations.md Outdated Show resolved Hide resolved
pennylane/capture/explanations.md Show resolved Hide resolved
pennylane/ops/op_math/adjoint.py Outdated Show resolved Hide resolved
tests/capture/test_nested_plxpr.py Outdated Show resolved Hide resolved
tests/capture/test_nested_plxpr.py Outdated Show resolved Hide resolved
@albi3ro albi3ro requested review from dwierichs and dime10 July 29, 2024 15:32
Copy link
Contributor

@dwierichs dwierichs left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good to me, thanks @albi3ro! Nice work 😍

@albi3ro albi3ro enabled auto-merge (squash) July 30, 2024 13:56
@albi3ro albi3ro disabled auto-merge July 30, 2024 14:18
Copy link
Contributor

@dime10 dime10 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice work figuring out the closure problem 💯

pennylane/ops/op_math/adjoint.py Show resolved Hide resolved
@dime10
Copy link
Contributor

dime10 commented Jul 30, 2024

As for the issue that @PietropaoloFrisoni brought up, I've now realized that has to do with the qnode primitive.

For the primitive bind call in a qnode, we don't have any more information about how to call qnode_prim.bind other than how the user provided args and keyword args.

Maybe we could iterate through the args and the signature, and turn things into keyword arguments if they aren't tensorlike. I can look into that if needed.

Ah that explains it, perfect 👌

@albi3ro albi3ro enabled auto-merge (squash) July 30, 2024 14:23
@albi3ro albi3ro merged commit d74f3fd into master Jul 30, 2024
40 checks passed
@albi3ro albi3ro deleted the nested-adjoint-plxpr branch July 30, 2024 14:44
albi3ro added a commit that referenced this pull request Jul 30, 2024
**Context:**

We have an ongoing experimental project to enable capture of quantum
functions into a jaxpr representation.

Following on from #5966 , this PR adds the ability to capture the `ctrl`
transform into jaxpr.

**Description of the Change:**

**Benefits:**

**Possible Drawbacks:**

**Related GitHub Issues:**

[sc-68090]

---------

Co-authored-by: Thomas R. Bromley <49409390+trbromley@users.noreply.github.com>
Co-authored-by: Pietropaolo Frisoni <pietropaolo.frisoni@xanadu.ai>
Co-authored-by: David Ittah <dime10@users.noreply.github.com>
Co-authored-by: David Wierichs <david.wierichs@xanadu.ai>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants