From 38ee38e6d072fd786bf04988aae65f375a4dbbe2 Mon Sep 17 00:00:00 2001 From: David Wierichs Date: Tue, 10 Sep 2024 00:31:07 +0200 Subject: [PATCH] [Program Capture] Capture & execute `qml.jacobian` in plxpr (#6127) **Context:** We're adding support for differentiation in plxpr, also see #6120. **Description of the Change:** This PR adds support for `qml.jacobian`, similar to the support for `qml.grad`. Note that Pytree support will be needed to allow for multi-argument derivatives. **Benefits:** Capture derivatives of non-scalar functions. **Possible Drawbacks:** See discussion around `qml.grad` in #6120. **Related GitHub Issues:** [sc-71860] --------- Co-authored-by: Christina Lee --- doc/releases/changelog-dev.md | 9 + pennylane/_grad.py | 7 +- pennylane/capture/capture_diff.py | 35 ++ pennylane/capture/primitives.py | 4 +- tests/capture/test_capture_diff.py | 619 +++++++++++++++++++---------- tests/test_compiler.py | 2 +- 6 files changed, 473 insertions(+), 203 deletions(-) diff --git a/doc/releases/changelog-dev.md b/doc/releases/changelog-dev.md index f4a07167948..a2809a619bf 100644 --- a/doc/releases/changelog-dev.md +++ b/doc/releases/changelog-dev.md @@ -20,6 +20,15 @@ which differs from the Autograd implementation of `qml.grad` itself. [(#6120)](https://github.com/PennyLaneAI/pennylane/pull/6120) +

Capturing and representing hybrid programs

+ +* Differentiation of hybrid programs via `qml.grad` and `qml.jacobian` can now be captured + into plxpr. When evaluating a captured `qml.grad` (`qml.jacobian`) instruction, it will + dispatch to `jax.grad` (`jax.jacobian`), which differs from the Autograd implementation + without capture. + [(#6120)](https://github.com/PennyLaneAI/pennylane/pull/6120) + [(#6127)](https://github.com/PennyLaneAI/pennylane/pull/6127) + * Improve unit testing for capturing of nested control flows. [(#6111)](https://github.com/PennyLaneAI/pennylane/pull/6111) diff --git a/pennylane/_grad.py b/pennylane/_grad.py index 8aeb11b02f5..859ae5d9fbb 100644 --- a/pennylane/_grad.py +++ b/pennylane/_grad.py @@ -24,7 +24,7 @@ from autograd.wrap_util import unary_to_nary from pennylane.capture import enabled -from pennylane.capture.capture_diff import _get_grad_prim +from pennylane.capture.capture_diff import _get_grad_prim, _get_jacobian_prim from pennylane.compiler import compiler from pennylane.compiler.compiler import CompileError @@ -434,8 +434,11 @@ def circuit(x): ops_loader = available_eps[active_jit]["ops"].load() return ops_loader.jacobian(func, method=method, h=h, argnums=argnum) + if enabled(): + return _capture_diff(func, argnum, _get_jacobian_prim(), method=method, h=h) + if method or h: - raise ValueError(f"Invalid values for 'method={method}' and 'h={h}' in interpreted mode") + raise ValueError(f"Invalid values '{method=}' and '{h=}' without QJIT.") def _get_argnum(args): """Inspect the arguments for differentiability and return the diff --git a/pennylane/capture/capture_diff.py b/pennylane/capture/capture_diff.py index cea9307d4da..92dde5a2956 100644 --- a/pennylane/capture/capture_diff.py +++ b/pennylane/capture/capture_diff.py @@ -82,3 +82,38 @@ def _(*args, argnum, jaxpr, n_consts, method, h): return tuple(jaxpr.invars[i].aval for i in argnum) return grad_prim + + +@lru_cache +def _get_jacobian_prim(): + """Create a primitive for Jacobian computations. + This primitive is used when capturing ``qml.jacobian``. + """ + jacobian_prim = create_non_jvp_primitive()("jacobian") + jacobian_prim.multiple_results = True # pylint: disable=attribute-defined-outside-init + + # pylint: disable=too-many-arguments + @jacobian_prim.def_impl + def _(*args, argnum, jaxpr, n_consts, method, h): + if method or h: # pragma: no cover + raise ValueError(f"Invalid values '{method=}' and '{h=}' without QJIT.") + consts = args[:n_consts] + args = args[n_consts:] + + def func(*inner_args): + return jax.core.eval_jaxpr(jaxpr, consts, *inner_args) + + return jax.jacobian(func, argnums=argnum)(*args) + + # pylint: disable=unused-argument + @jacobian_prim.def_abstract_eval + def _(*args, argnum, jaxpr, n_consts, method, h): + in_avals = [jaxpr.invars[i].aval for i in argnum] + out_shapes = (outvar.aval.shape for outvar in jaxpr.outvars) + return [ + jax.core.ShapedArray(out_shape + in_aval.shape, in_aval.dtype) + for out_shape in out_shapes + for in_aval in in_avals + ] + + return jacobian_prim diff --git a/pennylane/capture/primitives.py b/pennylane/capture/primitives.py index 3ccff96d5af..3d578b82f7f 100644 --- a/pennylane/capture/primitives.py +++ b/pennylane/capture/primitives.py @@ -22,7 +22,7 @@ from pennylane.ops.op_math.condition import _get_cond_qfunc_prim from pennylane.ops.op_math.controlled import _get_ctrl_qfunc_prim -from .capture_diff import _get_grad_prim +from .capture_diff import _get_grad_prim, _get_jacobian_prim from .capture_measurements import _get_abstract_measurement from .capture_operators import _get_abstract_operator from .capture_qnode import _get_qnode_prim @@ -32,6 +32,7 @@ adjoint_transform_prim = _get_adjoint_qfunc_prim() ctrl_transform_prim = _get_ctrl_qfunc_prim() grad_prim = _get_grad_prim() +jacobian_prim = _get_jacobian_prim() qnode_prim = _get_qnode_prim() cond_prim = _get_cond_qfunc_prim() for_loop_prim = _get_for_loop_qfunc_prim() @@ -44,6 +45,7 @@ "adjoint_transform_prim", "ctrl_transform_prim", "grad_prim", + "jacobian_prim", "qnode_prim", "cond_prim", "for_loop_prim", diff --git a/tests/capture/test_capture_diff.py b/tests/capture/test_capture_diff.py index edca307932d..cf7834aafb8 100644 --- a/tests/capture/test_capture_diff.py +++ b/tests/capture/test_capture_diff.py @@ -23,7 +23,10 @@ jax = pytest.importorskip("jax") -from pennylane.capture.primitives import grad_prim # pylint: disable=wrong-import-position +from pennylane.capture.primitives import ( # pylint: disable=wrong-import-position + grad_prim, + jacobian_prim, +) jnp = jax.numpy @@ -35,31 +38,34 @@ def enable_disable_plxpr(): qml.capture.disable() -@pytest.mark.parametrize("kwargs", [{"method": "fd"}, {"h": 0.3}, {"h": 0.2, "method": "fd"}]) -def test_error_with_method_or_h(kwargs): - """Test that an error is raised if kwargs for QJIT's grad are passed to PLxPRs grad.""" +class TestExceptions: + """Test that expected exceptions are correctly raised.""" - def func(x): - return qml.grad(jnp.sin, **kwargs)(x) + @pytest.mark.parametrize("kwargs", [{"method": "fd"}, {"h": 0.3}, {"h": 0.2, "method": "fd"}]) + @pytest.mark.parametrize("diff", [qml.grad, qml.jacobian]) + def test_error_with_method_or_h(self, kwargs, diff): + """Test that an error is raised if kwargs for QJIT's grad are passed to PLxPRs grad.""" - method = kwargs.get("method", None) - h = kwargs.get("h", None) - jaxpr = jax.make_jaxpr(func)(0.6) - with pytest.raises(ValueError, match=f"'{method=}' and '{h=}' without QJIT"): - func(0.6) - with pytest.raises(ValueError, match=f"'{method=}' and '{h=}' without QJIT"): - jax.core.eval_jaxpr(jaxpr.jaxpr, jaxpr.consts, 0.6) + def func(x): + return diff(jnp.sin, **kwargs)(x) + method = kwargs.get("method", None) + h = kwargs.get("h", None) + jaxpr = jax.make_jaxpr(func)(0.6) + with pytest.raises(ValueError, match=f"'{method=}' and '{h=}' without QJIT"): + func(0.6) + with pytest.raises(ValueError, match=f"'{method=}' and '{h=}' without QJIT"): + jax.core.eval_jaxpr(jaxpr.jaxpr, jaxpr.consts, 0.6) -def test_error_with_non_scalar_function(): - """Test that an error is raised if the differentiated function has non-scalar outputs.""" - with pytest.raises(TypeError, match="Grad only applies to scalar-output functions."): - jax.make_jaxpr(qml.grad(jnp.sin))(jnp.array([0.5, 0.2])) + def test_error_with_non_scalar_function(self): + """Test that an error is raised if the differentiated function has non-scalar outputs.""" + with pytest.raises(TypeError, match="Grad only applies to scalar-output functions."): + jax.make_jaxpr(qml.grad(jnp.sin))(jnp.array([0.5, 0.2])) -def grad_eqn_assertions(eqn, argnum=None, n_consts=0): +def diff_eqn_assertions(eqn, primitive, argnum=None, n_consts=0): argnum = [0] if argnum is None else argnum - assert eqn.primitive == grad_prim + assert eqn.primitive == primitive assert set(eqn.params.keys()) == {"argnum", "n_consts", "jaxpr", "method", "h"} assert eqn.params["argnum"] == argnum assert eqn.params["n_consts"] == n_consts @@ -68,187 +74,402 @@ def grad_eqn_assertions(eqn, argnum=None, n_consts=0): @pytest.mark.parametrize("x64_mode", (True, False)) -@pytest.mark.parametrize("argnum", ([0, 1], [0], [1], 0, 1)) -def test_classical_grad(x64_mode, argnum): - """Test that the qml.grad primitive can be captured with classical nodes.""" - - initial_mode = jax.config.jax_enable_x64 - jax.config.update("jax_enable_x64", x64_mode) - fdtype = jnp.float64 if x64_mode else jnp.float32 - - def inner_func(x, y): - return jnp.prod(jnp.sin(x) * jnp.cos(y) ** 2) - - def func_qml(x): - return qml.grad(inner_func, argnum=argnum)(x, 0.4 * jnp.sqrt(x)) - - def func_jax(x): - return jax.grad(inner_func, argnums=argnum)(x, 0.4 * jnp.sqrt(x)) - - x = 0.7 - jax_out = func_jax(x) - assert qml.math.allclose(func_qml(x), jax_out) - - # Check overall jaxpr properties - if isinstance(argnum, int): - argnum = [argnum] - jaxpr = jax.make_jaxpr(func_qml)(x) - assert jaxpr.in_avals == [jax.core.ShapedArray((), fdtype, weak_type=True)] - assert len(jaxpr.eqns) == 3 - assert jaxpr.out_avals == [jax.core.ShapedArray((), fdtype, weak_type=True)] * len(argnum) - - grad_eqn = jaxpr.eqns[2] - grad_eqn_assertions(grad_eqn, argnum=argnum) - assert [var.aval for var in grad_eqn.outvars] == jaxpr.out_avals - assert len(grad_eqn.params["jaxpr"].eqns) == 6 # 5 numeric eqns, 1 conversion eqn - - manual_eval = jax.core.eval_jaxpr(jaxpr.jaxpr, jaxpr.consts, x) - assert qml.math.allclose(manual_eval, jax_out) - - jax.config.update("jax_enable_x64", initial_mode) - - -@pytest.mark.parametrize("x64_mode", (True, False)) -def test_nested_grad(x64_mode): - """Test that nested qml.grad primitives can be captured. - We use the function - f(x) = sin(x)^3 - f'(x) = 3 sin(x)^2 cos(x) - f''(x) = 6 sin(x) cos(x)^2 - 3 sin(x)^3 - f'''(x) = 6 cos(x)^3 - 12 sin(x)^2 cos(x) - 9 sin(x)^2 cos(x) - """ - initial_mode = jax.config.jax_enable_x64 - jax.config.update("jax_enable_x64", x64_mode) - fdtype = jnp.float64 if x64_mode else jnp.float32 - - def func(x): - return jnp.sin(x) ** 3 - - x = 0.7 - - # 1st order - qml_func_1 = qml.grad(func) - expected_1 = 3 * jnp.sin(x) ** 2 * jnp.cos(x) - assert qml.math.allclose(qml_func_1(x), expected_1) - - jaxpr_1 = jax.make_jaxpr(qml_func_1)(x) - assert jaxpr_1.in_avals == [jax.core.ShapedArray((), fdtype, weak_type=True)] - assert len(jaxpr_1.eqns) == 1 - assert jaxpr_1.out_avals == [jax.core.ShapedArray((), fdtype, weak_type=True)] - - grad_eqn = jaxpr_1.eqns[0] - assert [var.aval for var in grad_eqn.outvars] == jaxpr_1.out_avals - grad_eqn_assertions(grad_eqn) - assert len(grad_eqn.params["jaxpr"].eqns) == 2 - - manual_eval_1 = jax.core.eval_jaxpr(jaxpr_1.jaxpr, jaxpr_1.consts, x) - assert qml.math.allclose(manual_eval_1, expected_1) - - # 2nd order - qml_func_2 = qml.grad(qml_func_1) - expected_2 = 6 * jnp.sin(x) * jnp.cos(x) ** 2 - 3 * jnp.sin(x) ** 3 - assert qml.math.allclose(qml_func_2(x), expected_2) - - jaxpr_2 = jax.make_jaxpr(qml_func_2)(x) - assert jaxpr_2.in_avals == [jax.core.ShapedArray((), fdtype, weak_type=True)] - assert len(jaxpr_2.eqns) == 1 - assert jaxpr_2.out_avals == [jax.core.ShapedArray((), fdtype, weak_type=True)] - - grad_eqn = jaxpr_2.eqns[0] - assert [var.aval for var in grad_eqn.outvars] == jaxpr_2.out_avals - grad_eqn_assertions(grad_eqn) - assert len(grad_eqn.params["jaxpr"].eqns) == 1 # inner grad equation - assert grad_eqn.params["jaxpr"].eqns[0].primitive == grad_prim - - manual_eval_2 = jax.core.eval_jaxpr(jaxpr_2.jaxpr, jaxpr_2.consts, x) - assert qml.math.allclose(manual_eval_2, expected_2) - - # 3rd order - qml_func_3 = qml.grad(qml_func_2) - expected_3 = ( - 6 * jnp.cos(x) ** 3 - 12 * jnp.sin(x) ** 2 * jnp.cos(x) - 9 * jnp.sin(x) ** 2 * jnp.cos(x) +class TestGrad: + """Tests for capturing `qml.grad`.""" + + @pytest.mark.parametrize("argnum", ([0, 1], [0], [1], 0, 1)) + def test_classical_grad(self, x64_mode, argnum): + """Test that the qml.grad primitive can be captured with classical nodes.""" + + initial_mode = jax.config.jax_enable_x64 + jax.config.update("jax_enable_x64", x64_mode) + fdtype = jnp.float64 if x64_mode else jnp.float32 + + def inner_func(x, y): + return jnp.prod(jnp.sin(x) * jnp.cos(y) ** 2) + + def func_qml(x): + return qml.grad(inner_func, argnum=argnum)(x, 0.4 * jnp.sqrt(x)) + + def func_jax(x): + return jax.grad(inner_func, argnums=argnum)(x, 0.4 * jnp.sqrt(x)) + + x = 0.7 + jax_out = func_jax(x) + assert qml.math.allclose(func_qml(x), jax_out) + + # Check overall jaxpr properties + if isinstance(argnum, int): + argnum = [argnum] + jaxpr = jax.make_jaxpr(func_qml)(x) + assert jaxpr.in_avals == [jax.core.ShapedArray((), fdtype, weak_type=True)] + assert len(jaxpr.eqns) == 3 + assert jaxpr.out_avals == [jax.core.ShapedArray((), fdtype, weak_type=True)] * len(argnum) + + grad_eqn = jaxpr.eqns[2] + diff_eqn_assertions(grad_eqn, grad_prim, argnum=argnum) + assert [var.aval for var in grad_eqn.outvars] == jaxpr.out_avals + assert len(grad_eqn.params["jaxpr"].eqns) == 6 # 5 numeric eqns, 1 conversion eqn + + manual_eval = jax.core.eval_jaxpr(jaxpr.jaxpr, jaxpr.consts, x) + assert qml.math.allclose(manual_eval, jax_out) + + jax.config.update("jax_enable_x64", initial_mode) + + def test_nested_grad(self, x64_mode): + """Test that nested qml.grad primitives can be captured. + We use the function + f(x) = sin(x)^3 + f'(x) = 3 sin(x)^2 cos(x) + f''(x) = 6 sin(x) cos(x)^2 - 3 sin(x)^3 + f'''(x) = 6 cos(x)^3 - 12 sin(x)^2 cos(x) - 9 sin(x)^2 cos(x) + """ + initial_mode = jax.config.jax_enable_x64 + jax.config.update("jax_enable_x64", x64_mode) + fdtype = jnp.float64 if x64_mode else jnp.float32 + + def func(x): + return jnp.sin(x) ** 3 + + x = 0.7 + + # 1st order + qml_func_1 = qml.grad(func) + expected_1 = 3 * jnp.sin(x) ** 2 * jnp.cos(x) + assert qml.math.allclose(qml_func_1(x), expected_1) + + jaxpr_1 = jax.make_jaxpr(qml_func_1)(x) + assert jaxpr_1.in_avals == [jax.core.ShapedArray((), fdtype, weak_type=True)] + assert len(jaxpr_1.eqns) == 1 + assert jaxpr_1.out_avals == [jax.core.ShapedArray((), fdtype, weak_type=True)] + + grad_eqn = jaxpr_1.eqns[0] + assert [var.aval for var in grad_eqn.outvars] == jaxpr_1.out_avals + diff_eqn_assertions(grad_eqn, grad_prim) + assert len(grad_eqn.params["jaxpr"].eqns) == 2 + + manual_eval_1 = jax.core.eval_jaxpr(jaxpr_1.jaxpr, jaxpr_1.consts, x) + assert qml.math.allclose(manual_eval_1, expected_1) + + # 2nd order + qml_func_2 = qml.grad(qml_func_1) + expected_2 = 6 * jnp.sin(x) * jnp.cos(x) ** 2 - 3 * jnp.sin(x) ** 3 + assert qml.math.allclose(qml_func_2(x), expected_2) + + jaxpr_2 = jax.make_jaxpr(qml_func_2)(x) + assert jaxpr_2.in_avals == [jax.core.ShapedArray((), fdtype, weak_type=True)] + assert len(jaxpr_2.eqns) == 1 + assert jaxpr_2.out_avals == [jax.core.ShapedArray((), fdtype, weak_type=True)] + + grad_eqn = jaxpr_2.eqns[0] + assert [var.aval for var in grad_eqn.outvars] == jaxpr_2.out_avals + diff_eqn_assertions(grad_eqn, grad_prim) + assert len(grad_eqn.params["jaxpr"].eqns) == 1 # inner grad equation + assert grad_eqn.params["jaxpr"].eqns[0].primitive == grad_prim + + manual_eval_2 = jax.core.eval_jaxpr(jaxpr_2.jaxpr, jaxpr_2.consts, x) + assert qml.math.allclose(manual_eval_2, expected_2) + + # 3rd order + qml_func_3 = qml.grad(qml_func_2) + expected_3 = ( + 6 * jnp.cos(x) ** 3 + - 12 * jnp.sin(x) ** 2 * jnp.cos(x) + - 9 * jnp.sin(x) ** 2 * jnp.cos(x) + ) + + assert qml.math.allclose(qml_func_3(x), expected_3) + + jaxpr_3 = jax.make_jaxpr(qml_func_3)(x) + assert jaxpr_3.in_avals == [jax.core.ShapedArray((), fdtype, weak_type=True)] + assert len(jaxpr_3.eqns) == 1 + assert jaxpr_3.out_avals == [jax.core.ShapedArray((), fdtype, weak_type=True)] + + grad_eqn = jaxpr_3.eqns[0] + assert [var.aval for var in grad_eqn.outvars] == jaxpr_3.out_avals + diff_eqn_assertions(grad_eqn, grad_prim) + assert len(grad_eqn.params["jaxpr"].eqns) == 1 # inner grad equation + assert grad_eqn.params["jaxpr"].eqns[0].primitive == grad_prim + + manual_eval_3 = jax.core.eval_jaxpr(jaxpr_3.jaxpr, jaxpr_3.consts, x) + assert qml.math.allclose(manual_eval_3, expected_3) + + jax.config.update("jax_enable_x64", initial_mode) + + @pytest.mark.parametrize("diff_method", ("backprop", "parameter-shift")) + def test_grad_of_simple_qnode(self, x64_mode, diff_method, mocker): + """Test capturing the gradient of a simple qnode.""" + # pylint: disable=protected-access + initial_mode = jax.config.jax_enable_x64 + jax.config.update("jax_enable_x64", x64_mode) + fdtype = jax.numpy.float64 if x64_mode else jax.numpy.float32 + + dev = qml.device("default.qubit", wires=2) + + @qml.grad + @qml.qnode(dev, diff_method=diff_method) + def circuit(x): + qml.RX(x[0], wires=0) + qml.RY(x[1] ** 2, wires=0) + return qml.expval(qml.Z(0)) + + x = jnp.array([0.5, 0.9]) + res = circuit(x) + expected_res = ( + -jnp.sin(x[0]) * jnp.cos(x[1] ** 2), + -2 * x[1] * jnp.sin(x[1] ** 2) * jnp.cos(x[0]), + ) + assert qml.math.allclose(res, expected_res) + + jaxpr = jax.make_jaxpr(circuit)(x) + + assert len(jaxpr.eqns) == 1 # grad equation + assert jaxpr.in_avals == [jax.core.ShapedArray((2,), fdtype)] + assert jaxpr.out_avals == [jax.core.ShapedArray((2,), fdtype)] + + grad_eqn = jaxpr.eqns[0] + assert grad_eqn.invars[0].aval == jaxpr.in_avals[0] + diff_eqn_assertions(grad_eqn, grad_prim) + grad_jaxpr = grad_eqn.params["jaxpr"] + assert len(grad_jaxpr.eqns) == 1 # qnode equation + + qnode_eqn = grad_jaxpr.eqns[0] + assert qnode_eqn.primitive == qnode_prim + assert qnode_eqn.invars[0].aval == jaxpr.in_avals[0] + + qfunc_jaxpr = qnode_eqn.params["qfunc_jaxpr"] + # Skipping a few equations related to indexing and preprocessing + assert qfunc_jaxpr.eqns[2].primitive == qml.RX._primitive + assert qfunc_jaxpr.eqns[6].primitive == qml.RY._primitive + assert qfunc_jaxpr.eqns[7].primitive == qml.Z._primitive + assert qfunc_jaxpr.eqns[8].primitive == qml.measurements.ExpectationMP._obs_primitive + + assert len(qnode_eqn.outvars) == 1 + assert qnode_eqn.outvars[0].aval == jax.core.ShapedArray((), fdtype) + + assert len(grad_eqn.outvars) == 1 + assert grad_eqn.outvars[0].aval == jax.core.ShapedArray((2,), fdtype) + + spy = mocker.spy(qml.gradients.parameter_shift, "expval_param_shift") + manual_res = jax.core.eval_jaxpr(jaxpr.jaxpr, jaxpr.consts, x) + if diff_method == "parameter-shift": + spy.assert_called_once() + else: + spy.assert_not_called() + assert qml.math.allclose(manual_res, expected_res) + + jax.config.update("jax_enable_x64", initial_mode) + + +def _jac_allclose(jac1, jac2, num_axes, atol=1e-8): + """Test that two Jacobians, given as nested sequences of arrays, are equal.""" + if num_axes == 0: + return qml.math.allclose(jac1, jac2, atol=atol) + if len(jac1) != len(jac2): + return False + return all( + _jac_allclose(_jac1, _jac2, num_axes - 1, atol=atol) for _jac1, _jac2 in zip(jac1, jac2) ) - assert qml.math.allclose(qml_func_3(x), expected_3) - - jaxpr_3 = jax.make_jaxpr(qml_func_3)(x) - assert jaxpr_3.in_avals == [jax.core.ShapedArray((), fdtype, weak_type=True)] - assert len(jaxpr_3.eqns) == 1 - assert jaxpr_3.out_avals == [jax.core.ShapedArray((), fdtype, weak_type=True)] - - grad_eqn = jaxpr_3.eqns[0] - assert [var.aval for var in grad_eqn.outvars] == jaxpr_3.out_avals - grad_eqn_assertions(grad_eqn) - assert len(grad_eqn.params["jaxpr"].eqns) == 1 # inner grad equation - assert grad_eqn.params["jaxpr"].eqns[0].primitive == grad_prim - - manual_eval_3 = jax.core.eval_jaxpr(jaxpr_3.jaxpr, jaxpr_3.consts, x) - assert qml.math.allclose(manual_eval_3, expected_3) - - jax.config.update("jax_enable_x64", initial_mode) - @pytest.mark.parametrize("x64_mode", (True, False)) -@pytest.mark.parametrize("diff_method", ("backprop", "parameter-shift")) -def test_grad_of_simple_qnode(x64_mode, diff_method, mocker): - """Test capturing the gradient of a simple qnode.""" - # pylint: disable=protected-access - initial_mode = jax.config.jax_enable_x64 - jax.config.update("jax_enable_x64", x64_mode) - fdtype = jax.numpy.float64 if x64_mode else jax.numpy.float32 - - dev = qml.device("default.qubit", wires=4) - - @qml.grad - @qml.qnode(dev, diff_method=diff_method) - def circuit(x): - qml.RX(x[0], wires=0) - qml.RY(x[1] ** 2, wires=0) - return qml.expval(qml.Z(0)) - - x = jnp.array([0.5, 0.9]) - res = circuit(x) - expected_res = ( - -jnp.sin(x[0]) * jnp.cos(x[1] ** 2), - -2 * x[1] * jnp.sin(x[1] ** 2) * jnp.cos(x[0]), - ) - assert qml.math.allclose(res, expected_res) - - jaxpr = jax.make_jaxpr(circuit)(x) - - assert len(jaxpr.eqns) == 1 # grad equation - assert jaxpr.in_avals == [jax.core.ShapedArray((2,), fdtype)] - assert jaxpr.out_avals == [jax.core.ShapedArray((2,), fdtype)] - - grad_eqn = jaxpr.eqns[0] - assert grad_eqn.invars[0].aval == jaxpr.in_avals[0] - grad_eqn_assertions(grad_eqn) - grad_jaxpr = grad_eqn.params["jaxpr"] - assert len(grad_jaxpr.eqns) == 1 # qnode equation - - qnode_eqn = grad_jaxpr.eqns[0] - assert qnode_eqn.primitive == qnode_prim - assert qnode_eqn.invars[0].aval == jaxpr.in_avals[0] - - qfunc_jaxpr = qnode_eqn.params["qfunc_jaxpr"] - # Skipping a few equations related to indexing and preprocessing - assert qfunc_jaxpr.eqns[2].primitive == qml.RX._primitive - assert qfunc_jaxpr.eqns[6].primitive == qml.RY._primitive - assert qfunc_jaxpr.eqns[7].primitive == qml.Z._primitive - assert qfunc_jaxpr.eqns[8].primitive == qml.measurements.ExpectationMP._obs_primitive - - assert len(qnode_eqn.outvars) == 1 - assert qnode_eqn.outvars[0].aval == jax.core.ShapedArray((), fdtype) - - assert len(grad_eqn.outvars) == 1 - assert grad_eqn.outvars[0].aval == jax.core.ShapedArray((2,), fdtype) - - spy = mocker.spy(qml.gradients.parameter_shift, "expval_param_shift") - manual_res = jax.core.eval_jaxpr(jaxpr.jaxpr, jaxpr.consts, x) - if diff_method == "parameter-shift": - spy.assert_called_once() - else: - spy.assert_not_called() - assert qml.math.allclose(manual_res, expected_res) - - jax.config.update("jax_enable_x64", initial_mode) +class TestJacobian: + """Tests for capturing `qml.jacobian`.""" + + @pytest.mark.parametrize("argnum", ([0, 1], [0], [1], 0, 1)) + def test_classical_jacobian(self, x64_mode, argnum): + """Test that the qml.jacobian primitive can be captured with classical nodes.""" + if isinstance(argnum, list) and len(argnum) > 1: + # These cases will only be unlocked with Pytree support + pytest.xfail() + + initial_mode = jax.config.jax_enable_x64 + jax.config.update("jax_enable_x64", x64_mode) + fdtype = jnp.float64 if x64_mode else jnp.float32 + + def shaped_array(shape): + """Make a ShapedArray with a given shape.""" + return jax.core.ShapedArray(shape, fdtype) + + def inner_func(x, y): + """A function with output signature + (4,), (2, 3) -> (2,), (4, 3), () + """ + return ( + x[0:2] * y[:, 1], + jnp.outer(x, y[0]).astype(jnp.float32), + jnp.prod(y) - jnp.sum(x), + ) + + x = jnp.array([0.3, 0.2, 0.1, 0.6]) + y = jnp.array([[0.4, -0.7, 0.2], [1.2, -7.2, 0.2]]) + func_qml = qml.jacobian(inner_func, argnum=argnum) + func_jax = jax.jacobian(inner_func, argnums=argnum) + + jax_out = func_jax(x, y) + num_axes = 1 if isinstance(argnum, int) else 2 + assert _jac_allclose(func_qml(x, y), jax_out, num_axes) + + # Check overall jaxpr properties + jaxpr = jax.make_jaxpr(func_jax)(x, y) + jaxpr = jax.make_jaxpr(func_qml)(x, y) + + if isinstance(argnum, int): + argnum = [argnum] + + exp_in_avals = [shaped_array(shape) for shape in [(4,), (2, 3)]] + # Expected Jacobian shapes for argnum=[0, 1] + exp_out_shapes = [[(2, 4), (2, 2, 3)], [(4, 3, 4), (4, 3, 2, 3)], [(4,), (2, 3)]] + # Slice out shapes corresponding to the actual argnum + exp_out_avals = [shaped_array(shapes[i]) for shapes in exp_out_shapes for i in argnum] + + assert jaxpr.in_avals == exp_in_avals + assert len(jaxpr.eqns) == 1 + assert jaxpr.out_avals == exp_out_avals + + jac_eqn = jaxpr.eqns[0] + diff_eqn_assertions(jac_eqn, jacobian_prim, argnum=argnum) + + manual_eval = jax.core.eval_jaxpr(jaxpr.jaxpr, jaxpr.consts, x, y) + assert _jac_allclose(manual_eval, jax_out, num_axes) + + jax.config.update("jax_enable_x64", initial_mode) + + def test_nested_jacobian(self, x64_mode): + r"""Test that nested qml.jacobian primitives can be captured. + We use the function + f(x) = (prod(x) * sin(x), sum(x**2)) + f'(x) = (prod(x)/x_i * sin(x) + prod(x) cos(x) e_i, 2 x_i) + f''(x) = | (prod(x)/x_i x_j * sin(x) + prod(x)cos(x) (e_j/x_i + e_i/x_j) + | - prod(x) sin(x) e_i e_j, 0) for i != j + | + | (2 prod(x)/x_i * cos(x) e_i - prod(x) sin(x) e_i e_i, 2) for i = j + """ + # pylint: disable=too-many-statements + initial_mode = jax.config.jax_enable_x64 + jax.config.update("jax_enable_x64", x64_mode) + fdtype = jnp.float64 if x64_mode else jnp.float32 + + def func(x): + return jnp.prod(x) * jnp.sin(x), jnp.sum(x**2) + + x = jnp.array([0.7, -0.9, 0.6, 0.3]) + x = x[:1] + dim = len(x) + eye = jnp.eye(dim) + + # 1st order + qml_func_1 = qml.jacobian(func) + prod_sin = jnp.prod(x) * jnp.sin(x) + prod_cos_e_i = jnp.prod(x) * jnp.cos(x) * eye + expected_1 = (prod_sin[:, None] / x[None, :] + prod_cos_e_i, 2 * x) + assert _jac_allclose(qml_func_1(x), expected_1, 1) + + jaxpr_1 = jax.make_jaxpr(qml_func_1)(x) + assert jaxpr_1.in_avals == [jax.core.ShapedArray((dim,), fdtype)] + assert len(jaxpr_1.eqns) == 1 + assert jaxpr_1.out_avals == [ + jax.core.ShapedArray(sh, fdtype) for sh in [(dim, dim), (dim,)] + ] + + jac_eqn = jaxpr_1.eqns[0] + assert [var.aval for var in jac_eqn.outvars] == jaxpr_1.out_avals + diff_eqn_assertions(jac_eqn, jacobian_prim) + assert len(jac_eqn.params["jaxpr"].eqns) == 5 + + manual_eval_1 = jax.core.eval_jaxpr(jaxpr_1.jaxpr, jaxpr_1.consts, x) + assert _jac_allclose(manual_eval_1, expected_1, 1) + + # 2nd order + qml_func_2 = qml.jacobian(qml_func_1) + expected_2 = ( + prod_sin[:, None, None] / x[None, :, None] / x[None, None, :] + + prod_cos_e_i[:, :, None] / x[None, None, :] + + prod_cos_e_i[:, None, :] / x[None, :, None] + - jnp.tensordot(prod_sin, eye + eye / x**2, axes=0), + jnp.tensordot(jnp.ones(dim), eye * 2, axes=0), + ) + # Output only has one tuple axis + assert _jac_allclose(qml_func_2(x), expected_2, 1) + + jaxpr_2 = jax.make_jaxpr(qml_func_2)(x) + assert jaxpr_2.in_avals == [jax.core.ShapedArray((dim,), fdtype)] + assert len(jaxpr_2.eqns) == 1 + assert jaxpr_2.out_avals == [ + jax.core.ShapedArray(sh, fdtype) for sh in [(dim, dim, dim), (dim, dim)] + ] + + jac_eqn = jaxpr_2.eqns[0] + assert [var.aval for var in jac_eqn.outvars] == jaxpr_2.out_avals + diff_eqn_assertions(jac_eqn, jacobian_prim) + assert len(jac_eqn.params["jaxpr"].eqns) == 1 # inner jacobian equation + assert jac_eqn.params["jaxpr"].eqns[0].primitive == jacobian_prim + + manual_eval_2 = jax.core.eval_jaxpr(jaxpr_2.jaxpr, jaxpr_2.consts, x) + assert _jac_allclose(manual_eval_2, expected_2, 1) + + jax.config.update("jax_enable_x64", initial_mode) + + @pytest.mark.parametrize("diff_method", ("backprop", "parameter-shift")) + def test_jacobian_of_simple_qnode(self, x64_mode, diff_method, mocker): + """Test capturing the gradient of a simple qnode.""" + # pylint: disable=protected-access + initial_mode = jax.config.jax_enable_x64 + jax.config.update("jax_enable_x64", x64_mode) + fdtype = jax.numpy.float64 if x64_mode else jax.numpy.float32 + + dev = qml.device("default.qubit", wires=2) + + # Note the decorator + @qml.jacobian + @qml.qnode(dev, diff_method=diff_method) + def circuit(x): + qml.RX(x[0], wires=0) + qml.RY(x[1], wires=0) + return qml.expval(qml.Z(0)), qml.probs(0) + + x = jnp.array([0.5, 0.9]) + res = circuit(x) + expval_diff = -jnp.sin(x) * jnp.cos(x[::-1]) + expected_res = (expval_diff, jnp.stack([expval_diff / 2, -expval_diff / 2])) + + assert _jac_allclose(res, expected_res, 1) + + jaxpr = jax.make_jaxpr(circuit)(x) + + assert len(jaxpr.eqns) == 1 # Jacobian equation + assert jaxpr.in_avals == [jax.core.ShapedArray((2,), fdtype)] + assert jaxpr.out_avals == [jax.core.ShapedArray(sh, fdtype) for sh in [(2,), (2, 2)]] + + jac_eqn = jaxpr.eqns[0] + assert jac_eqn.invars[0].aval == jaxpr.in_avals[0] + diff_eqn_assertions(jac_eqn, jacobian_prim) + jac_jaxpr = jac_eqn.params["jaxpr"] + assert len(jac_jaxpr.eqns) == 1 # qnode equation + + qnode_eqn = jac_jaxpr.eqns[0] + assert qnode_eqn.primitive == qnode_prim + assert qnode_eqn.invars[0].aval == jaxpr.in_avals[0] + + qfunc_jaxpr = qnode_eqn.params["qfunc_jaxpr"] + # Skipping a few equations related to indexing + assert qfunc_jaxpr.eqns[2].primitive == qml.RX._primitive + assert qfunc_jaxpr.eqns[5].primitive == qml.RY._primitive + assert qfunc_jaxpr.eqns[6].primitive == qml.Z._primitive + assert qfunc_jaxpr.eqns[7].primitive == qml.measurements.ExpectationMP._obs_primitive + + assert len(qnode_eqn.outvars) == 2 + assert qnode_eqn.outvars[0].aval == jax.core.ShapedArray((), fdtype) + assert qnode_eqn.outvars[1].aval == jax.core.ShapedArray((2,), fdtype) + + assert [outvar.aval for outvar in jac_eqn.outvars] == jaxpr.out_avals + + spy = mocker.spy(qml.gradients.parameter_shift, "expval_param_shift") + manual_res = jax.core.eval_jaxpr(jaxpr.jaxpr, jaxpr.consts, x) + if diff_method == "parameter-shift": + spy.assert_called_once() + else: + spy.assert_not_called() + assert _jac_allclose(manual_res, expected_res, 1) + + jax.config.update("jax_enable_x64", initial_mode) diff --git a/tests/test_compiler.py b/tests/test_compiler.py index 5c02608286d..c167c59d407 100644 --- a/tests/test_compiler.py +++ b/tests/test_compiler.py @@ -759,7 +759,7 @@ def circuit(x): with pytest.raises( ValueError, - match="Invalid values for 'method=fd' and 'h=0.3' in interpreted mode", + match="Invalid values 'method='fd'' and 'h=0.3' without QJIT", ): workflow(np.array([2.0, 1.0]))