Skip to content

Commit

Permalink
adding tests
Browse files Browse the repository at this point in the history
  • Loading branch information
KetpuntoG committed Sep 19, 2024
1 parent fcbcb87 commit 000b88d
Show file tree
Hide file tree
Showing 7 changed files with 48 additions and 16 deletions.
2 changes: 1 addition & 1 deletion pennylane/ops/op_math/adjoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,7 +241,7 @@ def wrapper(*args, **kwargs):
for arg in args:
remove_from_queue_args_and_kwargs(arg)

for key, value in kwargs.items():
for value in kwargs.values():
remove_from_queue_args_and_kwargs(value)

if lazy:
Expand Down
7 changes: 0 additions & 7 deletions pennylane/ops/op_math/condition.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
from pennylane.compiler import compiler
from pennylane.measurements import MeasurementValue
from pennylane.operation import AnyWires, Operation, Operator
from pennylane.ops.op_math.controlled import remove_from_queue_args_and_kwargs
from pennylane.ops.op_math.symbolicop import SymbolicOp
from pennylane.tape import make_qscript

Expand Down Expand Up @@ -624,12 +623,6 @@ def wrapper(*args, **kwargs):
# 1. Apply true_fn conditionally
qscript = make_qscript(true_fn)(*args, **kwargs)

for arg in args:
remove_from_queue_args_and_kwargs(arg)

for key, value in kwargs.items():
remove_from_queue_args_and_kwargs(value)

if qscript.measurements:
raise ConditionalTransformError(with_meas_err)

Expand Down
4 changes: 2 additions & 2 deletions pennylane/ops/op_math/controlled.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,7 @@ def remove_from_queue_args_and_kwargs(item):
for elem in item:
remove_from_queue_args_and_kwargs(elem)
elif isinstance(item, dict):
for key, value in item.items():
for value in item.values():
remove_from_queue_args_and_kwargs(value)
elif isinstance(item, Operator):
qml.queuing.QueuingManager.remove(item)
Expand All @@ -219,7 +219,7 @@ def wrapper(*args, **kwargs):
for arg in args:
remove_from_queue_args_and_kwargs(arg)

for key, value in kwargs.items():
for value in kwargs.values():
remove_from_queue_args_and_kwargs(value)

# flip control_values == 0 wires here, so we don't have to do it for each individual op.
Expand Down
6 changes: 0 additions & 6 deletions pennylane/ops/op_math/prod.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,12 +113,6 @@ def prod(*ops, id=None, lazy=True):
def wrapper(*args, **kwargs):
qs = qml.tape.make_qscript(fn)(*args, **kwargs)

for arg in args:
remove_from_queue_args_and_kwargs(arg)

for key, value in kwargs.items():
remove_from_queue_args_and_kwargs(value)

if len(qs.operations) == 1:
if qml.QueuingManager.recording():
qml.apply(qs[0])
Expand Down
15 changes: 15 additions & 0 deletions tests/ops/op_math/test_adjoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -868,6 +868,21 @@ def test_single_op_defined_outside_queue_eager(self):
assert len(q) == 1
assert q.queue[0] is out

def test_correct_queued_operators(self):
"""Test that args and kwargs do not add operators to the queue."""

dev = qml.device("default.qubit")

@qml.qnode(dev)
def circuit():
qml.adjoint(qml.QSVT)(qml.X(1), [qml.Z(1)])
qml.adjoint(qml.QSVT(qml.X(1), [qml.Z(1)]))
return qml.state()

circuit()
for op in circuit.tape.operations:
assert op.name == "Adjoint(QSVT)"

@pytest.mark.usefixtures("use_legacy_opmath")
def test_single_observable(self):
"""Test passing a single preconstructed observable in a queuing context."""
Expand Down
15 changes: 15 additions & 0 deletions tests/ops/op_math/test_controlled.py
Original file line number Diff line number Diff line change
Expand Up @@ -1899,6 +1899,21 @@ def test_nested_pauli_x_based_ctrl_ops(self):
expected = qml.MultiControlledX(wires=[3, 2, 1, 0], control_values=[1, 0, 1])
assert op == expected

def test_correct_queued_operators(self):
"""Test that args and kwargs do not add operators to the queue."""

dev = qml.device("default.qubit")

@qml.qnode(dev)
def circuit():
qml.ctrl(qml.QSVT, control=0)(qml.X(1), [qml.Z(1)])
qml.ctrl(qml.QSVT(qml.X(1), [qml.Z(1)]), control=0)
return qml.state()

circuit()
for op in circuit.tape.operations:
assert op.name == "C(QSVT)"


class _Rot(Operation):
"""A rotation operation that is not an instance of Rot
Expand Down
15 changes: 15 additions & 0 deletions tests/ops/op_math/test_prod.py
Original file line number Diff line number Diff line change
Expand Up @@ -1425,6 +1425,21 @@ def test_nonlazy_mode_queueing(self):
assert len(q) == 1
assert q.queue[0] is prod2

def test_correct_queued_operators(self):
"""Test that args and kwargs do not add operators to the queue."""

dev = qml.device("default.qubit")

@qml.qnode(dev)
def circuit():
qml.prod(qml.QSVT)(qml.X(1), [qml.Z(1)])
qml.prod(qml.QSVT(qml.X(1), [qml.Z(1)]))
return qml.state()

circuit()
for op in circuit.tape.operations:
assert op.name == "QSVT"


class TestIntegration:
"""Integration tests for the Prod class."""
Expand Down

0 comments on commit 000b88d

Please sign in to comment.