Skip to content

Commit

Permalink
Add finite-diff test
Browse files Browse the repository at this point in the history
  • Loading branch information
mudit2812 committed Jun 20, 2024
1 parent 3b1bf4e commit 35235b5
Showing 1 changed file with 27 additions and 111 deletions.
138 changes: 27 additions & 111 deletions tests/devices/default_qubit/test_default_qubit_native_mcm.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.
"""Tests for default qubit preprocessing."""
from typing import Sequence
from unittest.mock import patch

import mcm_utils
import numpy as np
Expand Down Expand Up @@ -424,6 +425,32 @@ def circuit(x):
_ = circuit([0.1, 0.2])


@pytest.mark.all_interfaces
@pytest.mark.parametrize("interface", ["torch", "tensorflow", "jax", "autograd"])
@pytest.mark.parametrize("mcm_method", ["one-shot", "tree-traversal"])
def test_finite_diff_in_transform_program(interface, mcm_method):
"""Test that finite diff is in the transform program of a qnode containing
mid-circuit measurements"""

dev = get_device(shots=10)

@qml.qnode(dev, mcm_method=mcm_method, diff_method="finite-diff")
def circuit(x):
qml.RX(x, 0)
qml.measure(0)
return qml.expval(qml.Z(0))

x = qml.math.array(1.5, like=interface)
with patch("pennylane.execute") as mock_execute:
circuit(x)
mock_execute.assert_called()
_, kwargs = mock_execute.call_args
transform_program = kwargs["transform_program"]

# pylint: disable=protected-access
assert transform_program[0]._transform == qml.gradients.finite_diff.expand_transform


# pylint: disable=import-outside-toplevel, not-an-iterable
@pytest.mark.jax
class TestJaxIntegration:
Expand Down Expand Up @@ -536,114 +563,3 @@ def func(x, y, z):
if measure_f == qml.sample:
r2 = r2[r2 != fill_in_value]
np.allclose(r1, r2)

def test_obs_grad(self):
"""Test that the gradient of a single observable expectation value is correct."""
assert True

def test_obs_jac(self):
"""Test that the jacobian for a circuit with multiple observable expectation values
is correct."""
assert True

def test_single_mcm_meas_grad(self):
"""Test that the gradient is correct when collecting statistics on a single
mid-circuit measurement."""
assert True

def test_multi_mcm_meas_jac(self):
"""Test that the jacobian is correct when collecting statistics on multiple
mid-circuit measurements."""
assert True


# pylint: disable=import-outside-toplevel, not-an-iterable
@pytest.mark.torch
@pytest.mark.parametrize("mcm_method", ["one-shot", "tree-traversal"])
class TestTorchIntegration:
"""Integration tests for dynamic_one_shot with Torch"""

def test_grad_cond_param(self):
"""Test that the gradient is correct for a circuit with classically controlled
operations with trainable parameters"""
import torch
shots = 3000
dev = qml.device("default.qubit", shots=shots)

@qml.qnode(dev, diff_method="finite-diff", mcm_method="one-shot")
def circuit(x, y):
qml.RX(x, 0)
m = qml.measure(0)
qml.cond(m, qml.RX)(y, 0)
return qml.expval(qml.PauliZ(0))

x = torch.tensor(2.1, requires_grad=False)
y = torch.tensor(1.5, requires_grad=True)

def test_obs_grad(self):
"""Test that the gradient of a single observable expectation value is correct."""
assert True

def test_obs_jac(self):
"""Test that the jacobian for a circuit with multiple observable expectation values
is correct."""
assert True

def test_single_mcm_meas_grad(self):
"""Test that the gradient is correct when collecting statistics on a single
mid-circuit measurement."""
assert True

def test_multi_mcm_meas_jac(self):
"""Test that the jacobian is correct when collecting statistics on multiple
mid-circuit measurements."""
assert True


@pytest.mark.autograd
class TestAutogradIntegration:
"""Integration tests for dynamic_one_shot with Autograd"""

def test_obs_grad(self):
"""Test that the gradient of a single observable expectation value is correct."""
assert True

def test_obs_jac(self):
"""Test that the jacobian for a circuit with multiple observable expectation values
is correct."""
assert True

def test_single_mcm_meas_grad(self):
"""Test that the gradient is correct when collecting statistics on a single
mid-circuit measurement."""
assert True

def test_multi_mcm_meas_jac(self):
"""Test that the jacobian is correct when collecting statistics on multiple
mid-circuit measurements."""
assert True


# pylint: disable=import-outside-toplevel, not-an-iterable
@pytest.mark.tf
class TestTensorflowIntegration:
"""Integration tests for dynamic_one_shot with Tensorflow"""

def test_obs_grad(self):
"""Test that the gradient of a single observable expectation value is correct."""
assert True

def test_obs_jac(self):
"""Test that the jacobian for a circuit with multiple observable expectation values
is correct."""
assert True

def test_single_mcm_meas_grad(self):
"""Test that the gradient is correct when collecting statistics on a single
mid-circuit measurement."""
assert True

def test_multi_mcm_meas_jac(self):
"""Test that the jacobian is correct when collecting statistics on multiple
mid-circuit measurements."""
assert True

0 comments on commit 35235b5

Please sign in to comment.