Skip to content

Commit

Permalink
update to support None and add tests
Browse files Browse the repository at this point in the history
  • Loading branch information
lillian542 committed Sep 19, 2024
1 parent 4b0fd1f commit 9eca00e
Show file tree
Hide file tree
Showing 2 changed files with 108 additions and 63 deletions.
38 changes: 24 additions & 14 deletions pennylane/tape/qscript.py
Original file line number Diff line number Diff line change
Expand Up @@ -375,22 +375,23 @@ def op_wires(self) -> Wires:

def update(
self,
operations: Optional[Iterable[Operator]] = None,
measurements: Optional[Iterable[MeasurementProcess]] = None,
shots: Optional[ShotsLike] = None,
trainable_params: Optional[Sequence[int]] = None,
**kwargs,
# operations: Optional[Iterable[Operator]] = "unset",
# measurements: Optional[Iterable[MeasurementProcess]] = "unset",
# shots: Optional[ShotsLike] = "unset",
# trainable_params: Optional[Sequence[int]] = "unset",
):
"""Update attirbutes on the tape.
r"""update(operations, measurements, shots, trainable_params)
Update attributes on the tape.
Keyword Args:
operations (Iterable[Operator]): An iterable of the operations to replace on the tape.
Defaults to None.
measurements (Iterable[MeasurementProcess]): All the measurements to replace on the
tape. Defaults to None.
measurements (Iterable[MeasurementProcess]): An iterable of all the measurements to
replace on the tape.
shots (None, int, Sequence[int], ~.Shots): Number and/or batches of shots to replace
on the tape. Defaults to None.
on the tape.
trainable_params (None, Sequence[int]): the indices for which parameters are trainable.
Defaults to None.
Returns: A new tape instance, initialized with any kwargs passed to update.
Anything not set in update will be the same as the initial tape.
Expand All @@ -417,11 +418,20 @@ def update(
>>> new_tape.shots
2000
"""
for k in kwargs:
if k not in ["operations", "measurements", "shots", "trainable_params"]:
raise TypeError(
f"{self.__class__}.update() got an unexpected keyword argument '{k}'"
)

ops = operations or self.operations
measurements = measurements or self.measurements
shots = shots or self.shots
trainable_params = trainable_params or self.trainable_params
ops = kwargs.get("operations") if "operations" in kwargs else self.operations
measurements = kwargs.get("measurements") if "measurements" in kwargs else self.measurements
shots = kwargs.get("shots") if "shots" in kwargs else self.shots
trainable_params = (
kwargs.get("trainable_params")
if "trainable_params" in kwargs
else self.trainable_params
)

return self.__class__(ops, measurements, shots, trainable_params)

Expand Down
133 changes: 84 additions & 49 deletions tests/tape/test_qscript.py
Original file line number Diff line number Diff line change
Expand Up @@ -1416,52 +1416,87 @@ def test_jax_pytree_integration(qscript_type):
assert qml.math.allclose(data[5], eye_mat)


@pytest.mark.parametrize(
"update_input",
[
{"measurements": [qml.expval(qml.X(0))]},
{"operations": [qml.X(7)]},
{"shots": 100},
{"shots": None},
{"shots": 50, "measurements": [qml.sample(wires=[2, 3])]},
{"operations": [qml.RX(1.2, 0), qml.RY(2.3, 1)], "trainable_params": [0, 1]},
],
)
@pytest.mark.parametrize("tape_class", [qml.tape.QuantumTape, QuantumScript])
def test_public_update(self, update_input, tape_class):
"""Test the public update method behaves as expected"""

initial_args = [[qml.X("b"), qml.RX(1.2, "a")]]
initial_kwargs = {
"measurements": [qml.counts()],
"shots": 2500,
"trainable_params": [1],
}

tape = tape_class(*initial_args, **initial_kwargs)

new_tape = tape.update(**update_input)

# original tape is unmodified
assert tape.operations == initial_args[0]
for kwarg, val in initial_kwargs.items():
if kwarg == "shots":
val = Shots(val)
assert getattr(tape, kwarg) == val

if "operations" in update_input:
assert new_tape.operations == update_input["operations"]
del update_input["operations"]
else:
assert new_tape.operations == tape.operations

for kwarg, val in update_input.items():
if kwarg == "shots":
val = Shots(val)
assert getattr(new_tape, kwarg) == val
del initial_kwargs[kwarg]

for kwarg, val in initial_kwargs.items():
if kwarg == "shots":
val = Shots(val)
assert getattr(new_tape, kwarg) == getattr(tape, kwarg)
class TestPublicUpdate:
"""Test the public update method"""

@pytest.mark.parametrize("shots", [50, (1000, 2000), None])
def test_public_update_shots(self, shots):
"""Test the public update method behaves as expected for setting shots"""

ops = [qml.X("b"), qml.RX(1.2, "a")]
tape = QuantumScript(ops, measurements=[qml.counts()], shots=2500, trainable_params=[1])

new_tape = tape.update(shots=shots)
assert tape.shots == Shots(2500)
assert new_tape.shots == Shots(shots)

assert new_tape.operations == tape.operations == ops
assert new_tape.measurements == tape.measurements == [qml.counts()]
assert new_tape.trainable_params == tape.trainable_params == [1]

def test_public_update_measurements(self):
"""Test the public update method behaves as expected for setting measurements"""

ops = [qml.X("b"), qml.RX(1.2, "a")]
tape = QuantumScript(ops, measurements=[qml.counts()], shots=2500, trainable_params=[1])

new_measurements = [qml.expval(qml.X(0)), qml.sample()]
new_tape = tape.update(measurements=new_measurements)

assert tape.measurements == [qml.counts()]
assert new_tape.measurements == new_measurements

assert new_tape.operations == tape.operations == ops
assert new_tape.shots == tape.shots == Shots(2500)
assert new_tape.trainable_params == tape.trainable_params == [1]

def test_public_update_operations(self):
"""Test the public update method behaves as expected for setting operations"""

ops = [qml.X("b"), qml.RX(1.2, "a")]
tape = QuantumScript(ops, measurements=[qml.counts()], shots=2500, trainable_params=[1])

new_ops = [qml.X(0)]
new_tape = tape.update(operations=new_ops)

assert tape.operations == ops
assert new_tape.operations == new_ops

assert new_tape.measurements == tape.measurements == [qml.counts()]
assert new_tape.shots == tape.shots == Shots(2500)
assert new_tape.trainable_params == tape.trainable_params == [1]

def test_public_update_trainable_params(self):
"""Test the public update method behaves as expected for setting trainable parameters"""

ops = [qml.RX(1.23, "b"), qml.RX(4.56, "a")]
tape = QuantumScript(ops, measurements=[qml.counts()], shots=2500, trainable_params=[1])

new_tape = tape.update(trainable_params=[0])

assert tape.trainable_params == [1]
assert tape.get_parameters() == [4.56]
assert new_tape.trainable_params == [0]
assert new_tape.get_parameters() == [1.23]

assert new_tape.operations == tape.operations == ops
assert new_tape.measurements == tape.measurements == [qml.counts()]
assert new_tape.shots == tape.shots == Shots(2500)

def test_public_update_bad_kwarg(self):
"""Test that an unrecognized keyword argument raises an error"""

tape = QuantumScript([qml.X(0)], [qml.counts()], shots=2500)

with pytest.raises(TypeError, match="unexpected keyword argument"):
_ = tape.update(bad_kwarg=3)

# pylint: disable = unidiomatic-typecheck
@pytest.mark.parametrize("qscript_type", (QuantumScript, qml.tape.QuantumTape))
def test_public_update_preserves_class(self, qscript_type):
"""Test that the type of the updated tape is unaltered"""

tape = qscript_type([qml.X(0)], [qml.counts()], shots=2500)
new_tape = tape.update(operations=[qml.Y(1)])

assert type(new_tape) == type(tape) == qscript_type

0 comments on commit 9eca00e

Please sign in to comment.