diff --git a/pennylane/capture/__init__.py b/pennylane/capture/__init__.py index 63b4aaa4261..75e82da4053 100644 --- a/pennylane/capture/__init__.py +++ b/pennylane/capture/__init__.py @@ -135,6 +135,13 @@ def _(*args, **kwargs): ) from .capture_qnode import qnode_call +# by defining this here, we avoid +# E0611: No name 'AbstractOperator' in module 'pennylane.capture' (no-name-in-module) +# on use of from capture import AbstractOperator +AbstractOperator: type +AbstractMeasurement: type +qnode_prim: "jax.core.Primitive" + def __getattr__(key): if key == "AbstractOperator": @@ -153,3 +160,21 @@ def __getattr__(key): return _get_qnode_prim() raise AttributeError(f"module 'pennylane.capture' has no attribute '{key}'") + + +# pylint: disable=undefined-all-variable +__all__ = ( + "disable", + "enable", + "enabled", + "CaptureMeta", + "ABCCaptureMeta", + "create_operator_primitive", + "create_measurement_obs_primitive", + "create_measurement_wires_primitive", + "create_measurement_mcm_primitive", + "qnode_call", + "AbstractOperator", + "AbstractMeasurement", + "qnode_prim", +) diff --git a/tests/capture/test_capture_qnode.py b/tests/capture/test_capture_qnode.py index 56f9f4739f0..553c4485b73 100644 --- a/tests/capture/test_capture_qnode.py +++ b/tests/capture/test_capture_qnode.py @@ -21,9 +21,7 @@ import pytest import pennylane as qml -from pennylane.capture.capture_qnode import _get_qnode_prim - -qnode_prim = _get_qnode_prim() +from pennylane.capture import qnode_prim pytestmark = pytest.mark.jax diff --git a/tests/capture/test_measurements_capture.py b/tests/capture/test_measurements_capture.py index a653a3146f9..eb0ddc5bf6f 100644 --- a/tests/capture/test_measurements_capture.py +++ b/tests/capture/test_measurements_capture.py @@ -20,7 +20,7 @@ import pytest import pennylane as qml -from pennylane.capture.primitives import _get_abstract_measurement +from pennylane.capture import AbstractMeasurement from pennylane.measurements import ( ClassicalShadowMP, DensityMatrixMP, @@ -40,8 +40,6 @@ pytestmark = pytest.mark.jax -AbstractMeasurement = _get_abstract_measurement() - @pytest.fixture(autouse=True) def enable_disable_plxpr(): diff --git a/tests/capture/test_operators.py b/tests/capture/test_operators.py index 6388a3b6ea4..3caff3e42fb 100644 --- a/tests/capture/test_operators.py +++ b/tests/capture/test_operators.py @@ -18,14 +18,12 @@ import pytest import pennylane as qml -from pennylane.capture.primitives import _get_abstract_operator +from pennylane.capture import AbstractOperator jax = pytest.importorskip("jax") pytestmark = pytest.mark.jax -AbstractOperator = _get_abstract_operator() - @pytest.fixture(autouse=True) def enable_disable_plxpr(): diff --git a/tests/capture/test_templates.py b/tests/capture/test_templates.py index 7648a4c1614..be54b2e8cb8 100644 --- a/tests/capture/test_templates.py +++ b/tests/capture/test_templates.py @@ -23,14 +23,12 @@ import pytest import pennylane as qml -from pennylane.capture.primitives import _get_abstract_operator jax = pytest.importorskip("jax") jnp = jax.numpy pytestmark = pytest.mark.jax -AbstractOperator = _get_abstract_operator() original_op_bind_code = qml.operation.Operator._primitive_bind_call.__code__