Skip to content

Commit

Permalink
fix test coverage
Browse files Browse the repository at this point in the history
  • Loading branch information
albi3ro committed Jun 20, 2024
1 parent 7d811a6 commit b585f4a
Show file tree
Hide file tree
Showing 5 changed files with 28 additions and 11 deletions.
25 changes: 25 additions & 0 deletions pennylane/capture/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,13 @@ def _(*args, **kwargs):
)
from .capture_qnode import qnode_call

# by defining this here, we avoid
# E0611: No name 'AbstractOperator' in module 'pennylane.capture' (no-name-in-module)
# on use of from capture import AbstractOperator
AbstractOperator: type
AbstractMeasurement: type
qnode_prim: "jax.core.Primitive"


def __getattr__(key):
if key == "AbstractOperator":
Expand All @@ -153,3 +160,21 @@ def __getattr__(key):
return _get_qnode_prim()

raise AttributeError(f"module 'pennylane.capture' has no attribute '{key}'")


# pylint: disable=undefined-all-variable
__all__ = (
"disable",
"enable",
"enabled",
"CaptureMeta",
"ABCCaptureMeta",
"create_operator_primitive",
"create_measurement_obs_primitive",
"create_measurement_wires_primitive",
"create_measurement_mcm_primitive",
"qnode_call",
"AbstractOperator",
"AbstractMeasurement",
"qnode_prim",
)
4 changes: 1 addition & 3 deletions tests/capture/test_capture_qnode.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,7 @@
import pytest

import pennylane as qml
from pennylane.capture.capture_qnode import _get_qnode_prim

qnode_prim = _get_qnode_prim()
from pennylane.capture import qnode_prim

pytestmark = pytest.mark.jax

Expand Down
4 changes: 1 addition & 3 deletions tests/capture/test_measurements_capture.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import pytest

import pennylane as qml
from pennylane.capture.primitives import _get_abstract_measurement
from pennylane.capture import AbstractMeasurement
from pennylane.measurements import (
ClassicalShadowMP,
DensityMatrixMP,
Expand All @@ -40,8 +40,6 @@

pytestmark = pytest.mark.jax

AbstractMeasurement = _get_abstract_measurement()


@pytest.fixture(autouse=True)
def enable_disable_plxpr():
Expand Down
4 changes: 1 addition & 3 deletions tests/capture/test_operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,12 @@
import pytest

import pennylane as qml
from pennylane.capture.primitives import _get_abstract_operator
from pennylane.capture import AbstractOperator

jax = pytest.importorskip("jax")

pytestmark = pytest.mark.jax

AbstractOperator = _get_abstract_operator()


@pytest.fixture(autouse=True)
def enable_disable_plxpr():
Expand Down
2 changes: 0 additions & 2 deletions tests/capture/test_templates.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,14 +23,12 @@
import pytest

import pennylane as qml
from pennylane.capture.primitives import _get_abstract_operator

jax = pytest.importorskip("jax")
jnp = jax.numpy

pytestmark = pytest.mark.jax

AbstractOperator = _get_abstract_operator()
original_op_bind_code = qml.operation.Operator._primitive_bind_call.__code__


Expand Down

0 comments on commit b585f4a

Please sign in to comment.