-
Notifications
You must be signed in to change notification settings - Fork 586
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Program Capture] Add pytree support to captured qml.grad
and qml.jacobian
#6134
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @dwierichs . Looks great 🎉 🚀
Co-authored-by: Mudit Pandey <mudit.pandey@xanadu.ai>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Am I missing something?
@qml.qnode(qml.device('default.qubit', wires=2))
def circuit(x, y):
qml.RX(x,0)
qml.RY(y['a'], 0)
qml.RX(y['b'], 0)
return qml.expval(qml.Z(0))
qml.grad(circuit)(jax.numpy.array(0.5), {"a":jax.numpy.array(1.0), "b": jax.numpy.array(2.0)})
TypeError: primal and tangent arguments to jax.jvp must have the same tree structure; primals have tree structure PyTreeDef((*, *, *)) whereas tangents have tree structure PyTreeDef((*, CustomNode(Zero[ShapedArray(float32[], weak_type=True)], []), CustomNode(Zero[ShapedArray(float32[], weak_type=True)], []))).
The issue seems to be that defining a primitive's JVP rule using It is worth noting that this issue is not introduced in this PR, but exists on master as well. It's only that we look into more complex signatures in this PR and thus found that partial differentiation (with MWE for the JAX issue: f = lambda a, b: a + b
f_prim = jax.core.Primitive("f")
@f_prim.def_impl
def _(a, b):
return a + b
@f_prim.def_abstract_eval
def _(a, b):
return jax.core.ShapedArray(a.shape, a.dtype)
from jax.interpreters import ad
def custom_jvp(args, dargs):
print("args", args)
print("dargs", dargs)
return jax.jvp(f, args, dargs)
ad.primitive_jvps[f_prim] = custom_jvp
def F(*args):
return f_prim.bind(*args)
print("forward")
print(F(1., 2.)) # this works
print("grad wrt both args")
print(jax.grad(F, argnums=[0,1])(1., 2.)) # this works
print("grad wrt one arg")
print(jax.grad(F, argnums=[0])(1., 2.)) # error |
Intercepting I added a test based on the bug above. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Approving again 😄
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
👍
…jacobian` (#6134) **Context:** #6120 and #6127 add support to capture `qml.grad` and `qml.jacobian` in plxpr. Once captured, they dispatch to `jax.grad` and `jax.jacobian`. **Description of the Change:** This PR adds support for pytree inputs and outputs of the differentiated functions, similar to #6081. For this, it extends the internal class `FlatFn` by the extra functionality to turn the wrapper into a `*flat_args -> *flat_outputs` function, instead of a `*pytree_args -> *flat_outputs` function. **Benefits:** Pytree support 🌳 **Possible Drawbacks:** **Related GitHub Issues:** [sc-70930] [sc-71862] --------- Co-authored-by: Christina Lee <christina@xanadu.ai> Co-authored-by: Mudit Pandey <mudit.pandey@xanadu.ai>
…jacobian` (#6134) **Context:** #6120 and #6127 add support to capture `qml.grad` and `qml.jacobian` in plxpr. Once captured, they dispatch to `jax.grad` and `jax.jacobian`. **Description of the Change:** This PR adds support for pytree inputs and outputs of the differentiated functions, similar to #6081. For this, it extends the internal class `FlatFn` by the extra functionality to turn the wrapper into a `*flat_args -> *flat_outputs` function, instead of a `*pytree_args -> *flat_outputs` function. **Benefits:** Pytree support 🌳 **Possible Drawbacks:** **Related GitHub Issues:** [sc-70930] [sc-71862] --------- Co-authored-by: Christina Lee <christina@xanadu.ai> Co-authored-by: Mudit Pandey <mudit.pandey@xanadu.ai>
…jacobian` (#6134) **Context:** #6120 and #6127 add support to capture `qml.grad` and `qml.jacobian` in plxpr. Once captured, they dispatch to `jax.grad` and `jax.jacobian`. **Description of the Change:** This PR adds support for pytree inputs and outputs of the differentiated functions, similar to #6081. For this, it extends the internal class `FlatFn` by the extra functionality to turn the wrapper into a `*flat_args -> *flat_outputs` function, instead of a `*pytree_args -> *flat_outputs` function. **Benefits:** Pytree support 🌳 **Possible Drawbacks:** **Related GitHub Issues:** [sc-70930] [sc-71862] --------- Co-authored-by: Christina Lee <christina@xanadu.ai> Co-authored-by: Mudit Pandey <mudit.pandey@xanadu.ai>
Context:
#6120 and #6127 add support to capture
qml.grad
andqml.jacobian
in plxpr. Once captured, they dispatch tojax.grad
andjax.jacobian
.Description of the Change:
This PR adds support for pytree inputs and outputs of the differentiated functions, similar to #6081.
For this, it extends the internal class
FlatFn
by the extra functionality to turn the wrapper into a*flat_args -> *flat_outputs
function, instead of a*pytree_args -> *flat_outputs
function.Benefits:
Pytree support 🌳
Possible Drawbacks:
Related GitHub Issues:
[sc-70930]
[sc-71862]