Skip to content

Commit

Permalink
Fix the derivative of MottonenStatePreparation where possible (#5774)
Browse files Browse the repository at this point in the history
**Context:**
The decomposition of `MottonenStatePreparation` skips some gates for
special parameter values/input states.
See the linked issue for details.

**Description of the Change:**
This PR introduces a check for differentiability so that the gates only
are skipped when no derivatives are being computed.
Note that this does *not* fix the non-differentiability at other special
parameter points that also is referenced in #5715 and that is being
warned against in the docs already.
Also, the linked issue is about multiple operations and we here only
address `MottonenStatePreparation`.

**Benefits:**
Fixes parts of #5715. Unblocks #5620 .

**Possible Drawbacks:**

**Related GitHub Issues:**
#5715
  • Loading branch information
dwierichs authored Jun 13, 2024
1 parent 3771f24 commit eb5c192
Show file tree
Hide file tree
Showing 4 changed files with 132 additions and 159 deletions.
3 changes: 3 additions & 0 deletions doc/releases/changelog-dev.md
Original file line number Diff line number Diff line change
Expand Up @@ -323,6 +323,9 @@

<h3>Bug fixes 🐛</h3>

* Fixes a bug where `MottonenStatePreparation` produces wrong derivatives at special parameter values.
[(#5774)](https://github.com/PennyLaneAI/pennylane/pull/5774)

* Fixes a bug where fractional powers and adjoints of operators were commuted, which is
not well-defined/correct in general. Adjoints of fractional powers can no longer be evaluated.
[(#5835)](https://github.com/PennyLaneAI/pennylane/pull/5835)
Expand Down
21 changes: 16 additions & 5 deletions pennylane/templates/state_preparations/mottonen.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,6 @@ def compute_theta(alpha):
(tensor_like): rotation angles theta
"""
ln = alpha.shape[-1]
k = np.log2(ln)

M_trans = np.zeros(shape=(ln, ln))
for i in range(len(M_trans)):
Expand All @@ -91,7 +90,7 @@ def compute_theta(alpha):

theta = qml.math.transpose(qml.math.dot(M_trans, qml.math.transpose(alpha)))

return theta / 2**k
return theta / ln


def _apply_uniform_rotation_dagger(gate, alpha, control_wires, target_wire):
Expand Down Expand Up @@ -124,7 +123,11 @@ def _apply_uniform_rotation_dagger(gate, alpha, control_wires, target_wire):
gray_code_rank = len(control_wires)

if gray_code_rank == 0:
if qml.math.is_abstract(theta) or qml.math.all(theta[..., 0] != 0.0):
if (
qml.math.is_abstract(theta)
or qml.math.requires_grad(theta)
or qml.math.all(theta[..., 0] != 0.0)
):
op_list.append(gate(theta[..., 0], wires=[target_wire]))
return op_list

Expand All @@ -137,7 +140,11 @@ def _apply_uniform_rotation_dagger(gate, alpha, control_wires, target_wire):
]

for i, control_index in enumerate(control_indices):
if qml.math.is_abstract(theta) or qml.math.all(theta[..., i] != 0.0):
if (
qml.math.is_abstract(theta)
or qml.math.requires_grad(theta)
or qml.math.all(theta[..., i] != 0.0)
):
op_list.append(gate(theta[..., i], wires=[target_wire]))
op_list.append(qml.CNOT(wires=[control_wires[control_index], target_wire]))
return op_list
Expand Down Expand Up @@ -366,7 +373,11 @@ def compute_decomposition(state_vector, wires): # pylint: disable=arguments-dif
op_list.extend(_apply_uniform_rotation_dagger(qml.RY, alpha_y_k, control, target))

# If necessary, apply inverse z rotation cascade to prepare correct phases of amplitudes
if qml.math.is_abstract(omega) or not qml.math.allclose(omega, 0):
if (
qml.math.is_abstract(omega)
or qml.math.requires_grad(omega)
or not qml.math.allclose(omega, 0)
):
for k in range(len(wires_reverse), 0, -1):
alpha_z_k = _get_alpha_z(omega, len(wires_reverse), k)
control = wires_reverse[k:]
Expand Down
245 changes: 98 additions & 147 deletions tests/templates/test_state_preparations/test_mottonen_state_prep.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,74 +67,60 @@ def test_get_alpha_y(self, current_qubit, expected, tol):
assert np.allclose(res, expected, atol=tol)


# fmt: off
fixed_states = (
[
-0.17133152 - 0.18777771j, 0.00240643 - 0.40704011j, 0.18684538 - 0.36315606j,
-0.07096948 + 0.104501j, 0.30357755 - 0.23831927j, -0.38735106 + 0.36075556j,
0.12351096 - 0.0539908j, 0.27942828 - 0.24810483j,
],
[
-0.29972867 + 0.04964242j, -0.28309418 + 0.09873227j, 0.00785743 - 0.37560696j,
-0.3825148 + 0.00674343j, -0.03008048 + 0.31119167j, 0.03666351 - 0.15935903j,
-0.25358831 + 0.35461265j, -0.32198531 + 0.33479292j,
],
[
-0.39340123 + 0.05705932j, 0.1980509 - 0.24234781j, 0.27265585 - 0.0604432j,
-0.42641249 + 0.25767258j, 0.40386614 - 0.39925987j, 0.03924761 + 0.13193724j,
-0.06059103 - 0.01753834j, 0.21707136 - 0.15887973j,
],
[
-1.33865287e-01 + 0.09802308j, 1.25060033e-01 + 0.16087698j, -4.14678130e-01 - 0.00774832j,
1.10121136e-01 + 0.37805482j, -3.21284864e-01 + 0.21521063j, -2.23121454e-04 + 0.28417422j,
5.64131205e-02 + 0.38135286j, 2.32694503e-01 + 0.41331133j,
],
)
# fmt: on
decomposition_test_cases = [
([1, 0], 0, np.eye(8)[0]),
([1, 0], [0], np.eye(8)[0]),
([1, 0], [1], np.eye(8)[0]),
([1, 0], [2], np.eye(8)[0]),
([0, 1], [0], np.eye(8)[4]),
([0, 1], [1], np.eye(8)[2]),
([0, 1], [2], np.eye(8)[1]),
([0, 1, 0, 0], [0, 1], np.eye(8)[2]),
([0, 0, 0, 1], [0, 2], np.eye(8)[5]),
([0, 0, 0, 1], [1, 2], np.eye(8)[3]),
(np.eye(8)[0], [0, 1, 2], np.eye(8)[0]),
(1j * np.eye(8)[4], [0, 1, 2], 1j * np.eye(8)[4]),
(x := np.array([1, 0, 0, 0, 1, 1j, -1, 0]) / 2, [0, 1, 2], x),
(x := np.array([1, 0, 0, 0, 2j, 2j, 0, 0]) / 3, [0, 1, 2], x),
(x := np.array([2, 0, 0, 0, 1, 0, 0, 2]) / 3, [0, 1, 2], x),
(x := np.array([1, 1j, 1, -1j, 1, 1, 1, 1j]) / np.sqrt(8), [0, 1, 2], x),
(fixed_states[0], [0, 1, 2], fixed_states[0]),
(fixed_states[1], [0, 1, 2], fixed_states[1]),
(fixed_states[2], [0, 1, 2], fixed_states[2]),
(fixed_states[3], [0, 1, 2], fixed_states[3]),
(x := np.array([1 / 2, 0, 0, 0, 1j / 2, 0, 1j / np.sqrt(2), 0]), [0, 1, 2], x),
(np.array([1 / 2, 0, 1j / 2, 1j / np.sqrt(2)]), [0, 1], x),
]


class TestDecomposition:
"""Tests that the template defines the correct decomposition."""

# fmt: off
@pytest.mark.parametrize("state_vector,wires,target_state", [
([1, 0], 0, [1, 0, 0, 0, 0, 0, 0, 0]),
([1, 0], [0], [1, 0, 0, 0, 0, 0, 0, 0]),
([1, 0], [1], [1, 0, 0, 0, 0, 0, 0, 0]),
([1, 0], [2], [1, 0, 0, 0, 0, 0, 0, 0]),
([0, 1], [0], [0, 0, 0, 0, 1, 0, 0, 0]),
([0, 1], [1], [0, 0, 1, 0, 0, 0, 0, 0]),
([0, 1], [2], [0, 1, 0, 0, 0, 0, 0, 0]),
([0, 1, 0, 0], [0, 1], [0, 0, 1, 0, 0, 0, 0, 0]),
([0, 0, 0, 1], [0, 2], [0, 0, 0, 0, 0, 1, 0, 0]),
([0, 0, 0, 1], [1, 2], [0, 0, 0, 1, 0, 0, 0, 0]),
([1, 0, 0, 0, 0, 0, 0, 0], [0, 1, 2], [1, 0, 0, 0, 0, 0, 0, 0]),
([0, 0, 0, 0, 1j, 0, 0, 0], [0, 1, 2], [0, 0, 0, 0, 1j, 0, 0, 0]),
([1 / 2, 0, 0, 0, 1 / 2, 1j / 2, -1 / 2, 0], [0, 1, 2], [1 / 2, 0, 0, 0, 1 / 2, 1j / 2, -1 / 2, 0]),
([1 / 3, 0, 0, 0, 2j / 3, 2j / 3, 0, 0], [0, 1, 2], [1 / 3, 0, 0, 0, 2j / 3, 2j / 3, 0, 0]),
([2 / 3, 0, 0, 0, 1 / 3, 0, 0, 2 / 3], [0, 1, 2], [2 / 3, 0, 0, 0, 1 / 3, 0, 0, 2 / 3]),
(
[1 / np.sqrt(8), 1j / np.sqrt(8), 1 / np.sqrt(8), -1j / np.sqrt(8), 1 / np.sqrt(8),
1 / np.sqrt(8), 1 / np.sqrt(8), 1j / np.sqrt(8)],
[0, 1, 2],
[1 / np.sqrt(8), 1j / np.sqrt(8), 1 / np.sqrt(8), -1j / np.sqrt(8), 1 / np.sqrt(8),
1 / np.sqrt(8), 1 / np.sqrt(8), 1j / np.sqrt(8)],
),
(
[-0.17133152 - 0.18777771j, 0.00240643 - 0.40704011j, 0.18684538 - 0.36315606j, -0.07096948 + 0.104501j,
0.30357755 - 0.23831927j, -0.38735106 + 0.36075556j, 0.12351096 - 0.0539908j,
0.27942828 - 0.24810483j],
[0, 1, 2],
[-0.17133152 - 0.18777771j, 0.00240643 - 0.40704011j, 0.18684538 - 0.36315606j, -0.07096948 + 0.104501j,
0.30357755 - 0.23831927j, -0.38735106 + 0.36075556j, 0.12351096 - 0.0539908j,
0.27942828 - 0.24810483j],
),
(
[-0.29972867 + 0.04964242j, -0.28309418 + 0.09873227j, 0.00785743 - 0.37560696j,
-0.3825148 + 0.00674343j, -0.03008048 + 0.31119167j, 0.03666351 - 0.15935903j,
-0.25358831 + 0.35461265j, -0.32198531 + 0.33479292j],
[0, 1, 2],
[-0.29972867 + 0.04964242j, -0.28309418 + 0.09873227j, 0.00785743 - 0.37560696j,
-0.3825148 + 0.00674343j, -0.03008048 + 0.31119167j, 0.03666351 - 0.15935903j,
-0.25358831 + 0.35461265j, -0.32198531 + 0.33479292j],
),
(
[-0.39340123 + 0.05705932j, 0.1980509 - 0.24234781j, 0.27265585 - 0.0604432j, -0.42641249 + 0.25767258j,
0.40386614 - 0.39925987j, 0.03924761 + 0.13193724j, -0.06059103 - 0.01753834j,
0.21707136 - 0.15887973j],
[0, 1, 2],
[-0.39340123 + 0.05705932j, 0.1980509 - 0.24234781j, 0.27265585 - 0.0604432j, -0.42641249 + 0.25767258j,
0.40386614 - 0.39925987j, 0.03924761 + 0.13193724j, -0.06059103 - 0.01753834j,
0.21707136 - 0.15887973j],
),
(
[-1.33865287e-01 + 0.09802308j, 1.25060033e-01 + 0.16087698j, -4.14678130e-01 - 0.00774832j,
1.10121136e-01 + 0.37805482j, -3.21284864e-01 + 0.21521063j, -2.23121454e-04 + 0.28417422j,
5.64131205e-02 + 0.38135286j, 2.32694503e-01 + 0.41331133j],
[0, 1, 2],
[-1.33865287e-01 + 0.09802308j, 1.25060033e-01 + 0.16087698j, -4.14678130e-01 - 0.00774832j,
1.10121136e-01 + 0.37805482j, -3.21284864e-01 + 0.21521063j, -2.23121454e-04 + 0.28417422j,
5.64131205e-02 + 0.38135286j, 2.32694503e-01 + 0.41331133j],
),
([1 / 2, 0, 0, 0, 1j / 2, 0, 1j / np.sqrt(2), 0], [0, 1, 2],
[1 / 2, 0, 0, 0, 1j / 2, 0, 1j / np.sqrt(2), 0]),
([1 / 2, 0, 1j / 2, 1j / np.sqrt(2)], [0, 1], [1 / 2, 0, 0, 0, 1j / 2, 0, 1j / np.sqrt(2), 0]),
])
# fmt: on
@pytest.mark.parametrize("state_vector,wires,target_state", decomposition_test_cases)
def test_state_preparation(self, state_vector, wires, target_state):
"""Tests that the template produces correct states."""

Expand All @@ -147,71 +133,7 @@ def circuit():

assert np.allclose(state, target_state)

# fmt: off
@pytest.mark.parametrize("state_vector,wires,target_state", [
([1, 0], 0, [1, 0, 0, 0, 0, 0, 0, 0]),
([1, 0], [0], [1, 0, 0, 0, 0, 0, 0, 0]),
([1, 0], [1], [1, 0, 0, 0, 0, 0, 0, 0]),
([1, 0], [2], [1, 0, 0, 0, 0, 0, 0, 0]),
([0, 1], [0], [0, 0, 0, 0, 1, 0, 0, 0]),
([0, 1], [1], [0, 0, 1, 0, 0, 0, 0, 0]),
([0, 1], [2], [0, 1, 0, 0, 0, 0, 0, 0]),
([0, 1, 0, 0], [0, 1], [0, 0, 1, 0, 0, 0, 0, 0]),
([0, 0, 0, 1], [0, 2], [0, 0, 0, 0, 0, 1, 0, 0]),
([0, 0, 0, 1], [1, 2], [0, 0, 0, 1, 0, 0, 0, 0]),
([1, 0, 0, 0, 0, 0, 0, 0], [0, 1, 2], [1, 0, 0, 0, 0, 0, 0, 0]),
([0, 0, 0, 0, 1j, 0, 0, 0], [0, 1, 2], [0, 0, 0, 0, 1j, 0, 0, 0]),
([1 / 2, 0, 0, 0, 1 / 2, 1j / 2, -1 / 2, 0], [0, 1, 2], [1 / 2, 0, 0, 0, 1 / 2, 1j / 2, -1 / 2, 0]),
([1 / 3, 0, 0, 0, 2j / 3, 2j / 3, 0, 0], [0, 1, 2], [1 / 3, 0, 0, 0, 2j / 3, 2j / 3, 0, 0]),
([2 / 3, 0, 0, 0, 1 / 3, 0, 0, 2 / 3], [0, 1, 2], [2 / 3, 0, 0, 0, 1 / 3, 0, 0, 2 / 3]),
(
[1 / np.sqrt(8), 1j / np.sqrt(8), 1 / np.sqrt(8), -1j / np.sqrt(8), 1 / np.sqrt(8),
1 / np.sqrt(8), 1 / np.sqrt(8), 1j / np.sqrt(8)],
[0, 1, 2],
[1 / np.sqrt(8), 1j / np.sqrt(8), 1 / np.sqrt(8), -1j / np.sqrt(8), 1 / np.sqrt(8),
1 / np.sqrt(8), 1 / np.sqrt(8), 1j / np.sqrt(8)],
),
(
[-0.17133152 - 0.18777771j, 0.00240643 - 0.40704011j, 0.18684538 - 0.36315606j, -0.07096948 + 0.104501j,
0.30357755 - 0.23831927j, -0.38735106 + 0.36075556j, 0.12351096 - 0.0539908j,
0.27942828 - 0.24810483j],
[0, 1, 2],
[-0.17133152 - 0.18777771j, 0.00240643 - 0.40704011j, 0.18684538 - 0.36315606j, -0.07096948 + 0.104501j,
0.30357755 - 0.23831927j, -0.38735106 + 0.36075556j, 0.12351096 - 0.0539908j,
0.27942828 - 0.24810483j],
),
(
[-0.29972867 + 0.04964242j, -0.28309418 + 0.09873227j, 0.00785743 - 0.37560696j,
-0.3825148 + 0.00674343j, -0.03008048 + 0.31119167j, 0.03666351 - 0.15935903j,
-0.25358831 + 0.35461265j, -0.32198531 + 0.33479292j],
[0, 1, 2],
[-0.29972867 + 0.04964242j, -0.28309418 + 0.09873227j, 0.00785743 - 0.37560696j,
-0.3825148 + 0.00674343j, -0.03008048 + 0.31119167j, 0.03666351 - 0.15935903j,
-0.25358831 + 0.35461265j, -0.32198531 + 0.33479292j],
),
(
[-0.39340123 + 0.05705932j, 0.1980509 - 0.24234781j, 0.27265585 - 0.0604432j, -0.42641249 + 0.25767258j,
0.40386614 - 0.39925987j, 0.03924761 + 0.13193724j, -0.06059103 - 0.01753834j,
0.21707136 - 0.15887973j],
[0, 1, 2],
[-0.39340123 + 0.05705932j, 0.1980509 - 0.24234781j, 0.27265585 - 0.0604432j, -0.42641249 + 0.25767258j,
0.40386614 - 0.39925987j, 0.03924761 + 0.13193724j, -0.06059103 - 0.01753834j,
0.21707136 - 0.15887973j],
),
(
[-1.33865287e-01 + 0.09802308j, 1.25060033e-01 + 0.16087698j, -4.14678130e-01 - 0.00774832j,
1.10121136e-01 + 0.37805482j, -3.21284864e-01 + 0.21521063j, -2.23121454e-04 + 0.28417422j,
5.64131205e-02 + 0.38135286j, 2.32694503e-01 + 0.41331133j],
[0, 1, 2],
[-1.33865287e-01 + 0.09802308j, 1.25060033e-01 + 0.16087698j, -4.14678130e-01 - 0.00774832j,
1.10121136e-01 + 0.37805482j, -3.21284864e-01 + 0.21521063j, -2.23121454e-04 + 0.28417422j,
5.64131205e-02 + 0.38135286j, 2.32694503e-01 + 0.41331133j],
),
([1 / 2, 0, 0, 0, 1j / 2, 0, 1j / np.sqrt(2), 0], [0, 1, 2],
[1 / 2, 0, 0, 0, 1j / 2, 0, 1j / np.sqrt(2), 0]),
([1 / 2, 0, 1j / 2, 1j / np.sqrt(2)], [0, 1], [1 / 2, 0, 0, 0, 1j / 2, 0, 1j / np.sqrt(2), 0]),
])
# fmt: on
@pytest.mark.parametrize("state_vector,wires,target_state", decomposition_test_cases)
def test_state_preparation_probability_distribution(
self, tol, state_vector, wires, target_state
):
Expand All @@ -224,32 +146,32 @@ def circuit():
qml.expval(qml.PauliZ(0)),
qml.expval(qml.PauliZ(1)),
qml.expval(qml.PauliZ(2)),
qml.state(),
qml.probs(),
)

results = circuit()

state = results[-1].ravel()
probabilities = results[-1].ravel()

probabilities = np.abs(state) ** 2
target_probabilities = np.abs(target_state) ** 2

assert np.allclose(probabilities, target_probabilities, atol=tol, rtol=0)

# fmt: off
@pytest.mark.parametrize("state_vector, n_wires", [
([1 / 2, 1 / 2, 1 / 2, 1 / 2], 2),
([1, 0, 0, 0], 2),
([0, 1, 0, 0], 2),
([0, 0, 0, 1], 2),
([0, 1, 0, 0, 0, 0, 0, 0], 3),
([0, 0, 0, 0, 1, 0, 0, 0], 3),
([2 / 3, 0, 0, 0, 1 / 3, 0, 0, 2 / 3], 3),
([1 / 2, 0, 0, 0, 1 / 2, 1 / 2, 1 / 2, 0], 3),
([1 / 3, 0, 0, 0, 2 / 3, 2 / 3, 0, 0], 3),
([2 / 3, 0, 0, 0, 1 / 3, 0, 0, 2 / 3], 3),
])
# fmt: on
@pytest.mark.parametrize(
"state_vector, n_wires",
[
([1 / 2, 1 / 2, 1 / 2, 1 / 2], 2),
([1, 0, 0, 0], 2),
([0, 1, 0, 0], 2),
([0, 0, 0, 1], 2),
([0, 1, 0, 0, 0, 0, 0, 0], 3),
([0, 0, 0, 0, 1, 0, 0, 0], 3),
([2 / 3, 0, 0, 0, 1 / 3, 0, 0, 2 / 3], 3),
([1 / 2, 0, 0, 0, 1 / 2, 1 / 2, 1 / 2, 0], 3),
([1 / 3, 0, 0, 0, 2 / 3, 2 / 3, 0, 0], 3),
([2 / 3, 0, 0, 0, 1 / 3, 0, 0, 2 / 3], 3),
],
)
def test_RZ_skipped(self, mocker, state_vector, n_wires):
"""Tests that the cascade of RZ gates is skipped for real-valued states."""

Expand Down Expand Up @@ -492,3 +414,32 @@ def circuit(state):
expected = np.zeros(8)
expected[0] = 1.0
assert qml.math.allclose(actual, expected)


@pytest.mark.jax
@pytest.mark.parametrize("shots, atol", [(None, 0.005), (1000000, 0.05)])
def test_jacobians_with_and_without_jit_match(shots, atol):
"""Test that the Jacobian of the circuit is the same with and without jit."""
import jax

dev = qml.device("default.qubit", shots=shots, seed=7890234)
dev_no_shots = qml.device("default.qubit", shots=None)

def circuit(coeffs):
qml.MottonenStatePreparation(coeffs, wires=[0, 1])
return qml.probs(wires=[0, 1])

circuit_fd = qml.QNode(circuit, dev, diff_method="finite-diff", h=0.05)
circuit_exact = qml.QNode(circuit, dev_no_shots)

params = jax.numpy.array([0.5, 0.5, 0.5, 0.5])
jac_exact_fn = jax.jacobian(circuit_exact)
jac_fn = jax.jacobian(circuit_fd)
jac_jit_fn = jax.jit(jac_fn)

jac_exact = jac_exact_fn(params)
jac = jac_fn(params)
jac_jit = jac_jit_fn(params)

assert qml.math.allclose(jac_exact, jac_jit, atol=atol)
assert qml.math.allclose(jac, jac_jit, atol=atol)
22 changes: 15 additions & 7 deletions tests/templates/test_subroutines/test_qubitization.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,16 +223,22 @@ def test_qnode_autograd(self):
assert np.allclose(res, self.exp_grad, atol=1e-5)

@pytest.mark.jax
@pytest.mark.parametrize(
"use_jit , shots",
((False, None), (True, None), (False, 50000)),
) # TODO: (True, 50000) fails because jax.jit on jax.grad does not work with AmplitudeEmbedding
@pytest.mark.parametrize("use_jit", (False, True))
@pytest.mark.parametrize("shots", (None, 50000))
@pytest.mark.parametrize("device", ["default.qubit", "default.qubit.legacy"])
def test_qnode_jax(self, shots, use_jit, device):
""" "Test that the QNode executes and is differentiable with JAX. The shots
argument controls whether autodiff or parameter-shift gradients are used."""
import jax

# TODO: Allow the following cases once their underlying issues are fixed:
# (True, 50000): jax.jit on jax.grad does not work with AmplitudeEmbedding currently
# (False, 50000): Since #5774, the decomposition of AmplitudeEmbedding triggered by
# param-shift includes a GlobalPhase always. GlobalPhase will only be
# param-shift-compatible again once #5620 is merged in.
if shots is not None:
pytest.xfail()

jax.config.update("jax_enable_x64", True)

if device == "default.qubit":
Expand All @@ -256,14 +262,16 @@ def test_qnode_jax(self, shots, use_jit, device):
assert np.allclose(jac, self.exp_grad, atol=0.05)

@pytest.mark.torch
@pytest.mark.parametrize(
"shots", [None]
) # TODO: finite shots fails because Prod is not currently differentiable.
@pytest.mark.parametrize("shots", [None, 50000])
def test_qnode_torch(self, shots):
""" "Test that the QNode executes and is differentiable with Torch. The shots
argument controls whether autodiff or parameter-shift gradients are used."""
import torch

# TODO: finite shots fails because Prod is not currently differentiable.
if shots is not None:
pytest.xfail()

dev = qml.device("default.qubit", shots=shots, seed=10)
diff_method = "backprop" if shots is None else "parameter-shift"
qnode = qml.QNode(self.circuit, dev, interface="torch", diff_method=diff_method)
Expand Down

0 comments on commit eb5c192

Please sign in to comment.