Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Program Capture] Add pytree support to captured qml.grad and qml.jacobian #6134

Merged
merged 90 commits into from
Sep 15, 2024

Conversation

dwierichs
Copy link
Contributor

Context:
#6120 and #6127 add support to capture qml.grad and qml.jacobian in plxpr. Once captured, they dispatch to jax.grad and jax.jacobian.

Description of the Change:
This PR adds support for pytree inputs and outputs of the differentiated functions, similar to #6081.
For this, it extends the internal class FlatFn by the extra functionality to turn the wrapper into a *flat_args -> *flat_outputs function, instead of a *pytree_args -> *flat_outputs function.

Benefits:
Pytree support 🌳

Possible Drawbacks:

Related GitHub Issues:

[sc-70930]
[sc-71862]

@dwierichs dwierichs marked this pull request as ready for review September 9, 2024 19:39
Base automatically changed from capture-jacobian to master September 9, 2024 22:31
@dwierichs dwierichs added the review-ready 👌 PRs which are ready for review by someone from the core team. label Sep 10, 2024
Copy link
Contributor

@mudit2812 mudit2812 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @dwierichs . Looks great 🎉 🚀

doc/releases/changelog-dev.md Outdated Show resolved Hide resolved
pennylane/_grad.py Outdated Show resolved Hide resolved
pennylane/capture/flatfn.py Outdated Show resolved Hide resolved
tests/capture/test_capture_diff.py Show resolved Hide resolved
Co-authored-by: Mudit Pandey <mudit.pandey@xanadu.ai>
Copy link
Contributor

@albi3ro albi3ro left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Am I missing something?

@qml.qnode(qml.device('default.qubit', wires=2))
def circuit(x, y):
    qml.RX(x,0)
    qml.RY(y['a'], 0)
    qml.RX(y['b'], 0)
    return qml.expval(qml.Z(0))

qml.grad(circuit)(jax.numpy.array(0.5), {"a":jax.numpy.array(1.0), "b": jax.numpy.array(2.0)})
TypeError: primal and tangent arguments to jax.jvp must have the same tree structure; primals have tree structure PyTreeDef((*, *, *)) whereas tangents have tree structure PyTreeDef((*, CustomNode(Zero[ShapedArray(float32[], weak_type=True)], []), CustomNode(Zero[ShapedArray(float32[], weak_type=True)], []))).

@dwierichs
Copy link
Contributor Author

dwierichs commented Sep 11, 2024

The issue seems to be that defining a primitive's JVP rule using jax.jvp again exposes the zero tracers sent through the program as tangents to the custom implementation, which breaks when looking at the PytreeDefs, because zeros are registered as CustomNode. This issue existed with a previous implementation in JAX, back in 2019. The old issue re-appears with up-to-date JAX code, even for a simple non-PennyLane example 🤔 Should I post an issue on the JAX repo for this?

It is worth noting that this issue is not introduced in this PR, but exists on master as well. It's only that we look into more complex signatures in this PR and thus found that partial differentiation (with argnum not pointing to all args) breaks.

MWE for the JAX issue:

f = lambda a, b: a + b

f_prim = jax.core.Primitive("f")

@f_prim.def_impl
def _(a, b):
    return a + b

@f_prim.def_abstract_eval
def _(a, b):
    return jax.core.ShapedArray(a.shape, a.dtype)

from jax.interpreters import ad

def custom_jvp(args, dargs):
  print("args", args)
  print("dargs", dargs)
  return jax.jvp(f, args, dargs)
    
ad.primitive_jvps[f_prim] = custom_jvp

def F(*args):
    return f_prim.bind(*args)

print("forward")
print(F(1., 2.)) # this works
print("grad wrt both args")
print(jax.grad(F, argnums=[0,1])(1., 2.)) # this works
print("grad wrt one arg")
print(jax.grad(F, argnums=[0])(1., 2.)) # error

@dwierichs
Copy link
Contributor Author

dwierichs commented Sep 11, 2024

Intercepting ad.Zero and manually producing zero-valued tangents like in this tutorial worked.

I added a test based on the bug above.

Copy link
Contributor

@mudit2812 mudit2812 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Approving again 😄

Copy link
Contributor

@albi3ro albi3ro left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍

@dwierichs dwierichs enabled auto-merge (squash) September 15, 2024 19:27
@dwierichs dwierichs merged commit d0344b0 into master Sep 15, 2024
37 checks passed
@dwierichs dwierichs deleted the capture-grad-pytrees branch September 15, 2024 19:47
mudit2812 added a commit that referenced this pull request Sep 16, 2024
…jacobian` (#6134)

**Context:**
#6120 and #6127 add support to capture `qml.grad` and `qml.jacobian` in
plxpr. Once captured, they dispatch to `jax.grad` and `jax.jacobian`.

**Description of the Change:**
This PR adds support for pytree inputs and outputs of the differentiated
functions, similar to #6081.
For this, it extends the internal class `FlatFn` by the extra
functionality to turn the wrapper into a `*flat_args -> *flat_outputs`
function, instead of a `*pytree_args -> *flat_outputs` function.

**Benefits:**
Pytree support 🌳 

**Possible Drawbacks:**

**Related GitHub Issues:**

[sc-70930]
[sc-71862]

---------

Co-authored-by: Christina Lee <christina@xanadu.ai>
Co-authored-by: Mudit Pandey <mudit.pandey@xanadu.ai>
mudit2812 added a commit that referenced this pull request Sep 16, 2024
…jacobian` (#6134)

**Context:**
#6120 and #6127 add support to capture `qml.grad` and `qml.jacobian` in
plxpr. Once captured, they dispatch to `jax.grad` and `jax.jacobian`.

**Description of the Change:**
This PR adds support for pytree inputs and outputs of the differentiated
functions, similar to #6081.
For this, it extends the internal class `FlatFn` by the extra
functionality to turn the wrapper into a `*flat_args -> *flat_outputs`
function, instead of a `*pytree_args -> *flat_outputs` function.

**Benefits:**
Pytree support 🌳 

**Possible Drawbacks:**

**Related GitHub Issues:**

[sc-70930]
[sc-71862]

---------

Co-authored-by: Christina Lee <christina@xanadu.ai>
Co-authored-by: Mudit Pandey <mudit.pandey@xanadu.ai>
mudit2812 added a commit that referenced this pull request Sep 18, 2024
…jacobian` (#6134)

**Context:**
#6120 and #6127 add support to capture `qml.grad` and `qml.jacobian` in
plxpr. Once captured, they dispatch to `jax.grad` and `jax.jacobian`.

**Description of the Change:**
This PR adds support for pytree inputs and outputs of the differentiated
functions, similar to #6081.
For this, it extends the internal class `FlatFn` by the extra
functionality to turn the wrapper into a `*flat_args -> *flat_outputs`
function, instead of a `*pytree_args -> *flat_outputs` function.

**Benefits:**
Pytree support 🌳 

**Possible Drawbacks:**

**Related GitHub Issues:**

[sc-70930]
[sc-71862]

---------

Co-authored-by: Christina Lee <christina@xanadu.ai>
Co-authored-by: Mudit Pandey <mudit.pandey@xanadu.ai>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
review-ready 👌 PRs which are ready for review by someone from the core team.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants