Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG] Can't differentiate dynamic_one_shot with autograd and jax #5709

Closed
1 task done
mudit2812 opened this issue May 17, 2024 · 2 comments
Closed
1 task done

[BUG] Can't differentiate dynamic_one_shot with autograd and jax #5709

mudit2812 opened this issue May 17, 2024 · 2 comments
Labels
bug 🐛 Something isn't working

Comments

@mudit2812
Copy link
Contributor

Expected behavior

I expect to get gradient/jacobian outputs without any errors when differentiating QNodes that use mid-circuit measurements with shots with the autograd or jax interface.

Actual behavior

Both interfaces give errors about not being able to differentiate with respect to integers.

Additional information

No response

Source code

import pennylane as qml
from pennylane import numpy as pnp
import jax
import jax.numpy as jnp

dev = qml.device("default.qubit", shots=10)

@qml.qnode(dev)
def f(x):
    qml.RX(x, 0)
    qml.measure(0)
    return qml.expval(qml.PauliX(0))

x = pnp.array(0.4, requires_grad=True)
print(qml.grad(f)(x))

x = jnp.array(0.4)
print(jax.grad(f)(x))

Tracebacks

WITH AUTOGRAD:

KeyError                                  Traceback (most recent call last)
File ~/.pyenv/versions/3.10.12/envs/pennylane/lib/python3.10/site-packages/autograd/tracer.py:118, in new_box(value, trace, node)
    117 try:
--> 118     return box_type_mappings[type(value)](value, trace, node)
    119 except KeyError:

KeyError: <class 'int'>

During handling of the above exception, another exception occurred:

TypeError                                 Traceback (most recent call last)
Cell In[40], line 16
     14 x = np.array(0.4, requires_grad=True) # make input require grad
     15 print(f(x))
---> 16 qml.grad(f)(x) # compute the gradient

File ~/repos/pennylane/pennylane/_grad.py:166, in grad.__call__(self, *args, **kwargs)
    163     self._forward = self._fun(*args, **kwargs)
    164     return ()
--> 166 grad_value, ans = grad_fn(*args, **kwargs)  # pylint: disable=not-callable
    167 self._forward = ans
    169 return grad_value

File ~/.pyenv/versions/3.10.12/envs/pennylane/lib/python3.10/site-packages/autograd/wrap_util.py:20, in unary_to_nary.<locals>.nary_operator.<locals>.nary_f(*args, **kwargs)
     18 else:
     19     x = tuple(args[i] for i in argnum)
---> 20 return unary_operator(unary_f, x, *nary_op_args, **nary_op_kwargs)

File ~/repos/pennylane/pennylane/_grad.py:184, in grad._grad_with_forward(fun, x)
    178 @staticmethod
    179 @unary_to_nary
    180 def _grad_with_forward(fun, x):
    181     """This function is a replica of ``autograd.grad``, with the only
    182     difference being that it returns both the gradient *and* the forward pass
    183     value."""
--> 184     vjp, ans = _make_vjp(fun, x)  # pylint: disable=redefined-outer-name
    186     if vspace(ans).size != 1:
    187         raise TypeError(
    188             "Grad only applies to real scalar-output functions. "
    189             "Try jacobian, elementwise_grad or holomorphic_grad."
    190         )

File ~/.pyenv/versions/3.10.12/envs/pennylane/lib/python3.10/site-packages/autograd/core.py:10, in make_vjp(fun, x)
      8 def make_vjp(fun, x):
      9     start_node = VJPNode.new_root()
---> 10     end_value, end_node =  trace(start_node, fun, x)
     11     if end_node is None:
     12         def vjp(g): return vspace(x).zeros()

File ~/.pyenv/versions/3.10.12/envs/pennylane/lib/python3.10/site-packages/autograd/tracer.py:10, in trace(start_node, fun, x)
      8 with trace_stack.new_trace() as t:
      9     start_box = new_box(x, t, start_node)
---> 10     end_box = fun(start_box)
     11     if isbox(end_box) and end_box._trace == start_box._trace:
     12         return end_box._value, end_box._node

File ~/.pyenv/versions/3.10.12/envs/pennylane/lib/python3.10/site-packages/autograd/wrap_util.py:15, in unary_to_nary.<locals>.nary_operator.<locals>.nary_f.<locals>.unary_f(x)
     13 else:
     14     subargs = subvals(args, zip(argnum, x))
---> 15 return fun(*subargs, **kwargs)

File ~/repos/pennylane/pennylane/workflow/qnode.py:1098, in QNode.__call__(self, *args, **kwargs)
   1095 self._update_gradient_fn(shots=override_shots, tape=self._tape)
   1097 try:
-> 1098     res = self._execution_component(args, kwargs, override_shots=override_shots)
   1099 finally:
   1100     if old_interface == "auto":

File ~/repos/pennylane/pennylane/workflow/qnode.py:1052, in QNode._execution_component(self, args, kwargs, override_shots)
   1049 full_transform_program.prune_dynamic_transform()
   1051 # pylint: disable=unexpected-keyword-arg
-> 1052 res = qml.execute(
   1053     (self._tape,),
   1054     device=self.device,
   1055     gradient_fn=self.gradient_fn,
   1056     interface=self.interface,
   1057     transform_program=full_transform_program,
   1058     config=config,
   1059     gradient_kwargs=self.gradient_kwargs,
   1060     override_shots=override_shots,
   1061     **self.execute_kwargs,
   1062 )
   1063 res = res[0]
   1065 # convert result to the interface in case the qfunc has no parameters

File ~/repos/pennylane/pennylane/workflow/execution.py:790, in execute(tapes, device, gradient_fn, interface, transform_program, config, grad_on_execution, gradient_kwargs, cache, cachesize, max_diff, override_shots, expand_fn, max_expansion, device_batch_transform, device_vjp)
    785 else:
    786     results = ml_boundary_execute(
    787         tapes, device, execute_fn, gradient_fn, gradient_kwargs, _n=1, max_diff=max_diff
    788     )
--> 790 return post_processing(results)

File ~/repos/pennylane/pennylane/transforms/core/transform_program.py:88, in _apply_postprocessing_stack(results, postprocessing_stack)
     65 """Applies the postprocessing and cotransform postprocessing functions in a Last-In-First-Out LIFO manner.
     66 
     67 Args:
   (...)
     85 
     86 """
     87 for postprocessing in reversed(postprocessing_stack):
---> 88     results = postprocessing(results)
     89 return results

File ~/repos/pennylane/pennylane/transforms/core/transform_program.py:58, in _batch_postprocessing(results, individual_fns, slices)
     32 def _batch_postprocessing(
     33     results: ResultBatch, individual_fns: List[PostProcessingFn], slices: List[slice]
     34 ) -> ResultBatch:
     35     """Broadcast individual post processing functions onto their respective tapes.
     36 
     37     Args:
   (...)
     56 
     57     """
---> 58     return tuple(fn(results[sl]) for fn, sl in zip(individual_fns, slices))

File ~/repos/pennylane/pennylane/transforms/core/transform_program.py:58, in <genexpr>(.0)
     32 def _batch_postprocessing(
     33     results: ResultBatch, individual_fns: List[PostProcessingFn], slices: List[slice]
     34 ) -> ResultBatch:
     35     """Broadcast individual post processing functions onto their respective tapes.
     36 
     37     Args:
   (...)
     56 
     57     """
---> 58     return tuple(fn(results[sl]) for fn, sl in zip(individual_fns, slices))

File ~/repos/pennylane/pennylane/transforms/dynamic_one_shot.py:143, in dynamic_one_shot.<locals>.processing_fn(results, has_partitioned_shots, batched_results)
    141 if not tape.shots.has_partitioned_shots:
    142     results = results[0]
--> 143 return parse_native_mid_circuit_measurements(tape, aux_tapes, results)

File ~/repos/pennylane/pennylane/transforms/dynamic_one_shot.py:244, in parse_native_mid_circuit_measurements(circuit, aux_tapes, results)
    235 post_process_tape = qml.tape.QuantumScript(
    236     aux_tapes[0].operations,
    237     aux_tapes[0].measurements[0:-n_mcms],
    238     shots=aux_tapes[0].shots,
    239     trainable_params=aux_tapes[0].trainable_params,
    240 )
    241 single_measurement = (
    242     len(post_process_tape.measurements) == 0 and len(aux_tapes[0].measurements) == 1
    243 )
--> 244 mcm_samples = qml.math.array(
    245     [[res] if single_measurement else res[-n_mcms::] for res in results],
    246     like=interface,
    247     dtype=float if interface == "autograd" else int,
    248 )
    249 # Can't use boolean dtype array with tf, hence why conditionally setting items to 0 or 1
    250 has_postselect = qml.math.array(
    251     [[int(op.postselect is not None) for op in all_mcms]], like=interface
    252 )

File ~/repos/pennylane/pennylane/math/multi_dispatch.py:39, in array(like, *args, **kwargs)
     30 def array(*args, like=None, **kwargs):
     31     """Creates an array or tensor object of the target framework.
     32 
     33     If the PyTorch interface is specified, this method preserves the Torch device used.
   (...)
     37         tensor_like: the tensor_like object of the framework
     38     """
---> 39     res = np.array(*args, like=like, **kwargs)
     40     if like is not None and get_interface(like) == "torch":
     41         res = res.to(device=like.device)

File ~/.pyenv/versions/3.10.12/envs/pennylane/lib/python3.10/site-packages/autoray/autoray.py:80, in do(fn, like, *args, **kwargs)
     31 """Do function named ``fn`` on ``(*args, **kwargs)``, peforming single
     32 dispatch to retrieve ``fn`` based on whichever library defines the class of
     33 the ``args[0]``, or the ``like`` keyword argument if specified.
   (...)
     77     <tf.Tensor: id=91, shape=(3, 3), dtype=float32>
     78 """
     79 backend = choose_backend(fn, *args, like=like, **kwargs)
---> 80 return get_lib_fn(backend, fn)(*args, **kwargs)

File ~/repos/pennylane/pennylane/numpy/wrapper.py:117, in tensor_wrapper.<locals>._wrapped(*args, **kwargs)
    114         tensor_kwargs["requires_grad"] = _np.any([i.requires_grad for i in tensor_args])
    116 # evaluate the original object
--> 117 res = obj(*args, **kwargs)
    119 if isinstance(res, _np.ndarray):
    120     # only if the output of the object is a ndarray,
    121     # then convert to a PennyLane tensor
    122     res = tensor(res, **tensor_kwargs)

File ~/.pyenv/versions/3.10.12/envs/pennylane/lib/python3.10/site-packages/autograd/numpy/numpy_wrapper.py:58, in array(A, *args, **kwargs)
     56 t = builtins.type(A)
     57 if t in (list, tuple):
---> 58     return array_from_args(args, kwargs, *map(array, A))
     59 else:
     60     return _array_from_scalar_or_array(args, kwargs, A)

File ~/.pyenv/versions/3.10.12/envs/pennylane/lib/python3.10/site-packages/autograd/numpy/numpy_wrapper.py:58, in array(A, *args, **kwargs)
     56 t = builtins.type(A)
     57 if t in (list, tuple):
---> 58     return array_from_args(args, kwargs, *map(array, A))
     59 else:
     60     return _array_from_scalar_or_array(args, kwargs, A)

File ~/.pyenv/versions/3.10.12/envs/pennylane/lib/python3.10/site-packages/autograd/tracer.py:46, in primitive.<locals>.f_wrapped(*args, **kwargs)
     44     ans = f_wrapped(*argvals, **kwargs)
     45     node = node_constructor(ans, f_wrapped, argvals, kwargs, argnums, parents)
---> 46     return new_box(ans, trace, node)
     47 else:
     48     return f_raw(*args, **kwargs)

File ~/.pyenv/versions/3.10.12/envs/pennylane/lib/python3.10/site-packages/autograd/tracer.py:120, in new_box(value, trace, node)
    118     return box_type_mappings[type(value)](value, trace, node)
    119 except KeyError:
--> 120     raise TypeError("Can't differentiate w.r.t. type {}".format(type(value)))

TypeError: Can't differentiate w.r.t. type <class 'int'>


WITH JAX

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
Cell In[8], line 15
     12     return qml.expval(qml.PauliX(0))
     14 x = jnp.array(0.4)
---> 15 print(jax.grad(f)(x))

    [... skipping hidden 10 frame]

File ~/repos/pennylane/pennylane/workflow/qnode.py:1098, in QNode.__call__(self, *args, **kwargs)
   1095 self._update_gradient_fn(shots=override_shots, tape=self._tape)
   1097 try:
-> 1098     res = self._execution_component(args, kwargs, override_shots=override_shots)
   1099 finally:
   1100     if old_interface == "auto":

File ~/repos/pennylane/pennylane/workflow/qnode.py:1052, in QNode._execution_component(self, args, kwargs, override_shots)
   1049 full_transform_program.prune_dynamic_transform()
   1051 # pylint: disable=unexpected-keyword-arg
-> 1052 res = qml.execute(
   1053     (self._tape,),
   1054     device=self.device,
   1055     gradient_fn=self.gradient_fn,
   1056     interface=self.interface,
   1057     transform_program=full_transform_program,
   1058     config=config,
   1059     gradient_kwargs=self.gradient_kwargs,
   1060     override_shots=override_shots,
   1061     **self.execute_kwargs,
   1062 )
   1063 res = res[0]
   1065 # convert result to the interface in case the qfunc has no parameters

File ~/repos/pennylane/pennylane/workflow/execution.py:784, in execute(tapes, device, gradient_fn, interface, transform_program, config, grad_on_execution, gradient_kwargs, cache, cachesize, max_diff, override_shots, expand_fn, max_expansion, device_batch_transform, device_vjp)
    776 ml_boundary_execute = _get_ml_boundary_execute(
    777     interface,
    778     _grad_on_execution,
    779     config.use_device_jacobian_product,
    780     differentiable=max_diff > 1,
    781 )
    783 if interface in jpc_interfaces:
--> 784     results = ml_boundary_execute(tapes, execute_fn, jpc, device=device)
    785 else:
    786     results = ml_boundary_execute(
    787         tapes, device, execute_fn, gradient_fn, gradient_kwargs, _n=1, max_diff=max_diff
    788     )

File ~/repos/pennylane/pennylane/workflow/interfaces/jax.py:265, in jax_jvp_execute(tapes, execute_fn, jpc, device)
    261     logger.debug("Entry with (tapes=%s, execute_fn=%s, jpc=%s)", tapes, execute_fn, jpc)
    263 parameters = tuple(tuple(t.get_parameters()) for t in tapes)
--> 265 return _execute_jvp(parameters, _NonPytreeWrapper(tuple(tapes)), execute_fn, jpc)

    [... skipping hidden 5 frame]

File ~/.pyenv/versions/3.10.12/envs/pennylane/lib/python3.10/site-packages/jax/_src/custom_derivatives.py:347, in _flatten_jvp(primal_name, jvp_name, in_tree, maybe_out_type, *args)
    342     msg = ("Custom JVP rule must produce primal and tangent outputs with "
    343            "equal shapes and dtypes, but got:\n{}")
    344     disagreements = (
    345         f"  primal {av1.str_short()} for tangent {av2.str_short()}"
    346         for av1, av2 in zip(primal_avals_out, tangent_avals_out) if av1 != av2)
--> 347     raise TypeError(msg.format('\n'.join(disagreements)))
    348 yield primals_out + tangents_out, (out_tree, primal_avals)

TypeError: Custom JVP rule must produce primal and tangent outputs with equal shapes and dtypes, but got:
  primal int64[] for tangent float64[]
  primal int64[] for tangent float64[]
  primal int64[] for tangent float64[]
  primal int64[] for tangent float64[]
  primal int64[] for tangent float64[]
  primal int64[] for tangent float64[]
  primal int64[] for tangent float64[]
  primal int64[] for tangent float64[]
  primal int64[] for tangent float64[]
  primal int64[] for tangent float64[]

System information

Dev

Existing GitHub issues

  • I have searched existing GitHub issues to make sure the issue does not already exist.
@mudit2812 mudit2812 added the bug 🐛 Something isn't working label May 17, 2024
@albi3ro
Copy link
Contributor

albi3ro commented May 17, 2024

Note that moving the device preprocessing inside the ML interface boundary, see:

master...inner-transform-program

does seem to solve this problem.

@mudit2812
Copy link
Contributor Author

Duplicate of #5736 . Closing.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug 🐛 Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants