diff --git a/pennylane/tape/qscript.py b/pennylane/tape/qscript.py index dd2eecdc02e..6cf16c3d0fd 100644 --- a/pennylane/tape/qscript.py +++ b/pennylane/tape/qscript.py @@ -375,22 +375,23 @@ def op_wires(self) -> Wires: def update( self, - operations: Optional[Iterable[Operator]] = None, - measurements: Optional[Iterable[MeasurementProcess]] = None, - shots: Optional[ShotsLike] = None, - trainable_params: Optional[Sequence[int]] = None, + **kwargs, + # operations: Optional[Iterable[Operator]] = "unset", + # measurements: Optional[Iterable[MeasurementProcess]] = "unset", + # shots: Optional[ShotsLike] = "unset", + # trainable_params: Optional[Sequence[int]] = "unset", ): - """Update attirbutes on the tape. + r"""update(operations, measurements, shots, trainable_params) + + Update attributes on the tape. Keyword Args: operations (Iterable[Operator]): An iterable of the operations to replace on the tape. - Defaults to None. - measurements (Iterable[MeasurementProcess]): All the measurements to replace on the - tape. Defaults to None. + measurements (Iterable[MeasurementProcess]): An iterable of all the measurements to + replace on the tape. shots (None, int, Sequence[int], ~.Shots): Number and/or batches of shots to replace - on the tape. Defaults to None. + on the tape. trainable_params (None, Sequence[int]): the indices for which parameters are trainable. - Defaults to None. Returns: A new tape instance, initialized with any kwargs passed to update. Anything not set in update will be the same as the initial tape. @@ -417,11 +418,20 @@ def update( >>> new_tape.shots 2000 """ + for k in kwargs: + if k not in ["operations", "measurements", "shots", "trainable_params"]: + raise TypeError( + f"{self.__class__}.update() got an unexpected keyword argument '{k}'" + ) - ops = operations or self.operations - measurements = measurements or self.measurements - shots = shots or self.shots - trainable_params = trainable_params or self.trainable_params + ops = kwargs.get("operations") if "operations" in kwargs else self.operations + measurements = kwargs.get("measurements") if "measurements" in kwargs else self.measurements + shots = kwargs.get("shots") if "shots" in kwargs else self.shots + trainable_params = ( + kwargs.get("trainable_params") + if "trainable_params" in kwargs + else self.trainable_params + ) return self.__class__(ops, measurements, shots, trainable_params) diff --git a/tests/tape/test_qscript.py b/tests/tape/test_qscript.py index 3a88aedf80e..f4c0e1ac3b1 100644 --- a/tests/tape/test_qscript.py +++ b/tests/tape/test_qscript.py @@ -1416,52 +1416,87 @@ def test_jax_pytree_integration(qscript_type): assert qml.math.allclose(data[5], eye_mat) -@pytest.mark.parametrize( - "update_input", - [ - {"measurements": [qml.expval(qml.X(0))]}, - {"operations": [qml.X(7)]}, - {"shots": 100}, - {"shots": None}, - {"shots": 50, "measurements": [qml.sample(wires=[2, 3])]}, - {"operations": [qml.RX(1.2, 0), qml.RY(2.3, 1)], "trainable_params": [0, 1]}, - ], -) -@pytest.mark.parametrize("tape_class", [qml.tape.QuantumTape, QuantumScript]) -def test_public_update(self, update_input, tape_class): - """Test the public update method behaves as expected""" - - initial_args = [[qml.X("b"), qml.RX(1.2, "a")]] - initial_kwargs = { - "measurements": [qml.counts()], - "shots": 2500, - "trainable_params": [1], - } - - tape = tape_class(*initial_args, **initial_kwargs) - - new_tape = tape.update(**update_input) - - # original tape is unmodified - assert tape.operations == initial_args[0] - for kwarg, val in initial_kwargs.items(): - if kwarg == "shots": - val = Shots(val) - assert getattr(tape, kwarg) == val - - if "operations" in update_input: - assert new_tape.operations == update_input["operations"] - del update_input["operations"] - else: - assert new_tape.operations == tape.operations - - for kwarg, val in update_input.items(): - if kwarg == "shots": - val = Shots(val) - assert getattr(new_tape, kwarg) == val - del initial_kwargs[kwarg] - - for kwarg, val in initial_kwargs.items(): - if kwarg == "shots": - val = Shots(val) - assert getattr(new_tape, kwarg) == getattr(tape, kwarg) +class TestPublicUpdate: + """Test the public update method""" + + @pytest.mark.parametrize("shots", [50, (1000, 2000), None]) + def test_public_update_shots(self, shots): + """Test the public update method behaves as expected for setting shots""" + + ops = [qml.X("b"), qml.RX(1.2, "a")] + tape = QuantumScript(ops, measurements=[qml.counts()], shots=2500, trainable_params=[1]) + + new_tape = tape.update(shots=shots) + assert tape.shots == Shots(2500) + assert new_tape.shots == Shots(shots) + + assert new_tape.operations == tape.operations == ops + assert new_tape.measurements == tape.measurements == [qml.counts()] + assert new_tape.trainable_params == tape.trainable_params == [1] + + def test_public_update_measurements(self): + """Test the public update method behaves as expected for setting measurements""" + + ops = [qml.X("b"), qml.RX(1.2, "a")] + tape = QuantumScript(ops, measurements=[qml.counts()], shots=2500, trainable_params=[1]) + + new_measurements = [qml.expval(qml.X(0)), qml.sample()] + new_tape = tape.update(measurements=new_measurements) + + assert tape.measurements == [qml.counts()] + assert new_tape.measurements == new_measurements + + assert new_tape.operations == tape.operations == ops + assert new_tape.shots == tape.shots == Shots(2500) + assert new_tape.trainable_params == tape.trainable_params == [1] + + def test_public_update_operations(self): + """Test the public update method behaves as expected for setting operations""" + + ops = [qml.X("b"), qml.RX(1.2, "a")] + tape = QuantumScript(ops, measurements=[qml.counts()], shots=2500, trainable_params=[1]) + + new_ops = [qml.X(0)] + new_tape = tape.update(operations=new_ops) + + assert tape.operations == ops + assert new_tape.operations == new_ops + + assert new_tape.measurements == tape.measurements == [qml.counts()] + assert new_tape.shots == tape.shots == Shots(2500) + assert new_tape.trainable_params == tape.trainable_params == [1] + + def test_public_update_trainable_params(self): + """Test the public update method behaves as expected for setting trainable parameters""" + + ops = [qml.RX(1.23, "b"), qml.RX(4.56, "a")] + tape = QuantumScript(ops, measurements=[qml.counts()], shots=2500, trainable_params=[1]) + + new_tape = tape.update(trainable_params=[0]) + + assert tape.trainable_params == [1] + assert tape.get_parameters() == [4.56] + assert new_tape.trainable_params == [0] + assert new_tape.get_parameters() == [1.23] + + assert new_tape.operations == tape.operations == ops + assert new_tape.measurements == tape.measurements == [qml.counts()] + assert new_tape.shots == tape.shots == Shots(2500) + + def test_public_update_bad_kwarg(self): + """Test that an unrecognized keyword argument raises an error""" + + tape = QuantumScript([qml.X(0)], [qml.counts()], shots=2500) + + with pytest.raises(TypeError, match="unexpected keyword argument"): + _ = tape.update(bad_kwarg=3) + + # pylint: disable = unidiomatic-typecheck + @pytest.mark.parametrize("qscript_type", (QuantumScript, qml.tape.QuantumTape)) + def test_public_update_preserves_class(self, qscript_type): + """Test that the type of the updated tape is unaltered""" + + tape = qscript_type([qml.X(0)], [qml.counts()], shots=2500) + new_tape = tape.update(operations=[qml.Y(1)]) + + assert type(new_tape) == type(tape) == qscript_type