Skip to content

Commit

Permalink
Set grad_method=None for ControlledSequence, Reflection, `Ampli…
Browse files Browse the repository at this point in the history
…tudeAmplification`, and `Qubitization`. (#5806)

**Context:**
Templates that are not actually supported by `parameter_shift` should
have `grad_method=None` so that they are decomposed by
`_expand_transform_param_shift`

**Description of the Change:**
1. Adds the `data` of components of the templates to the `data` of the
templates such that trainable parameters are tracked
2. Adds `grad_method=None` for `ControlledSequence`, `Reflection`,
`AmplitudeAmplification`, and `Qubitization`.

**Related GitHub Issues:**
Fixes #5802
[sc-64967]
  • Loading branch information
astralcai committed Jun 6, 2024
1 parent 38b3e74 commit 22da9a0
Show file tree
Hide file tree
Showing 9 changed files with 108 additions and 38 deletions.
3 changes: 3 additions & 0 deletions doc/releases/changelog-dev.md
Original file line number Diff line number Diff line change
Expand Up @@ -332,6 +332,9 @@
* `CNOT` and `Toffoli` now have an `arithmetic_depth` of `1`, as they are controlled operations.
[(#5797)](https://github.com/PennyLaneAI/pennylane/pull/5797)

* Fixes a bug where the gradient of `ControlledSequence`, `Reflection`, `AmplitudeAmplification`, and `Qubitization` is incorrect on `default.qubit.legacy` with `parameter_shift`.
[(#5806)](https://github.com/PennyLaneAI/pennylane/pull/5806)

<h3>Contributors ✍️</h3>

This release contains contributions from (in alphabetical order):
Expand Down
6 changes: 4 additions & 2 deletions pennylane/templates/subroutines/amplitude_amplification.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,8 @@ def circuit():
[0.013, 0.013, 0.91, 0.013, 0.013, 0.013, 0.013, 0.013]
"""

grad_method = None

def _flatten(self):
data = (self.hyperparameters["U"], self.hyperparameters["O"])
metadata = tuple(item for item in self.hyperparameters.items() if item[0] not in ["O", "U"])
Expand Down Expand Up @@ -141,11 +143,11 @@ def __init__(
self.hyperparameters["p_min"] = p_min
self.hyperparameters["reflection_wires"] = qml.wires.Wires(reflection_wires)

super().__init__(wires=wires)
super().__init__(*U.data, *O.data, wires=wires)

# pylint:disable=arguments-differ
@staticmethod
def compute_decomposition(**kwargs):
def compute_decomposition(*_, **kwargs):
U = kwargs["U"]
O = kwargs["O"]
iters = kwargs["iters"]
Expand Down
2 changes: 2 additions & 0 deletions pennylane/templates/subroutines/controlled_sequence.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,8 @@ def circuit():
"""

grad_method = None

def _flatten(self):
return (self.base,), (self.control,)

Expand Down
4 changes: 3 additions & 1 deletion pennylane/templates/subroutines/qubitization.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,8 @@ def circuit():
eigenvalue: 0.7
"""

grad_method = None

@classmethod
def _primitive_bind_call(cls, *args, **kwargs):
return cls._primitive.bind(*args, **kwargs)
Expand All @@ -104,7 +106,7 @@ def __init__(self, hamiltonian, control, id=None):
"control": qml.wires.Wires(control),
}

super().__init__(wires=wires, id=id)
super().__init__(*hamiltonian.data, wires=wires, id=id)

def _flatten(self):
data = (self.hyperparameters["hamiltonian"],)
Expand Down
4 changes: 3 additions & 1 deletion pennylane/templates/subroutines/reflection.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,8 @@ def circuit():
"""

grad_method = None

@classmethod
def _primitive_bind_call(cls, *args, **kwargs):
return cls._primitive.bind(*args, **kwargs)
Expand Down Expand Up @@ -136,7 +138,7 @@ def __init__(self, U, alpha=np.pi, reflection_wires=None, id=None):
"reflection_wires": tuple(reflection_wires),
}

super().__init__(alpha, wires=wires, id=id)
super().__init__(alpha, *U.data, wires=wires, id=id)

def map_wires(self, wire_map: dict):
# pylint: disable=protected-access
Expand Down
31 changes: 22 additions & 9 deletions tests/templates/test_subroutines/test_amplitude_amplification.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ def circuit(params):
qml.RZ(params[1], wires=0),
iters=3,
fixed_point=True,
work_wire=3,
work_wire=2,
)

return qml.expval(qml.PauliZ(0))
Expand All @@ -156,28 +156,36 @@ def circuit(params):
params = np.array([0.9, 0.1])

@pytest.mark.autograd
def test_qnode_autograd(self):
@pytest.mark.parametrize("device", ["default.qubit", "default.qubit.legacy"])
@pytest.mark.parametrize("shots", [None, 50000])
def test_qnode_autograd(self, device, shots):
"""Test that the QNode executes with Autograd."""

dev = qml.device("default.qubit")
qnode = qml.QNode(self.circuit, dev, interface="autograd")
dev = qml.device(device, wires=3, shots=shots)
diff_method = "backprop" if shots is None else "parameter-shift"
qnode = qml.QNode(self.circuit, dev, interface="autograd", diff_method=diff_method)

params = qml.numpy.array(self.params, requires_grad=True)
res = qml.grad(qnode)(params)
assert qml.math.shape(res) == (2,)
assert np.allclose(res, self.exp_grad, atol=1e-5)
assert np.allclose(res, self.exp_grad, atol=0.01)

@pytest.mark.jax
@pytest.mark.parametrize("use_jit", [False, True])
@pytest.mark.parametrize("shots", [None, 50000])
def test_qnode_jax(self, shots, use_jit):
@pytest.mark.parametrize("device", ["default.qubit", "default.qubit.legacy"])
def test_qnode_jax(self, shots, use_jit, device):
"""Test that the QNode executes and is differentiable with JAX. The shots
argument controls whether autodiff or parameter-shift gradients are used."""
import jax

jax.config.update("jax_enable_x64", True)

dev = qml.device("default.qubit", shots=shots, seed=10)
if device == "default.qubit":
dev = qml.device("default.qubit", shots=shots, seed=10)
else:
dev = qml.device("default.qubit.legacy", shots=shots, wires=3)

diff_method = "backprop" if shots is None else "parameter-shift"
qnode = qml.QNode(self.circuit, dev, interface="jax", diff_method=diff_method)
if use_jit:
Expand All @@ -195,12 +203,17 @@ def test_qnode_jax(self, shots, use_jit):

@pytest.mark.torch
@pytest.mark.parametrize("shots", [None, 50000])
def test_qnode_torch(self, shots):
@pytest.mark.parametrize("device", ["default.qubit", "default.qubit.legacy"])
def test_qnode_torch(self, shots, device):
"""Test that the QNode executes and is differentiable with Torch. The shots
argument controls whether autodiff or parameter-shift gradients are used."""
import torch

dev = qml.device("default.qubit", shots=shots, seed=10)
if device == "default.qubit":
dev = qml.device("default.qubit", shots=shots, seed=10)
else:
dev = qml.device("default.qubit.legacy", shots=shots, wires=3)

diff_method = "backprop" if shots is None else "parameter-shift"
qnode = qml.QNode(self.circuit, dev, interface="torch", diff_method=diff_method)

Expand Down
45 changes: 34 additions & 11 deletions tests/templates/test_subroutines/test_controlled_sequence.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ def test_standard_validity():


class TestInitialization:

def test_id(self):
"""Tests that the id attribute can be set."""
op = qml.ControlledSequence(qml.RX(0.25, wires=3), control=[0, 1, 2], id="a")
Expand All @@ -57,6 +58,7 @@ def test_name(self):


class TestProperties:

def test_hash(self):
"""Test that op.hash uniquely describes a ControlledSequence"""

Expand Down Expand Up @@ -97,6 +99,7 @@ def test_has_matrix(self):


class TestMethods:

def test_repr(self):
"""Test that the operator repr is as expected"""
op = qml.ControlledSequence(qml.RX(0.25, wires=3), control=[0, 1, 2])
Expand Down Expand Up @@ -189,28 +192,41 @@ def test_qnode_numpy(self):
assert np.allclose(res, self.exp_result, atol=0.002)

@pytest.mark.autograd
def test_qnode_autograd(self):
@pytest.mark.parametrize("shots", [None, 50000])
@pytest.mark.parametrize("device", ["default.qubit", "default.qubit.legacy"])
def test_qnode_autograd(self, shots, device):
"""Test that the QNode executes with Autograd."""

dev = qml.device("default.qubit")
qnode = qml.QNode(self.circuit, dev, interface="autograd")

dev = qml.device(device, wires=4, shots=shots)
diff_method = "backprop" if shots is None else "parameter-shift"
qnode = qml.QNode(self.circuit, dev, interface="autograd", diff_method=diff_method)
x = qml.numpy.array(self.x, requires_grad=True)

res = qnode(x)
assert qml.math.shape(res) == (16,)
assert np.allclose(res, self.exp_result, atol=0.002)

res = qml.jacobian(qnode)(x)
assert np.shape(res) == (16,)
assert np.allclose(res, self.exp_jac, atol=0.005)

@pytest.mark.jax
@pytest.mark.parametrize("use_jit", [False, True])
@pytest.mark.parametrize("shots", [None, 10000])
def test_qnode_jax(self, shots, use_jit):
@pytest.mark.parametrize("shots", [None, 50000])
@pytest.mark.parametrize("device", ["default.qubit", "default.qubit.legacy"])
def test_qnode_jax(self, shots, use_jit, device):
"""Test that the QNode executes and is differentiable with JAX. The shots
argument controls whether autodiff or parameter-shift gradients are used."""

import jax

jax.config.update("jax_enable_x64", True)

dev = qml.device("default.qubit", shots=shots, seed=10)
if device == "default.qubit":
dev = qml.device("default.qubit", shots=shots, seed=10)
else:
dev = qml.device("default.qubit.legacy", shots=shots, wires=4)

diff_method = "backprop" if shots is None else "parameter-shift"
qnode = qml.QNode(self.circuit, dev, interface="jax", diff_method=diff_method)
if use_jit:
Expand All @@ -230,13 +246,19 @@ def test_qnode_jax(self, shots, use_jit):
assert np.allclose(jac, self.exp_jac, atol=0.006)

@pytest.mark.torch
@pytest.mark.parametrize("shots", [None, 10000])
def test_qnode_torch(self, shots):
@pytest.mark.parametrize("shots", [None, 50000])
@pytest.mark.parametrize("device", ["default.qubit", "default.qubit.legacy"])
def test_qnode_torch(self, shots, device):
"""Test that the QNode executes and is differentiable with Torch. The shots
argument controls whether autodiff or parameter-shift gradients are used."""

import torch

dev = qml.device("default.qubit", shots=shots, seed=10)
if device == "default.qubit":
dev = qml.device("default.qubit", shots=shots, seed=10)
else:
dev = qml.device("default.qubit.legacy", shots=shots, wires=4)

diff_method = "backprop" if shots is None else "parameter-shift"
qnode = qml.QNode(self.circuit, dev, interface="torch", diff_method=diff_method)

Expand All @@ -247,14 +269,15 @@ def test_qnode_torch(self, shots):

jac = torch.autograd.functional.jacobian(qnode, x)
assert qml.math.shape(jac) == (16,)
assert qml.math.allclose(jac, self.exp_jac, atol=0.006)
assert qml.math.allclose(jac, self.exp_jac, atol=0.005)

@pytest.mark.tf
@pytest.mark.parametrize("shots", [None, 10000])
@pytest.mark.xfail(reason="tf gradient doesn't seem to be working, returns ()")
def test_qnode_tf(self, shots):
"""Test that the QNode executes and is differentiable with TensorFlow. The shots
argument controls whether autodiff or parameter-shift gradients are used."""

import tensorflow as tf

dev = qml.device("default.qubit", shots=shots, seed=10)
Expand Down
11 changes: 8 additions & 3 deletions tests/templates/test_subroutines/test_qubitization.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,14 +227,19 @@ def test_qnode_autograd(self):
"use_jit , shots",
((False, None), (True, None), (False, 50000)),
) # TODO: (True, 50000) fails because jax.jit on jax.grad does not work with AmplitudeEmbedding
def test_qnode_jax(self, shots, use_jit):
@pytest.mark.parametrize("device", ["default.qubit", "default.qubit.legacy"])
def test_qnode_jax(self, shots, use_jit, device):
""" "Test that the QNode executes and is differentiable with JAX. The shots
argument controls whether autodiff or parameter-shift gradients are used."""
import jax

jax.config.update("jax_enable_x64", True)

dev = qml.device("default.qubit", shots=shots, seed=10)
if device == "default.qubit":
dev = qml.device("default.qubit", shots=shots, seed=10)
else:
dev = qml.device("default.qubit.legacy", shots=shots, wires=5)

diff_method = "backprop" if shots is None else "parameter-shift"
qnode = qml.QNode(self.circuit, dev, interface="jax", diff_method=diff_method)
if use_jit:
Expand All @@ -248,7 +253,7 @@ def test_qnode_jax(self, shots, use_jit):

jac = jac_fn(params)
assert jac.shape == (4,)
assert np.allclose(jac, self.exp_grad, atol=0.01)
assert np.allclose(jac, self.exp_grad, atol=0.05)

@pytest.mark.torch
@pytest.mark.parametrize(
Expand Down
40 changes: 29 additions & 11 deletions tests/templates/test_subroutines/test_reflection.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,28 +163,40 @@ def test_lightning_qubit(self):
assert np.allclose(res, self.exp_result, atol=0.002)

@pytest.mark.autograd
def test_qnode_autograd(self):
@pytest.mark.parametrize("shots", [None, 50000])
@pytest.mark.parametrize("device", ["default.qubit", "default.qubit.legacy"])
def test_qnode_autograd(self, shots, device):
"""Test that the QNode executes with Autograd."""

dev = qml.device("default.qubit")
qnode = qml.QNode(self.circuit, dev, interface="autograd")
dev = qml.device(device, shots=shots, wires=3)
diff_method = "backprop" if shots is None else "parameter-shift"
qnode = qml.QNode(self.circuit, dev, interface="autograd", diff_method=diff_method)

x = qml.numpy.array(self.x, requires_grad=True)
res = qnode(x)
assert qml.math.shape(res) == (8,)
assert np.allclose(res, self.exp_result, atol=0.002)
assert np.allclose(res, self.exp_result, atol=0.005)

res = qml.jacobian(qnode)(x)
assert np.shape(res) == (8,)
assert np.allclose(res, self.exp_jac, atol=0.005)

@pytest.mark.jax
@pytest.mark.parametrize("use_jit", [False, True])
@pytest.mark.parametrize("shots", [None, 50000])
def test_qnode_jax(self, shots, use_jit):
@pytest.mark.parametrize("device", ["default.qubit", "default.qubit.legacy"])
def test_qnode_jax(self, shots, use_jit, device):
"""Test that the QNode executes and is differentiable with JAX. The shots
argument controls whether autodiff or parameter-shift gradients are used."""
import jax

jax.config.update("jax_enable_x64", True)

dev = qml.device("default.qubit", shots=shots, seed=10)
if device == "default.qubit":
dev = qml.device("default.qubit", shots=shots, seed=10)
else:
dev = qml.device("default.qubit.legacy", shots=shots, wires=3)

diff_method = "backprop" if shots is None else "parameter-shift"
qnode = qml.QNode(self.circuit, dev, interface="jax", diff_method=diff_method)
if use_jit:
Expand All @@ -201,27 +213,33 @@ def test_qnode_jax(self, shots, use_jit):

jac = jac_fn(x)
assert jac.shape == (8,)
assert np.allclose(jac, self.exp_jac, atol=0.006)
assert np.allclose(jac, self.exp_jac, atol=0.005)

@pytest.mark.torch
@pytest.mark.parametrize("shots", [None, 50000])
def test_qnode_torch(self, shots):
@pytest.mark.parametrize("device", ["default.qubit", "default.qubit.legacy"])
def test_qnode_torch(self, shots, device):
"""Test that the QNode executes and is differentiable with Torch. The shots
argument controls whether autodiff or parameter-shift gradients are used."""

import torch

dev = qml.device("default.qubit", shots=shots, seed=10)
if device == "default.qubit":
dev = qml.device("default.qubit", shots=shots, seed=10)
else:
dev = qml.device("default.qubit.legacy", shots=shots, wires=3)

diff_method = "backprop" if shots is None else "parameter-shift"
qnode = qml.QNode(self.circuit, dev, interface="torch", diff_method=diff_method)

x = torch.tensor(self.x, requires_grad=True)
res = qnode(x)
assert qml.math.shape(res) == (8,)
assert qml.math.allclose(res, self.exp_result, atol=0.002)
assert qml.math.allclose(res, self.exp_result, atol=0.005)

jac = torch.autograd.functional.jacobian(qnode, x)
assert qml.math.shape(jac) == (8,)
assert qml.math.allclose(jac, self.exp_jac, atol=0.006)
assert qml.math.allclose(jac, self.exp_jac, atol=0.005)

@pytest.mark.tf
@pytest.mark.parametrize("shots", [None, 50000])
Expand Down

0 comments on commit 22da9a0

Please sign in to comment.