From 705a3ced42a3463f29672410576206d18a6b0e53 Mon Sep 17 00:00:00 2001 From: Astral Cai Date: Wed, 5 Jun 2024 13:39:59 -0400 Subject: [PATCH 01/36] Mark controlled sequence as having no grad method --- .../subroutines/controlled_sequence.py | 2 + .../test_controlled_sequence.py | 45 ++++++++++++++----- 2 files changed, 36 insertions(+), 11 deletions(-) diff --git a/pennylane/templates/subroutines/controlled_sequence.py b/pennylane/templates/subroutines/controlled_sequence.py index e9efad65281..517129556cc 100644 --- a/pennylane/templates/subroutines/controlled_sequence.py +++ b/pennylane/templates/subroutines/controlled_sequence.py @@ -69,6 +69,8 @@ def circuit(): """ + grad_method = None + def _flatten(self): return (self.base,), (self.control,) diff --git a/tests/templates/test_subroutines/test_controlled_sequence.py b/tests/templates/test_subroutines/test_controlled_sequence.py index 6286af4e577..f7dc9c8a233 100644 --- a/tests/templates/test_subroutines/test_controlled_sequence.py +++ b/tests/templates/test_subroutines/test_controlled_sequence.py @@ -31,6 +31,7 @@ def test_standard_validity(): class TestInitialization: + def test_id(self): """Tests that the id attribute can be set.""" op = qml.ControlledSequence(qml.RX(0.25, wires=3), control=[0, 1, 2], id="a") @@ -57,6 +58,7 @@ def test_name(self): class TestProperties: + def test_hash(self): """Test that op.hash uniquely describes a ControlledSequence""" @@ -97,6 +99,7 @@ def test_has_matrix(self): class TestMethods: + def test_repr(self): """Test that the operator repr is as expected""" op = qml.ControlledSequence(qml.RX(0.25, wires=3), control=[0, 1, 2]) @@ -189,28 +192,41 @@ def test_qnode_numpy(self): assert np.allclose(res, self.exp_result, atol=0.002) @pytest.mark.autograd - def test_qnode_autograd(self): + @pytest.mark.parametrize("shots", [None, 50000]) + @pytest.mark.parametrize("device", ["default.qubit", "default.qubit.legacy"]) + def test_qnode_autograd(self, shots, device): """Test that the QNode executes with Autograd.""" - dev = qml.device("default.qubit") - qnode = qml.QNode(self.circuit, dev, interface="autograd") - + dev = qml.device(device, wires=4, shots=shots) + diff_method = "backprop" if shots is None else "parameter-shift" + qnode = qml.QNode(self.circuit, dev, interface="autograd", diff_method=diff_method) x = qml.numpy.array(self.x, requires_grad=True) + res = qnode(x) assert qml.math.shape(res) == (16,) assert np.allclose(res, self.exp_result, atol=0.002) + res = qml.jacobian(qnode)(x) + assert np.shape(res) == (16,) + assert np.allclose(res, self.exp_jac, atol=0.005) + @pytest.mark.jax @pytest.mark.parametrize("use_jit", [False, True]) - @pytest.mark.parametrize("shots", [None, 10000]) - def test_qnode_jax(self, shots, use_jit): + @pytest.mark.parametrize("shots", [None, 50000]) + @pytest.mark.parametrize("device", ["default.qubit", "default.qubit.legacy"]) + def test_qnode_jax(self, shots, use_jit, device): """Test that the QNode executes and is differentiable with JAX. The shots argument controls whether autodiff or parameter-shift gradients are used.""" + import jax jax.config.update("jax_enable_x64", True) - dev = qml.device("default.qubit", shots=shots, seed=10) + if device == "default.qubit": + dev = qml.device("default.qubit", shots=shots, seed=10) + else: + dev = qml.device("default.qubit.legacy", shots=shots, wires=4) + diff_method = "backprop" if shots is None else "parameter-shift" qnode = qml.QNode(self.circuit, dev, interface="jax", diff_method=diff_method) if use_jit: @@ -230,13 +246,19 @@ def test_qnode_jax(self, shots, use_jit): assert np.allclose(jac, self.exp_jac, atol=0.006) @pytest.mark.torch - @pytest.mark.parametrize("shots", [None, 10000]) - def test_qnode_torch(self, shots): + @pytest.mark.parametrize("shots", [None, 50000]) + @pytest.mark.parametrize("device", ["default.qubit", "default.qubit.legacy"]) + def test_qnode_torch(self, shots, device): """Test that the QNode executes and is differentiable with Torch. The shots argument controls whether autodiff or parameter-shift gradients are used.""" + import torch - dev = qml.device("default.qubit", shots=shots, seed=10) + if device == "default.qubit": + dev = qml.device("default.qubit", shots=shots, seed=10) + else: + dev = qml.device("default.qubit.legacy", shots=shots, wires=4) + diff_method = "backprop" if shots is None else "parameter-shift" qnode = qml.QNode(self.circuit, dev, interface="torch", diff_method=diff_method) @@ -247,7 +269,7 @@ def test_qnode_torch(self, shots): jac = torch.autograd.functional.jacobian(qnode, x) assert qml.math.shape(jac) == (16,) - assert qml.math.allclose(jac, self.exp_jac, atol=0.006) + assert qml.math.allclose(jac, self.exp_jac, atol=0.005) @pytest.mark.tf @pytest.mark.parametrize("shots", [None, 10000]) @@ -255,6 +277,7 @@ def test_qnode_torch(self, shots): def test_qnode_tf(self, shots): """Test that the QNode executes and is differentiable with TensorFlow. The shots argument controls whether autodiff or parameter-shift gradients are used.""" + import tensorflow as tf dev = qml.device("default.qubit", shots=shots, seed=10) From b7537d562f93882d5da0f32e6bd063ec009938fa Mon Sep 17 00:00:00 2001 From: Astral Cai Date: Wed, 5 Jun 2024 14:31:01 -0400 Subject: [PATCH 02/36] Mark `Reflection` as having no grad method --- pennylane/templates/subroutines/reflection.py | 2 + .../test_subroutines/test_reflection.py | 40 ++++++++++++++----- 2 files changed, 31 insertions(+), 11 deletions(-) diff --git a/pennylane/templates/subroutines/reflection.py b/pennylane/templates/subroutines/reflection.py index a39cf7cb3af..f67a0c435b6 100644 --- a/pennylane/templates/subroutines/reflection.py +++ b/pennylane/templates/subroutines/reflection.py @@ -104,6 +104,8 @@ def circuit(): """ + grad_method = None + @classmethod def _primitive_bind_call(cls, *args, **kwargs): return cls._primitive.bind(*args, **kwargs) diff --git a/tests/templates/test_subroutines/test_reflection.py b/tests/templates/test_subroutines/test_reflection.py index 417c3d715d0..5d7cefb729c 100644 --- a/tests/templates/test_subroutines/test_reflection.py +++ b/tests/templates/test_subroutines/test_reflection.py @@ -163,28 +163,40 @@ def test_lightning_qubit(self): assert np.allclose(res, self.exp_result, atol=0.002) @pytest.mark.autograd - def test_qnode_autograd(self): + @pytest.mark.parametrize("shots", [None, 50000]) + @pytest.mark.parametrize("device", ["default.qubit", "default.qubit.legacy"]) + def test_qnode_autograd(self, shots, device): """Test that the QNode executes with Autograd.""" - dev = qml.device("default.qubit") - qnode = qml.QNode(self.circuit, dev, interface="autograd") + dev = qml.device(device, shots=shots, wires=3) + diff_method = "backprop" if shots is None else "parameter-shift" + qnode = qml.QNode(self.circuit, dev, interface="autograd", diff_method=diff_method) x = qml.numpy.array(self.x, requires_grad=True) res = qnode(x) assert qml.math.shape(res) == (8,) - assert np.allclose(res, self.exp_result, atol=0.002) + assert np.allclose(res, self.exp_result, atol=0.005) + + res = qml.jacobian(qnode)(x) + assert np.shape(res) == (8,) + assert np.allclose(res, self.exp_jac, atol=0.005) @pytest.mark.jax @pytest.mark.parametrize("use_jit", [False, True]) @pytest.mark.parametrize("shots", [None, 50000]) - def test_qnode_jax(self, shots, use_jit): + @pytest.mark.parametrize("device", ["default.qubit", "default.qubit.legacy"]) + def test_qnode_jax(self, shots, use_jit, device): """Test that the QNode executes and is differentiable with JAX. The shots argument controls whether autodiff or parameter-shift gradients are used.""" import jax jax.config.update("jax_enable_x64", True) - dev = qml.device("default.qubit", shots=shots, seed=10) + if device == "default.qubit": + dev = qml.device("default.qubit", shots=shots, seed=10) + else: + dev = qml.device("default.qubit.legacy", shots=shots, wires=3) + diff_method = "backprop" if shots is None else "parameter-shift" qnode = qml.QNode(self.circuit, dev, interface="jax", diff_method=diff_method) if use_jit: @@ -201,27 +213,33 @@ def test_qnode_jax(self, shots, use_jit): jac = jac_fn(x) assert jac.shape == (8,) - assert np.allclose(jac, self.exp_jac, atol=0.006) + assert np.allclose(jac, self.exp_jac, atol=0.005) @pytest.mark.torch @pytest.mark.parametrize("shots", [None, 50000]) - def test_qnode_torch(self, shots): + @pytest.mark.parametrize("device", ["default.qubit", "default.qubit.legacy"]) + def test_qnode_torch(self, shots, device): """Test that the QNode executes and is differentiable with Torch. The shots argument controls whether autodiff or parameter-shift gradients are used.""" + import torch - dev = qml.device("default.qubit", shots=shots, seed=10) + if device == "default.qubit": + dev = qml.device("default.qubit", shots=shots, seed=10) + else: + dev = qml.device("default.qubit.legacy", shots=shots, wires=3) + diff_method = "backprop" if shots is None else "parameter-shift" qnode = qml.QNode(self.circuit, dev, interface="torch", diff_method=diff_method) x = torch.tensor(self.x, requires_grad=True) res = qnode(x) assert qml.math.shape(res) == (8,) - assert qml.math.allclose(res, self.exp_result, atol=0.002) + assert qml.math.allclose(res, self.exp_result, atol=0.005) jac = torch.autograd.functional.jacobian(qnode, x) assert qml.math.shape(jac) == (8,) - assert qml.math.allclose(jac, self.exp_jac, atol=0.006) + assert qml.math.allclose(jac, self.exp_jac, atol=0.005) @pytest.mark.tf @pytest.mark.parametrize("shots", [None, 50000]) From 579f9118666e1a302a926af2f27ce1f92687fb8d Mon Sep 17 00:00:00 2001 From: Astral Cai Date: Wed, 5 Jun 2024 15:13:32 -0400 Subject: [PATCH 03/36] Mark amplitude amplification as having no grad method --- .../subroutines/amplitude_amplification.py | 6 ++-- pennylane/templates/subroutines/reflection.py | 2 +- .../test_amplitude_amplification.py | 31 +++++++++++++------ 3 files changed, 27 insertions(+), 12 deletions(-) diff --git a/pennylane/templates/subroutines/amplitude_amplification.py b/pennylane/templates/subroutines/amplitude_amplification.py index 08e050ff73f..fd89816b984 100644 --- a/pennylane/templates/subroutines/amplitude_amplification.py +++ b/pennylane/templates/subroutines/amplitude_amplification.py @@ -102,6 +102,8 @@ def circuit(): [0.013, 0.013, 0.91, 0.013, 0.013, 0.013, 0.013, 0.013] """ + grad_method = None + def _flatten(self): data = (self.hyperparameters["U"], self.hyperparameters["O"]) metadata = tuple(item for item in self.hyperparameters.items() if item[0] not in ["O", "U"]) @@ -141,11 +143,11 @@ def __init__( self.hyperparameters["p_min"] = p_min self.hyperparameters["reflection_wires"] = qml.wires.Wires(reflection_wires) - super().__init__(wires=wires) + super().__init__(*U.data, *O.data, wires=wires) # pylint:disable=arguments-differ @staticmethod - def compute_decomposition(**kwargs): + def compute_decomposition(*_, **kwargs): U = kwargs["U"] O = kwargs["O"] iters = kwargs["iters"] diff --git a/pennylane/templates/subroutines/reflection.py b/pennylane/templates/subroutines/reflection.py index f67a0c435b6..bad3316b3fe 100644 --- a/pennylane/templates/subroutines/reflection.py +++ b/pennylane/templates/subroutines/reflection.py @@ -138,7 +138,7 @@ def __init__(self, U, alpha=np.pi, reflection_wires=None, id=None): "reflection_wires": tuple(reflection_wires), } - super().__init__(alpha, wires=wires, id=id) + super().__init__(alpha, *U.data, wires=wires, id=id) def map_wires(self, wire_map: dict): # pylint: disable=protected-access diff --git a/tests/templates/test_subroutines/test_amplitude_amplification.py b/tests/templates/test_subroutines/test_amplitude_amplification.py index 487d0f3dbe4..9dff939ee6d 100644 --- a/tests/templates/test_subroutines/test_amplitude_amplification.py +++ b/tests/templates/test_subroutines/test_amplitude_amplification.py @@ -145,7 +145,7 @@ def circuit(params): qml.RZ(params[1], wires=0), iters=3, fixed_point=True, - work_wire=3, + work_wire=2, ) return qml.expval(qml.PauliZ(0)) @@ -156,28 +156,36 @@ def circuit(params): params = np.array([0.9, 0.1]) @pytest.mark.autograd - def test_qnode_autograd(self): + @pytest.mark.parametrize("device", ["default.qubit", "default.qubit.legacy"]) + @pytest.mark.parametrize("shots", [None, 50000]) + def test_qnode_autograd(self, device, shots): """Test that the QNode executes with Autograd.""" - dev = qml.device("default.qubit") - qnode = qml.QNode(self.circuit, dev, interface="autograd") + dev = qml.device(device, wires=3, shots=shots) + diff_method = "backprop" if shots is None else "parameter-shift" + qnode = qml.QNode(self.circuit, dev, interface="autograd", diff_method=diff_method) params = qml.numpy.array(self.params, requires_grad=True) res = qml.grad(qnode)(params) assert qml.math.shape(res) == (2,) - assert np.allclose(res, self.exp_grad, atol=1e-5) + assert np.allclose(res, self.exp_grad, atol=0.01) @pytest.mark.jax @pytest.mark.parametrize("use_jit", [False, True]) @pytest.mark.parametrize("shots", [None, 50000]) - def test_qnode_jax(self, shots, use_jit): + @pytest.mark.parametrize("device", ["default.qubit", "default.qubit.legacy"]) + def test_qnode_jax(self, shots, use_jit, device): """Test that the QNode executes and is differentiable with JAX. The shots argument controls whether autodiff or parameter-shift gradients are used.""" import jax jax.config.update("jax_enable_x64", True) - dev = qml.device("default.qubit", shots=shots, seed=10) + if device == "default.qubit": + dev = qml.device("default.qubit", shots=shots, seed=10) + else: + dev = qml.device("default.qubit.legacy", shots=shots, wires=3) + diff_method = "backprop" if shots is None else "parameter-shift" qnode = qml.QNode(self.circuit, dev, interface="jax", diff_method=diff_method) if use_jit: @@ -195,12 +203,17 @@ def test_qnode_jax(self, shots, use_jit): @pytest.mark.torch @pytest.mark.parametrize("shots", [None, 50000]) - def test_qnode_torch(self, shots): + @pytest.mark.parametrize("device", ["default.qubit", "default.qubit.legacy"]) + def test_qnode_torch(self, shots, device): """Test that the QNode executes and is differentiable with Torch. The shots argument controls whether autodiff or parameter-shift gradients are used.""" import torch - dev = qml.device("default.qubit", shots=shots, seed=10) + if device == "default.qubit": + dev = qml.device("default.qubit", shots=shots, seed=10) + else: + dev = qml.device("default.qubit.legacy", shots=shots, wires=3) + diff_method = "backprop" if shots is None else "parameter-shift" qnode = qml.QNode(self.circuit, dev, interface="torch", diff_method=diff_method) From 237226e39f6c055957681bcb33b2380fd5ba05d1 Mon Sep 17 00:00:00 2001 From: Astral Cai Date: Wed, 5 Jun 2024 15:46:14 -0400 Subject: [PATCH 04/36] Mark Qubitization as having no grad_method --- pennylane/templates/subroutines/qubitization.py | 4 +++- tests/templates/test_subroutines/test_qubitization.py | 11 ++++++++--- 2 files changed, 11 insertions(+), 4 deletions(-) diff --git a/pennylane/templates/subroutines/qubitization.py b/pennylane/templates/subroutines/qubitization.py index ae45487404a..5847d64c652 100644 --- a/pennylane/templates/subroutines/qubitization.py +++ b/pennylane/templates/subroutines/qubitization.py @@ -92,6 +92,8 @@ def circuit(): eigenvalue: 0.7 """ + grad_method = None + @classmethod def _primitive_bind_call(cls, *args, **kwargs): return cls._primitive.bind(*args, **kwargs) @@ -104,7 +106,7 @@ def __init__(self, hamiltonian, control, id=None): "control": qml.wires.Wires(control), } - super().__init__(wires=wires, id=id) + super().__init__(*hamiltonian.data, wires=wires, id=id) def _flatten(self): data = (self.hyperparameters["hamiltonian"],) diff --git a/tests/templates/test_subroutines/test_qubitization.py b/tests/templates/test_subroutines/test_qubitization.py index 430382f011e..124445f0ee8 100644 --- a/tests/templates/test_subroutines/test_qubitization.py +++ b/tests/templates/test_subroutines/test_qubitization.py @@ -227,14 +227,19 @@ def test_qnode_autograd(self): "use_jit , shots", ((False, None), (True, None), (False, 50000)), ) # TODO: (True, 50000) fails because jax.jit on jax.grad does not work with AmplitudeEmbedding - def test_qnode_jax(self, shots, use_jit): + @pytest.mark.parametrize("device", ["default.qubit", "default.qubit.legacy"]) + def test_qnode_jax(self, shots, use_jit, device): """ "Test that the QNode executes and is differentiable with JAX. The shots argument controls whether autodiff or parameter-shift gradients are used.""" import jax jax.config.update("jax_enable_x64", True) - dev = qml.device("default.qubit", shots=shots, seed=10) + if device == "default.qubit": + dev = qml.device("default.qubit", shots=shots, seed=10) + else: + dev = qml.device("default.qubit.legacy", shots=shots, wires=5) + diff_method = "backprop" if shots is None else "parameter-shift" qnode = qml.QNode(self.circuit, dev, interface="jax", diff_method=diff_method) if use_jit: @@ -248,7 +253,7 @@ def test_qnode_jax(self, shots, use_jit): jac = jac_fn(params) assert jac.shape == (4,) - assert np.allclose(jac, self.exp_grad, atol=0.01) + assert np.allclose(jac, self.exp_grad, atol=0.05) @pytest.mark.torch @pytest.mark.parametrize( From 7cd3c02977b4c6abc8923563618347a741382d56 Mon Sep 17 00:00:00 2001 From: Astral Cai Date: Mon, 3 Jun 2024 17:04:34 -0400 Subject: [PATCH 05/36] Add device preprocessing to inner transform --- pennylane/workflow/execution.py | 46 ++++++++++++++++++++++++--------- pennylane/workflow/qnode.py | 17 ++++++++---- 2 files changed, 46 insertions(+), 17 deletions(-) diff --git a/pennylane/workflow/execution.py b/pennylane/workflow/execution.py index f27d100e752..2bb78270c9e 100644 --- a/pennylane/workflow/execution.py +++ b/pennylane/workflow/execution.py @@ -253,7 +253,13 @@ def device_expansion_function(tape): # pylint: disable=function-redefined def _make_inner_execute( - device, override_shots, cache, expand_fn=None, execution_config=None, numpy_only=True + device, + override_shots, + cache, + inner_transform, + expand_fn=None, + execution_config=None, + numpy_only=True, ) -> Callable: """Construct the function that will execute the tapes inside the ml framework registration for the 1st order derivatives. @@ -282,28 +288,29 @@ def _make_inner_execute( else: device_execution = partial(device.execute, execution_config=execution_config) + transform_program = qml.transforms.core.TransformProgram(inner_transform) + def inner_execute(tapes: Sequence[QuantumTape], **_) -> ResultBatch: """Execution that occurs within a machine learning framework boundary. Closure Variables: expand_fn (Callable[[QuantumTape], QuantumTape]): A device preprocessing step - numpy_only (bool): whether or not to convert the data to numpy or leave as is + numpy_only (bool): whether to convert the data to numpy or leave as is device_execution (Callable[[Sequence[QuantumTape]], ResultBatch]) cache (None | MutableMapping): The cache to use. If ``None``, caching will not occur. """ - transform_program = qml.transforms.core.TransformProgram() if numpy_only: - transform_program.add_transform(qml.transforms.convert_to_numpy_parameters) + transform_program.insert_front_transform(qml.transforms.convert_to_numpy_parameters) if cache is not None: transform_program.add_transform(_cache_transform, cache=cache) + transformed_tapes, transform_post_processing = transform_program(tapes) + # TODO: Apply expand_fn() as transform. if expand_fn: - tapes = tuple(expand_fn(t) for t in tapes) - - transformed_tapes, transform_post_processing = transform_program(tapes) + transformed_tapes = tuple(expand_fn(t) for t in transformed_tapes) if transformed_tapes: results = device_execution(transformed_tapes) @@ -407,6 +414,7 @@ def execute( gradient_fn: Optional[Union[Callable, str]] = None, interface="auto", transform_program=None, + inner_transform=None, config=None, grad_on_execution="best", gradient_kwargs=None, @@ -435,6 +443,7 @@ def execute( This affects the types of parameters that can exist on the input tapes. Available options include ``autograd``, ``torch``, ``tf``, ``jax`` and ``auto``. transform_program(.TransformProgram): A transform program to be applied to the initial tape. + inner_transform (.TransformProgram): A transform program to be applied to the tapes in the inner_execute. config (qml.devices.ExecutionConfig): A datastructure describing the parameters needed to fully describe the execution. grad_on_execution (bool, str): Whether the gradients should be computed on the execution or not. Only applies if the device is queried for the gradient; gradient transform @@ -584,11 +593,21 @@ def cost_fn(params, x): raise ValueError("Using postselect_mode='hw-like' is not supported with jax-jit.") config.mcm_config.postselect_mode = "fill-shots" - if transform_program is None: - if isinstance(device, qml.devices.Device): - transform_program = device.preprocess(config)[0] + if isinstance(device, qml.devices.Device): + + # If gradient_fn is a gradient transform, device preprocessing should happen in + # inner execute (inside the ml boundary). + if isinstance(gradient_fn, qml.transforms.core.TransformDispatcher): + inner_transform = inner_transform or device.preprocess(config)[0] + transform_program = transform_program or qml.transforms.core.TransformProgram() else: - transform_program = qml.transforms.core.TransformProgram() + inner_transform = inner_transform or qml.transforms.core.TransformProgram() + transform_program = transform_program or device.preprocess(config)[0] + + else: + + transform_program = transform_program or qml.transforms.core.TransformProgram() + inner_transform = inner_transform or qml.transforms.core.TransformProgram() # If caching is desired but an explicit cache is not provided, use an ``LRUCache``. if cache is True: @@ -614,6 +633,7 @@ def cost_fn(params, x): device, override_shots, cache, + inner_transform, expand_fn, config, numpy_only=not device_supports_interface_data, @@ -751,7 +771,9 @@ def device_execute_and_gradients(internal_tapes, **gradient_kwargs): else: # need to override to have no cache - inner_execute = _make_inner_execute(device, override_shots, cache=None) + inner_execute = _make_inner_execute( + device, override_shots, cache=None, inner_transform=inner_transform + ) def inner_execute_with_empty_jac(tapes, **_): return (inner_execute(tapes), []) diff --git a/pennylane/workflow/qnode.py b/pennylane/workflow/qnode.py index e0fd75458c4..ca44d873d5f 100644 --- a/pennylane/workflow/qnode.py +++ b/pennylane/workflow/qnode.py @@ -1046,14 +1046,19 @@ def _execution_component(self, args: tuple, kwargs: dict, override_shots) -> qml "Cannot use the 'one-shot' method for mid-circuit measurements with analytic mode." ) - # Add the device program to the QNode program + full_transform_program = qml.transforms.core.TransformProgram(self.transform_program) + inner_transform_program = qml.transforms.core.TransformProgram() + config = None + if isinstance(self.device, qml.devices.Device): + config = _make_execution_config(self, self.gradient_fn) device_transform_program, config = self.device.preprocess(execution_config=config) - full_transform_program = self.transform_program + device_transform_program - else: - config = None - full_transform_program = qml.transforms.core.TransformProgram(self.transform_program) + + if config.use_device_gradient: + full_transform_program += device_transform_program + else: + inner_transform_program += device_transform_program has_mcm_support = ( any(isinstance(op, MidMeasureMP) for op in self._tape) @@ -1079,6 +1084,7 @@ def _execution_component(self, args: tuple, kwargs: dict, override_shots) -> qml qml.transform(self.gradient_fn.expand_transform), **self.gradient_kwargs, ) + # Calculate the classical jacobians if necessary full_transform_program.set_classical_component(self, args, kwargs) full_transform_program.prune_dynamic_transform() @@ -1090,6 +1096,7 @@ def _execution_component(self, args: tuple, kwargs: dict, override_shots) -> qml gradient_fn=self.gradient_fn, interface=self.interface, transform_program=full_transform_program, + inner_transform=inner_transform_program, config=config, gradient_kwargs=self.gradient_kwargs, override_shots=override_shots, From c10ca0b3a6744bb7d37fd854fb0b5860c785dfcd Mon Sep 17 00:00:00 2001 From: Astral Cai Date: Tue, 4 Jun 2024 09:59:20 -0400 Subject: [PATCH 06/36] fix bug with mcm --- .../transforms/core/transform_program.py | 57 ++++++++++++------- pennylane/workflow/qnode.py | 3 +- 2 files changed, 40 insertions(+), 20 deletions(-) diff --git a/pennylane/transforms/core/transform_program.py b/pennylane/transforms/core/transform_program.py index 31eac21e420..3dd66b76197 100644 --- a/pennylane/transforms/core/transform_program.py +++ b/pennylane/transforms/core/transform_program.py @@ -354,25 +354,6 @@ def set_classical_component(self, qnode, args, kwargs): self._set_all_classical_jacobians(qnode, args, kwargs, argnums) self._set_all_argnums(qnode, args, kwargs, argnums) - def prune_dynamic_transform(self): - """Ensure a single ``dynamic_one_shot`` transform is applied.""" - trans_type = np.zeros(len(self._transform_program), dtype=np.int32) - for i, t in enumerate(self._transform_program): - if "dynamic_one_shot" in str(t): - trans_type[i] = 1 - if "mid_circuit_measurements" in str(t): - trans_type[i] = 2 - if sum(trans_type) < 2: - return - keep = 2 if 2 in trans_type else 1 - found = False - for i, ttype in enumerate(reversed(trans_type)): - if not found and ttype == keep: - found = True - continue - if found and ttype in [1, 2]: - self._transform_program.pop(len(self._transform_program) - 1 - i) - def _set_all_classical_jacobians( self, qnode, args, kwargs, argnums ): # pylint: disable=too-many-statements @@ -545,3 +526,41 @@ def __call__(self, tapes: Tuple[QuantumTape]) -> Tuple[ResultBatch, BatchPostPro # Reset classical jacobians self._classical_jacobians = [] return tuple(tapes), postprocessing_fn + + +# pylint: disable=protected-access +def prune_dynamic_transform(outer_transform, inner_transform): + """Ensure a single ``dynamic_one_shot`` transform is applied.""" + + trans_type_inner = [0 for _ in inner_transform] + trans_type_outer = [0 for _ in outer_transform] + + for i, t in enumerate(outer_transform): + if "dynamic_one_shot" in str(t): + trans_type_outer[i] = 1 + if "mid_circuit_measurements" in str(t): + trans_type_outer[i] = 2 + + for i, t in enumerate(inner_transform): + if "dynamic_one_shot" in str(t): + trans_type_inner[i] = 1 + if "mid_circuit_measurements" in str(t): + trans_type_inner[i] = 2 + + if sum(trans_type_inner) + sum(trans_type_outer) < 2: + return + + keep = 2 if 2 in trans_type_inner + trans_type_outer else 1 + found = False + for i, ttype in enumerate(reversed(trans_type_inner)): + if not found and ttype == keep: + found = True + continue + if found and ttype in [1, 2]: + inner_transform._transform_program.pop(len(inner_transform) - 1 - i) + for i, ttype in enumerate(reversed(trans_type_outer)): + if not found and ttype == keep: + found = True + continue + if found and ttype in [1, 2]: + outer_transform._transform_program.pop(len(outer_transform) - 1 - i) diff --git a/pennylane/workflow/qnode.py b/pennylane/workflow/qnode.py index ca44d873d5f..3273da8a114 100644 --- a/pennylane/workflow/qnode.py +++ b/pennylane/workflow/qnode.py @@ -28,6 +28,7 @@ from pennylane.logging import debug_logger from pennylane.measurements import CountsMP, MidMeasureMP, Shots from pennylane.tape import QuantumScript, QuantumTape +from pennylane.transforms.core.transform_program import prune_dynamic_transform from .execution import INTERFACE_MAP, SUPPORTED_INTERFACES @@ -1087,7 +1088,7 @@ def _execution_component(self, args: tuple, kwargs: dict, override_shots) -> qml # Calculate the classical jacobians if necessary full_transform_program.set_classical_component(self, args, kwargs) - full_transform_program.prune_dynamic_transform() + prune_dynamic_transform(full_transform_program, inner_transform_program) # pylint: disable=unexpected-keyword-arg res = qml.execute( From d931d53ffbaa765e598c6cb35e975226592c5a61 Mon Sep 17 00:00:00 2001 From: Astral Cai Date: Tue, 4 Jun 2024 10:38:03 -0400 Subject: [PATCH 07/36] make pylint happy --- .../transforms/core/transform_program.py | 71 ++++++++++--------- pennylane/workflow/qnode.py | 4 +- 2 files changed, 41 insertions(+), 34 deletions(-) diff --git a/pennylane/transforms/core/transform_program.py b/pennylane/transforms/core/transform_program.py index 3dd66b76197..71b1f98cf17 100644 --- a/pennylane/transforms/core/transform_program.py +++ b/pennylane/transforms/core/transform_program.py @@ -17,8 +17,6 @@ from functools import partial from typing import Callable, List, Optional, Sequence, Tuple, Union -import numpy as np - import pennylane as qml from pennylane.tape import QuantumTape from pennylane.typing import Result, ResultBatch @@ -354,6 +352,33 @@ def set_classical_component(self, qnode, args, kwargs): self._set_all_classical_jacobians(qnode, args, kwargs, argnums) self._set_all_argnums(qnode, args, kwargs, argnums) + def _prune_dynamic_transform(self, type_to_keep=1): + """Ensures that only one or none ``dynamic_one_shot`` is applied. + + Args: + type_to_keep (int): The type of the dynamic transform to keep. 0: keep none, + 1: dynamic_one_shot or mid_circuit_measurements, 2: only mid_circuit_measurements. + + Returns: + bool: ``True`` if a dynamic transform was found, ``False`` otherwise. + + """ + + i = len(self._transform_program) - 1 + found = False + while i >= 0: + t = self._transform_program[i] + if "mid_circuit_measurements" in str(t) and type_to_keep > 0: + type_to_keep = 0 # keep this and do not keep the rest + found = True + elif "dynamic_one_shot" in str(t) and type_to_keep == 1: + type_to_keep = 0 # keep this and do not keep the rest + found = True + elif "dynamic_one_shot" in str(t) or "mid_circuit_measurements" in str(t): + self._transform_program.pop(i) + i -= 1 + return found + def _set_all_classical_jacobians( self, qnode, args, kwargs, argnums ): # pylint: disable=too-many-statements @@ -529,38 +554,20 @@ def __call__(self, tapes: Tuple[QuantumTape]) -> Tuple[ResultBatch, BatchPostPro # pylint: disable=protected-access -def prune_dynamic_transform(outer_transform, inner_transform): +def _prune_dynamic_transform(outer_transform, inner_transform): """Ensure a single ``dynamic_one_shot`` transform is applied.""" - trans_type_inner = [0 for _ in inner_transform] - trans_type_outer = [0 for _ in outer_transform] - - for i, t in enumerate(outer_transform): - if "dynamic_one_shot" in str(t): - trans_type_outer[i] = 1 - if "mid_circuit_measurements" in str(t): - trans_type_outer[i] = 2 - - for i, t in enumerate(inner_transform): - if "dynamic_one_shot" in str(t): - trans_type_inner[i] = 1 - if "mid_circuit_measurements" in str(t): - trans_type_inner[i] = 2 + all_transforms = outer_transform._transform_program + inner_transform._transform_program + type_to_keep = 0 + if any("mid_circuit_measurements" in str(t) for t in all_transforms): + type_to_keep = 2 + elif any("dynamic_one_shot" in str(t) for t in all_transforms): + type_to_keep = 1 - if sum(trans_type_inner) + sum(trans_type_outer) < 2: + if type_to_keep == 0: return - keep = 2 if 2 in trans_type_inner + trans_type_outer else 1 - found = False - for i, ttype in enumerate(reversed(trans_type_inner)): - if not found and ttype == keep: - found = True - continue - if found and ttype in [1, 2]: - inner_transform._transform_program.pop(len(inner_transform) - 1 - i) - for i, ttype in enumerate(reversed(trans_type_outer)): - if not found and ttype == keep: - found = True - continue - if found and ttype in [1, 2]: - outer_transform._transform_program.pop(len(outer_transform) - 1 - i) + dynamic_transform_found = inner_transform._prune_dynamic_transform(type_to_keep) + if dynamic_transform_found: + type_to_keep = 0 + outer_transform._prune_dynamic_transform(type_to_keep) diff --git a/pennylane/workflow/qnode.py b/pennylane/workflow/qnode.py index 3273da8a114..2bb363a8b70 100644 --- a/pennylane/workflow/qnode.py +++ b/pennylane/workflow/qnode.py @@ -28,7 +28,7 @@ from pennylane.logging import debug_logger from pennylane.measurements import CountsMP, MidMeasureMP, Shots from pennylane.tape import QuantumScript, QuantumTape -from pennylane.transforms.core.transform_program import prune_dynamic_transform +from pennylane.transforms.core.transform_program import _prune_dynamic_transform from .execution import INTERFACE_MAP, SUPPORTED_INTERFACES @@ -1088,7 +1088,7 @@ def _execution_component(self, args: tuple, kwargs: dict, override_shots) -> qml # Calculate the classical jacobians if necessary full_transform_program.set_classical_component(self, args, kwargs) - prune_dynamic_transform(full_transform_program, inner_transform_program) + _prune_dynamic_transform(full_transform_program, inner_transform_program) # pylint: disable=unexpected-keyword-arg res = qml.execute( From 7451e3618b4e3a9036018717a6d4991550214090 Mon Sep 17 00:00:00 2001 From: Astral Cai Date: Tue, 4 Jun 2024 10:55:54 -0400 Subject: [PATCH 08/36] fix bug with cache --- pennylane/workflow/execution.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/pennylane/workflow/execution.py b/pennylane/workflow/execution.py index 2bb78270c9e..4bc6151211b 100644 --- a/pennylane/workflow/execution.py +++ b/pennylane/workflow/execution.py @@ -288,8 +288,6 @@ def _make_inner_execute( else: device_execution = partial(device.execute, execution_config=execution_config) - transform_program = qml.transforms.core.TransformProgram(inner_transform) - def inner_execute(tapes: Sequence[QuantumTape], **_) -> ResultBatch: """Execution that occurs within a machine learning framework boundary. @@ -300,12 +298,16 @@ def inner_execute(tapes: Sequence[QuantumTape], **_) -> ResultBatch: cache (None | MutableMapping): The cache to use. If ``None``, caching will not occur. """ + transform_program = qml.transforms.core.TransformProgram() + if numpy_only: - transform_program.insert_front_transform(qml.transforms.convert_to_numpy_parameters) + transform_program.add_transform(qml.transforms.convert_to_numpy_parameters) if cache is not None: transform_program.add_transform(_cache_transform, cache=cache) + transform_program += inner_transform + transformed_tapes, transform_post_processing = transform_program(tapes) # TODO: Apply expand_fn() as transform. From 26417cdff1650fc1cf62a5464692374791b59889 Mon Sep 17 00:00:00 2001 From: Astral Cai Date: Tue, 4 Jun 2024 13:49:22 -0400 Subject: [PATCH 09/36] fix testcase --- tests/docs/test_supported_confs.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/docs/test_supported_confs.py b/tests/docs/test_supported_confs.py index bebba067858..12a9a4206a8 100644 --- a/tests/docs/test_supported_confs.py +++ b/tests/docs/test_supported_confs.py @@ -412,7 +412,7 @@ def test_all_paramshift_state(self, interface, return_type, shots, wire_specs): # with pytest.raises(ValueError, match=msg): circuit = get_qnode(interface, "parameter-shift", return_type, shots, wire_specs) x = get_variable(interface, wire_specs, complex=complex) - if shots is not None: + if shots is not None and interface != "jax": with pytest.raises(qml.DeviceError, match="not accepted with finite shots"): compute_gradient(x, interface, circuit, return_type, complex=complex) else: From 7155a4eb269f73c71406e156d5f65a738207f2aa Mon Sep 17 00:00:00 2001 From: Astral Cai Date: Tue, 4 Jun 2024 14:08:24 -0400 Subject: [PATCH 10/36] mcm for legacy device --- pennylane/workflow/qnode.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pennylane/workflow/qnode.py b/pennylane/workflow/qnode.py index 2bb363a8b70..dd770ce4042 100644 --- a/pennylane/workflow/qnode.py +++ b/pennylane/workflow/qnode.py @@ -1067,7 +1067,7 @@ def _execution_component(self, args: tuple, kwargs: dict, override_shots) -> qml and self.device.capabilities().get("supports_mid_measure", False) ) if has_mcm_support: - full_transform_program.add_transform( + inner_transform_program.add_transform( qml.devices.preprocess.mid_circuit_measurements, device=self.device, mcm_config=self.execute_kwargs["mcm_config"], From 792a11069e670f8f18a14c85869b4731ac98392b Mon Sep 17 00:00:00 2001 From: Astral Cai Date: Tue, 4 Jun 2024 15:09:07 -0400 Subject: [PATCH 11/36] update test case --- tests/docs/test_supported_confs.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/docs/test_supported_confs.py b/tests/docs/test_supported_confs.py index 12a9a4206a8..2859ffea0b6 100644 --- a/tests/docs/test_supported_confs.py +++ b/tests/docs/test_supported_confs.py @@ -540,7 +540,7 @@ def test_all_hadamard_nonstate_non_var( circuit = get_qnode(interface, diff_method, return_type, shots, wire_specs) x = get_variable(interface, wire_specs) if return_type in (VnEntropy, MutualInfo): - if shots: + if shots and interface != "jax": err_cls = qml.DeviceError msg = "not accepted with finite shots" else: From 7b87b2ec07476998ec526435c7b15e620c02f953 Mon Sep 17 00:00:00 2001 From: Astral Cai Date: Tue, 4 Jun 2024 15:17:13 -0400 Subject: [PATCH 12/36] fix bug where preprocess is called twice --- pennylane/workflow/execution.py | 19 ++++++++++++------- 1 file changed, 12 insertions(+), 7 deletions(-) diff --git a/pennylane/workflow/execution.py b/pennylane/workflow/execution.py index 4bc6151211b..bb8e284d7cb 100644 --- a/pennylane/workflow/execution.py +++ b/pennylane/workflow/execution.py @@ -600,16 +600,21 @@ def cost_fn(params, x): # If gradient_fn is a gradient transform, device preprocessing should happen in # inner execute (inside the ml boundary). if isinstance(gradient_fn, qml.transforms.core.TransformDispatcher): - inner_transform = inner_transform or device.preprocess(config)[0] - transform_program = transform_program or qml.transforms.core.TransformProgram() + if inner_transform is None: + inner_transform = device.preprocess(config)[0] + if transform_program is None: + transform_program = qml.transforms.core.TransformProgram() else: - inner_transform = inner_transform or qml.transforms.core.TransformProgram() - transform_program = transform_program or device.preprocess(config)[0] + if inner_transform is None: + inner_transform = qml.transforms.core.TransformProgram() + if transform_program is None: + transform_program = device.preprocess(config)[0] else: - - transform_program = transform_program or qml.transforms.core.TransformProgram() - inner_transform = inner_transform or qml.transforms.core.TransformProgram() + if transform_program is None: + transform_program = qml.transforms.core.TransformProgram() + if inner_transform is None: + inner_transform = qml.transforms.core.TransformProgram() # If caching is desired but an explicit cache is not provided, use an ``LRUCache``. if cache is True: From 8d7594a4da93f19a41321dc751c463d015c05da7 Mon Sep 17 00:00:00 2001 From: Astral Cai Date: Tue, 4 Jun 2024 15:35:48 -0400 Subject: [PATCH 13/36] fix bug with jax --- pennylane/workflow/execution.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/pennylane/workflow/execution.py b/pennylane/workflow/execution.py index bb8e284d7cb..6af0d3f5e32 100644 --- a/pennylane/workflow/execution.py +++ b/pennylane/workflow/execution.py @@ -298,7 +298,7 @@ def inner_execute(tapes: Sequence[QuantumTape], **_) -> ResultBatch: cache (None | MutableMapping): The cache to use. If ``None``, caching will not occur. """ - transform_program = qml.transforms.core.TransformProgram() + transform_program = qml.transforms.core.TransformProgram(inner_transform) if numpy_only: transform_program.add_transform(qml.transforms.convert_to_numpy_parameters) @@ -306,8 +306,6 @@ def inner_execute(tapes: Sequence[QuantumTape], **_) -> ResultBatch: if cache is not None: transform_program.add_transform(_cache_transform, cache=cache) - transform_program += inner_transform - transformed_tapes, transform_post_processing = transform_program(tapes) # TODO: Apply expand_fn() as transform. From d1a6001f955313b385d4fca436cd5f248a3df2ac Mon Sep 17 00:00:00 2001 From: Astral Cai Date: Wed, 5 Jun 2024 15:31:04 -0400 Subject: [PATCH 14/36] trigger ci From 7ffba2dfdc5c86ad51aa445db4eb5d1a83ee7f54 Mon Sep 17 00:00:00 2001 From: Astral Cai Date: Wed, 5 Jun 2024 15:51:57 -0400 Subject: [PATCH 15/36] retrigger ci From 5075499c3452977972d0277dcd4558783516eeb8 Mon Sep 17 00:00:00 2001 From: Astral Cai Date: Wed, 5 Jun 2024 16:43:32 -0400 Subject: [PATCH 16/36] add changelog entry --- doc/releases/changelog-dev.md | 3 +++ 1 file changed, 3 insertions(+) diff --git a/doc/releases/changelog-dev.md b/doc/releases/changelog-dev.md index 16d0483df16..f0fafbfaf8b 100644 --- a/doc/releases/changelog-dev.md +++ b/doc/releases/changelog-dev.md @@ -291,6 +291,9 @@ * `CNOT` and `Toffoli` now have an `arithmetic_depth` of `1`, as they are controlled operations. [(#5797)](https://github.com/PennyLaneAI/pennylane/pull/5797) +* Fixes a bug where the gradient of `ControlledSequence`, `Reflection`, `AmplitudeAmplification`, and `Qubitization` is incorrect on `default.qubit.legacy` with `parameter_shift`. + [(#5806)](https://github.com/PennyLaneAI/pennylane/pull/5806) +

Contributors ✍️

This release contains contributions from (in alphabetical order): From a6018c0458e0d6c92fe7827014eb552aab9d9570 Mon Sep 17 00:00:00 2001 From: Astral Cai Date: Thu, 6 Jun 2024 09:17:14 -0400 Subject: [PATCH 17/36] Pass interface to dynamic one shot --- pennylane/devices/default_qubit.py | 5 ++++- pennylane/devices/preprocess.py | 7 +++++-- pennylane/transforms/dynamic_one_shot.py | 11 ++++++++--- 3 files changed, 17 insertions(+), 6 deletions(-) diff --git a/pennylane/devices/default_qubit.py b/pennylane/devices/default_qubit.py index c0a8a99f48d..6b1388802e9 100644 --- a/pennylane/devices/default_qubit.py +++ b/pennylane/devices/default_qubit.py @@ -502,7 +502,10 @@ def preprocess( transform_program.add_transform(validate_device_wires, self.wires, name=self.name) transform_program.add_transform( - mid_circuit_measurements, device=self, mcm_config=config.mcm_config + mid_circuit_measurements, + device=self, + mcm_config=config.mcm_config, + interface=config.interface, ) transform_program.add_transform( decompose, diff --git a/pennylane/devices/preprocess.py b/pennylane/devices/preprocess.py index 23fed7614e6..eb465af72af 100644 --- a/pennylane/devices/preprocess.py +++ b/pennylane/devices/preprocess.py @@ -147,7 +147,10 @@ def validate_device_wires( @transform def mid_circuit_measurements( - tape: qml.tape.QuantumTape, device, mcm_config=MCMConfig() + tape: qml.tape.QuantumTape, + device, + mcm_config=MCMConfig(), + interface=None, ) -> tuple[Sequence[qml.tape.QuantumTape], Callable]: """Provide the transform to handle mid-circuit measurements. @@ -162,7 +165,7 @@ def mid_circuit_measurements( mcm_method = "one-shot" if tape.shots else "deferred" if mcm_method == "one-shot": - return qml.dynamic_one_shot(tape) + return qml.dynamic_one_shot(tape, interface=interface) return qml.defer_measurements(tape, device=device) diff --git a/pennylane/transforms/dynamic_one_shot.py b/pennylane/transforms/dynamic_one_shot.py index 831e30031af..4e96ee035a7 100644 --- a/pennylane/transforms/dynamic_one_shot.py +++ b/pennylane/transforms/dynamic_one_shot.py @@ -118,6 +118,8 @@ def func(x, y): aux_tapes = [init_auxiliary_tape(t) for t in tapes] + interface = kwargs.get("interface", None) + def processing_fn(results, has_partitioned_shots=None, batched_results=None): if batched_results is None and batch_size is not None: # If broadcasting, recursively process the results for each batch. For each batch @@ -141,7 +143,7 @@ def processing_fn(results, has_partitioned_shots=None, batched_results=None): return tuple(final_results) if not tape.shots.has_partitioned_shots: results = results[0] - return parse_native_mid_circuit_measurements(tape, aux_tapes, results) + return parse_native_mid_circuit_measurements(tape, aux_tapes, results, interface=interface) return aux_tapes, processing_fn @@ -208,7 +210,10 @@ def init_auxiliary_tape(circuit: qml.tape.QuantumScript): def parse_native_mid_circuit_measurements( - circuit: qml.tape.QuantumScript, aux_tapes: qml.tape.QuantumScript, results: TensorLike + circuit: qml.tape.QuantumScript, + aux_tapes: qml.tape.QuantumScript, + results: TensorLike, + interface=None, ): """Combines, gathers and normalizes the results of native mid-circuit measurement runs. @@ -228,7 +233,7 @@ def measurement_with_no_shots(measurement): else np.nan ) - interface = qml.math.get_deep_interface(circuit.data) + interface = interface or qml.math.get_deep_interface(circuit.data) interface = "numpy" if interface == "builtins" else interface all_mcms = [op for op in aux_tapes[0].operations if is_mcm(op)] From 494b22449ff978a1f33fc20aee185f1388938032 Mon Sep 17 00:00:00 2001 From: Astral Cai Date: Thu, 6 Jun 2024 09:28:58 -0400 Subject: [PATCH 18/36] update changelog --- doc/releases/changelog-dev.md | 3 +++ 1 file changed, 3 insertions(+) diff --git a/doc/releases/changelog-dev.md b/doc/releases/changelog-dev.md index 699a9d89576..207e0cee91f 100644 --- a/doc/releases/changelog-dev.md +++ b/doc/releases/changelog-dev.md @@ -172,6 +172,9 @@ [(#5758)](https://github.com/PennyLaneAI/pennylane/pull/5758/) [(#5638)](https://github.com/PennyLaneAI/pennylane/pull/5638/) +* Device preprocess transforms now happen inside the ml boundary. + [(#5791)](https://github.com/PennyLaneAI/pennylane/pull/5791) +

Community contributions 🥳

* Implemented kwargs (`check_interface`, `check_trainability`, `rtol` and `atol`) support in `qml.equal` for the operators `Pow`, `Adjoint`, `Exp`, and `SProd`. From de02371f9c5ae5b9655e55fd41b8f60b8038add9 Mon Sep 17 00:00:00 2001 From: Astral Cai Date: Thu, 6 Jun 2024 10:14:10 -0400 Subject: [PATCH 19/36] update test case --- tests/workflow/test_construct_batch.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/tests/workflow/test_construct_batch.py b/tests/workflow/test_construct_batch.py index e80aaee331b..9af2d8c8d28 100644 --- a/tests/workflow/test_construct_batch.py +++ b/tests/workflow/test_construct_batch.py @@ -142,7 +142,9 @@ def circuit(x): assert len(full_prog) == 13 config = qml.devices.ExecutionConfig( - gradient_method="adjoint", use_device_jacobian_product=False + interface=getattr(circuit, "interface", None), + gradient_method="adjoint", + use_device_jacobian_product=False, ) dev_program = dev.preprocess(config)[0] @@ -194,7 +196,8 @@ def circuit(): assert grad_program[2].transform == qml.gradients.param_shift.expand_transform dev_program = get_transform_program(circuit, level="device") - assert len(dev_program) == 3 + len(circuit.device.preprocess()[0]) # currently 8 + config = qml.devices.ExecutionConfig(interface=getattr(circuit, "interface", None)) + assert len(dev_program) == 3 + len(circuit.device.preprocess(config)[0]) # currently 8 assert qml.metric_tensor not in dev_program full = get_transform_program(circuit) From edc106997381d54f0581ab78ddec0e46715069e0 Mon Sep 17 00:00:00 2001 From: Astral Cai Date: Thu, 6 Jun 2024 10:56:09 -0400 Subject: [PATCH 20/36] update test with config --- tests/workflow/test_construct_batch.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/workflow/test_construct_batch.py b/tests/workflow/test_construct_batch.py index 9af2d8c8d28..44e345e552a 100644 --- a/tests/workflow/test_construct_batch.py +++ b/tests/workflow/test_construct_batch.py @@ -106,7 +106,8 @@ def circuit(): assert p_dev == p_default assert p_none == p_dev assert len(p_dev) == 9 - assert p_dev == p_grad + dev.preprocess()[0] + config = qml.devices.ExecutionConfig(interface=getattr(circuit, "interface", None)) + assert p_dev == p_grad + dev.preprocess(config)[0] # slicing p_sliced = get_transform_program(circuit, slice(2, 7, 2)) From 6e3e861bdb97c3184d2d641d613a2c7e0ed2b90d Mon Sep 17 00:00:00 2001 From: Astral Cai Date: Thu, 6 Jun 2024 11:05:33 -0400 Subject: [PATCH 21/36] fix bug for conditional --- pennylane/ops/op_math/condition.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/pennylane/ops/op_math/condition.py b/pennylane/ops/op_math/condition.py index 70cdf6fca28..15ecd128265 100644 --- a/pennylane/ops/op_math/condition.py +++ b/pennylane/ops/op_math/condition.py @@ -55,6 +55,11 @@ def __init__(self, expr, then_op: Type[Operation], id=None): self._name = f"Conditional({then_op.name})" super().__init__(then_op, id=id) + @property + def grad_method(self): + """Gradient computation method.""" + return "F" if self.num_params > 0 else None + def label(self, decimals=None, base_label=None, cache=None): return self.base.label(decimals=decimals, base_label=base_label, cache=cache) From 5ec0d7d758745aa706649f49cbd5aae9791721e5 Mon Sep 17 00:00:00 2001 From: Astral Cai Date: Thu, 6 Jun 2024 11:12:52 -0400 Subject: [PATCH 22/36] fix typos --- pennylane/ops/op_math/condition.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pennylane/ops/op_math/condition.py b/pennylane/ops/op_math/condition.py index 15ecd128265..3e24021ed6f 100644 --- a/pennylane/ops/op_math/condition.py +++ b/pennylane/ops/op_math/condition.py @@ -35,7 +35,7 @@ class Conditional(SymbolicOp): instead, use the :func:`qml.cond <.cond>` function. The ``Conditional`` class is a container class that defines an operation - that should by applied relative to a single measurement value. + that should be applied relative to a single measurement value. Support for executing ``Conditional`` operations is device-dependent. If a device doesn't support mid-circuit measurements natively, then the QNode @@ -65,7 +65,7 @@ def label(self, decimals=None, base_label=None, cache=None): @property def meas_val(self): - "the measurement outcome value to consider from `expr` argument" + """the measurement outcome value to consider from `expr` argument""" return self.hyperparameters["meas_val"] @property From e15221921fcf6d6d5beea52b5eee900e0ba10d0c Mon Sep 17 00:00:00 2001 From: Astral Cai Date: Thu, 6 Jun 2024 11:14:25 -0400 Subject: [PATCH 23/36] make conditional an operation --- pennylane/ops/op_math/condition.py | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/pennylane/ops/op_math/condition.py b/pennylane/ops/op_math/condition.py index 3e24021ed6f..1764b554ed3 100644 --- a/pennylane/ops/op_math/condition.py +++ b/pennylane/ops/op_math/condition.py @@ -28,7 +28,7 @@ class ConditionalTransformError(ValueError): """Error for using qml.cond incorrectly""" -class Conditional(SymbolicOp): +class Conditional(SymbolicOp, Operation): """A Conditional Operation. Unless you are a Pennylane plugin developer, **you should NOT directly use this class**, @@ -54,11 +54,8 @@ def __init__(self, expr, then_op: Type[Operation], id=None): self.hyperparameters["meas_val"] = expr self._name = f"Conditional({then_op.name})" super().__init__(then_op, id=id) - - @property - def grad_method(self): - """Gradient computation method.""" - return "F" if self.num_params > 0 else None + if self.grad_recipe is None: + self.grad_recipe = [None] * self.num_params def label(self, decimals=None, base_label=None, cache=None): return self.base.label(decimals=decimals, base_label=base_label, cache=cache) From 82fa69cc03277aa28fb196f61492aa13afb6922b Mon Sep 17 00:00:00 2001 From: Astral Cai Date: Thu, 6 Jun 2024 11:33:38 -0400 Subject: [PATCH 24/36] small bug fix --- pennylane/workflow/qnode.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pennylane/workflow/qnode.py b/pennylane/workflow/qnode.py index dd770ce4042..38d14df2d49 100644 --- a/pennylane/workflow/qnode.py +++ b/pennylane/workflow/qnode.py @@ -1074,7 +1074,7 @@ def _execution_component(self, args: tuple, kwargs: dict, override_shots) -> qml ) override_shots = 1 elif hasattr(self.device, "capabilities"): - full_transform_program.add_transform( + inner_transform_program.add_transform( qml.defer_measurements, device=self.device, ) From ed18e4151b1b58e765d762003f4ab3b73da25c79 Mon Sep 17 00:00:00 2001 From: Astral Cai Date: Thu, 6 Jun 2024 11:49:50 -0400 Subject: [PATCH 25/36] add missing line --- pennylane/workflow/qnode.py | 1 + 1 file changed, 1 insertion(+) diff --git a/pennylane/workflow/qnode.py b/pennylane/workflow/qnode.py index 38d14df2d49..ae7901d5932 100644 --- a/pennylane/workflow/qnode.py +++ b/pennylane/workflow/qnode.py @@ -1071,6 +1071,7 @@ def _execution_component(self, args: tuple, kwargs: dict, override_shots) -> qml qml.devices.preprocess.mid_circuit_measurements, device=self.device, mcm_config=self.execute_kwargs["mcm_config"], + interface=getattr(self, "interface", None), ) override_shots = 1 elif hasattr(self.device, "capabilities"): From ce3dbb8dddc135784030ba3b2a4a27c7404a2266 Mon Sep 17 00:00:00 2001 From: Astral Cai Date: Thu, 6 Jun 2024 11:51:14 -0400 Subject: [PATCH 26/36] minor update --- pennylane/workflow/qnode.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pennylane/workflow/qnode.py b/pennylane/workflow/qnode.py index ae7901d5932..032d5bac9aa 100644 --- a/pennylane/workflow/qnode.py +++ b/pennylane/workflow/qnode.py @@ -1071,7 +1071,7 @@ def _execution_component(self, args: tuple, kwargs: dict, override_shots) -> qml qml.devices.preprocess.mid_circuit_measurements, device=self.device, mcm_config=self.execute_kwargs["mcm_config"], - interface=getattr(self, "interface", None), + interface=self.interface, ) override_shots = 1 elif hasattr(self.device, "capabilities"): From 11b49cc11dfeca79574419dc3d0e299ecbd0347b Mon Sep 17 00:00:00 2001 From: Astral Cai Date: Thu, 6 Jun 2024 15:10:21 -0400 Subject: [PATCH 27/36] fix bug in tests --- .../test_jax_jit_qnode_default_qubit_2.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tests/interfaces/default_qubit_2_integration/test_jax_jit_qnode_default_qubit_2.py b/tests/interfaces/default_qubit_2_integration/test_jax_jit_qnode_default_qubit_2.py index cb1306d7b59..065403f9ff3 100644 --- a/tests/interfaces/default_qubit_2_integration/test_jax_jit_qnode_default_qubit_2.py +++ b/tests/interfaces/default_qubit_2_integration/test_jax_jit_qnode_default_qubit_2.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. """Integration tests for using the JAX-JIT interface with a QNode""" +import copy + # pylint: disable=too-many-arguments,too-few-public-methods from functools import partial @@ -1374,6 +1376,7 @@ def test_state(self, dev, diff_method, grad_on_execution, device_vjp, interface, x = jax.numpy.array(0.543) y = jax.numpy.array(-0.654) if not dev.wires: + dev = copy.copy(dev) dev._wires = qml.wires.Wires([0, 1]) # pylint:disable=protected-access @qnode( From a30315a653e861e5f0d67c9642d2204e5d420698 Mon Sep 17 00:00:00 2001 From: Astral Cai Date: Thu, 6 Jun 2024 15:18:07 -0400 Subject: [PATCH 28/36] make code factor happy --- pennylane/workflow/execution.py | 53 ++++++++++++++++++++------------- 1 file changed, 33 insertions(+), 20 deletions(-) diff --git a/pennylane/workflow/execution.py b/pennylane/workflow/execution.py index 6af0d3f5e32..a7fbaf26e94 100644 --- a/pennylane/workflow/execution.py +++ b/pennylane/workflow/execution.py @@ -593,26 +593,10 @@ def cost_fn(params, x): raise ValueError("Using postselect_mode='hw-like' is not supported with jax-jit.") config.mcm_config.postselect_mode = "fill-shots" - if isinstance(device, qml.devices.Device): - - # If gradient_fn is a gradient transform, device preprocessing should happen in - # inner execute (inside the ml boundary). - if isinstance(gradient_fn, qml.transforms.core.TransformDispatcher): - if inner_transform is None: - inner_transform = device.preprocess(config)[0] - if transform_program is None: - transform_program = qml.transforms.core.TransformProgram() - else: - if inner_transform is None: - inner_transform = qml.transforms.core.TransformProgram() - if transform_program is None: - transform_program = device.preprocess(config)[0] - - else: - if transform_program is None: - transform_program = qml.transforms.core.TransformProgram() - if inner_transform is None: - inner_transform = qml.transforms.core.TransformProgram() + is_gradient_transform = isinstance(gradient_fn, qml.transforms.core.TransformDispatcher) + transform_program, inner_transform = _make_transform_programs( + device, config, inner_transform, transform_program, is_gradient_transform + ) # If caching is desired but an explicit cache is not provided, use an ``LRUCache``. if cache is True: @@ -854,6 +838,35 @@ def device_gradient_fn(inner_tapes, **gradient_kwargs): return post_processing(results) +def _make_transform_programs( + device, config, inner_transform, transform_program, is_gradient_transform +): + """helper function to make the transform programs.""" + + if isinstance(device, qml.devices.Device): + + # If gradient_fn is a gradient transform, device preprocessing should happen in + # inner execute (inside the ml boundary). + if is_gradient_transform: + if inner_transform is None: + inner_transform = device.preprocess(config)[0] + if transform_program is None: + transform_program = qml.transforms.core.TransformProgram() + else: + if inner_transform is None: + inner_transform = qml.transforms.core.TransformProgram() + if transform_program is None: + transform_program = device.preprocess(config)[0] + + else: + if transform_program is None: + transform_program = qml.transforms.core.TransformProgram() + if inner_transform is None: + inner_transform = qml.transforms.core.TransformProgram() + + return transform_program, inner_transform + + def _get_execution_config( gradient_fn, grad_on_execution, interface, device, device_vjp, mcm_config ): From 3924a0f8dd4d2d68ac7b8feda1ce72f8f4500bc8 Mon Sep 17 00:00:00 2001 From: Astral Cai Date: Thu, 6 Jun 2024 16:11:41 -0400 Subject: [PATCH 29/36] add missing test coverage --- .../test_transform_program.py | 29 +++++++++++++++++++ 1 file changed, 29 insertions(+) diff --git a/tests/transforms/test_experimental/test_transform_program.py b/tests/transforms/test_experimental/test_transform_program.py index d3c9747d122..f015a2fe3a6 100644 --- a/tests/transforms/test_experimental/test_transform_program.py +++ b/tests/transforms/test_experimental/test_transform_program.py @@ -29,6 +29,7 @@ _apply_postprocessing_stack, _batch_postprocessing, null_postprocessing, + _prune_dynamic_transform, ) from pennylane.typing import Result, ResultBatch @@ -107,6 +108,27 @@ def postprocessing2(results): out2 = _apply_postprocessing_stack(results, [postprocessing2, postprocessing1]) assert out2 == (4.0, 9.0) + def test_prune_dynamic_transform(self): + """Tests that prune dynamic transform works.""" + + program1 = TransformProgram( + [ + qml.transforms.dynamic_one_shot, + qml.transforms.sum_expand, + qml.transforms.dynamic_one_shot, + ] + ) + program2 = TransformProgram( + [ + qml.transforms.dynamic_one_shot, + qml.transforms.sum_expand, + ] + ) + + _prune_dynamic_transform(program1, program2) + assert len(program1) == 1 + assert len(program2) == 2 + class TestTransformProgramDunders: """Test the dunder methods.""" @@ -356,6 +378,13 @@ def test_empty_program(self): ): program.get_last() + def test_get_last(self): + """Tests the get_last method""" + program = TransformProgram() + program.add_transform(transform(first_valid_transform)) + program.add_transform(transform(second_valid_transform)) + assert program.get_last() == TransformContainer(transform=second_valid_transform) + def test_push_back(self): """Test to push back multiple transforms into a program and also the different methods of a program.""" transform_program = TransformProgram() From 804f7bd40b83f7c3ef7f660c9560b090e269439b Mon Sep 17 00:00:00 2001 From: Astral Cai Date: Thu, 6 Jun 2024 16:20:41 -0400 Subject: [PATCH 30/36] make isort happy --- tests/transforms/test_experimental/test_transform_program.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/transforms/test_experimental/test_transform_program.py b/tests/transforms/test_experimental/test_transform_program.py index f015a2fe3a6..fcf716eff2a 100644 --- a/tests/transforms/test_experimental/test_transform_program.py +++ b/tests/transforms/test_experimental/test_transform_program.py @@ -28,8 +28,8 @@ from pennylane.transforms.core.transform_program import ( _apply_postprocessing_stack, _batch_postprocessing, - null_postprocessing, _prune_dynamic_transform, + null_postprocessing, ) from pennylane.typing import Result, ResultBatch From 3a67fa0c1e6cb892e8de6abb17f6c103e04ef0a3 Mon Sep 17 00:00:00 2001 From: Astral Cai Date: Fri, 7 Jun 2024 09:56:43 -0400 Subject: [PATCH 31/36] move helper function location --- .../transforms/core/transform_program.py | 20 --------- pennylane/workflow/qnode.py | 21 ++++++++- tests/test_qnode.py | 45 +++++++++++++++++++ .../test_transform_program.py | 22 --------- 4 files changed, 65 insertions(+), 43 deletions(-) diff --git a/pennylane/transforms/core/transform_program.py b/pennylane/transforms/core/transform_program.py index 71b1f98cf17..19ce9be17f8 100644 --- a/pennylane/transforms/core/transform_program.py +++ b/pennylane/transforms/core/transform_program.py @@ -551,23 +551,3 @@ def __call__(self, tapes: Tuple[QuantumTape]) -> Tuple[ResultBatch, BatchPostPro # Reset classical jacobians self._classical_jacobians = [] return tuple(tapes), postprocessing_fn - - -# pylint: disable=protected-access -def _prune_dynamic_transform(outer_transform, inner_transform): - """Ensure a single ``dynamic_one_shot`` transform is applied.""" - - all_transforms = outer_transform._transform_program + inner_transform._transform_program - type_to_keep = 0 - if any("mid_circuit_measurements" in str(t) for t in all_transforms): - type_to_keep = 2 - elif any("dynamic_one_shot" in str(t) for t in all_transforms): - type_to_keep = 1 - - if type_to_keep == 0: - return - - dynamic_transform_found = inner_transform._prune_dynamic_transform(type_to_keep) - if dynamic_transform_found: - type_to_keep = 0 - outer_transform._prune_dynamic_transform(type_to_keep) diff --git a/pennylane/workflow/qnode.py b/pennylane/workflow/qnode.py index 032d5bac9aa..fabf3e1275c 100644 --- a/pennylane/workflow/qnode.py +++ b/pennylane/workflow/qnode.py @@ -28,7 +28,6 @@ from pennylane.logging import debug_logger from pennylane.measurements import CountsMP, MidMeasureMP, Shots from pennylane.tape import QuantumScript, QuantumTape -from pennylane.transforms.core.transform_program import _prune_dynamic_transform from .execution import INTERFACE_MAP, SUPPORTED_INTERFACES @@ -1159,3 +1158,23 @@ def __call__(self, *args, **kwargs) -> qml.typing.Result: qnode = lambda device, **kwargs: functools.partial(QNode, device=device, **kwargs) qnode.__doc__ = QNode.__doc__ qnode.__signature__ = inspect.signature(QNode) + + +# pylint: disable=protected-access +def _prune_dynamic_transform(outer_transform, inner_transform): + """Ensure a single ``dynamic_one_shot`` transform is applied.""" + + all_transforms = outer_transform._transform_program + inner_transform._transform_program + type_to_keep = 0 + if any("mid_circuit_measurements" in str(t) for t in all_transforms): + type_to_keep = 2 + elif any("dynamic_one_shot" in str(t) for t in all_transforms): + type_to_keep = 1 + + if type_to_keep == 0: + return + + dynamic_transform_found = inner_transform._prune_dynamic_transform(type_to_keep) + if dynamic_transform_found: + type_to_keep = 0 + outer_transform._prune_dynamic_transform(type_to_keep) diff --git a/tests/test_qnode.py b/tests/test_qnode.py index cf840c47766..f9495e3da60 100644 --- a/tests/test_qnode.py +++ b/tests/test_qnode.py @@ -29,6 +29,7 @@ from pennylane import numpy as pnp from pennylane import qnode from pennylane.tape import QuantumScript +from pennylane.workflow.qnode import _prune_dynamic_transform def dummyfunc(): @@ -2011,3 +2012,47 @@ def circuit(x): circuit(qml.numpy.array(0.1)) assert circuit.interface == "auto" + + +def test_prune_dynamic_transform(): + """Tests that the helper function prune dynamic transform works.""" + + program1 = qml.transforms.core.TransformProgram( + [ + qml.transforms.dynamic_one_shot, + qml.transforms.sum_expand, + qml.transforms.dynamic_one_shot, + ] + ) + program2 = qml.transforms.core.TransformProgram( + [ + qml.transforms.dynamic_one_shot, + qml.transforms.sum_expand, + ] + ) + + _prune_dynamic_transform(program1, program2) + assert len(program1) == 1 + assert len(program2) == 2 + + +def test_prune_dynamic_transform_with_mcm(): + """Tests that the helper function prune dynamic transform works with mcm""" + + program1 = qml.transforms.core.TransformProgram( + [ + qml.transforms.dynamic_one_shot, + qml.transforms.sum_expand, + qml.devices.preprocess.mid_circuit_measurements, + ] + ) + program2 = qml.transforms.core.TransformProgram( + [ + qml.transforms.dynamic_one_shot, + qml.transforms.sum_expand, + ] + ) + + _prune_dynamic_transform(program1, program2) + assert len(program1) == 2 + assert len(program2) == 1 diff --git a/tests/transforms/test_experimental/test_transform_program.py b/tests/transforms/test_experimental/test_transform_program.py index fcf716eff2a..30fdc6266cf 100644 --- a/tests/transforms/test_experimental/test_transform_program.py +++ b/tests/transforms/test_experimental/test_transform_program.py @@ -28,7 +28,6 @@ from pennylane.transforms.core.transform_program import ( _apply_postprocessing_stack, _batch_postprocessing, - _prune_dynamic_transform, null_postprocessing, ) from pennylane.typing import Result, ResultBatch @@ -108,27 +107,6 @@ def postprocessing2(results): out2 = _apply_postprocessing_stack(results, [postprocessing2, postprocessing1]) assert out2 == (4.0, 9.0) - def test_prune_dynamic_transform(self): - """Tests that prune dynamic transform works.""" - - program1 = TransformProgram( - [ - qml.transforms.dynamic_one_shot, - qml.transforms.sum_expand, - qml.transforms.dynamic_one_shot, - ] - ) - program2 = TransformProgram( - [ - qml.transforms.dynamic_one_shot, - qml.transforms.sum_expand, - ] - ) - - _prune_dynamic_transform(program1, program2) - assert len(program1) == 1 - assert len(program2) == 2 - class TestTransformProgramDunders: """Test the dunder methods.""" From 3f9868e438dec20664bc53554107a1b7cfc23978 Mon Sep 17 00:00:00 2001 From: Astral Cai Date: Fri, 7 Jun 2024 10:03:36 -0400 Subject: [PATCH 32/36] update documentation --- pennylane/workflow/execution.py | 4 ++-- pennylane/workflow/qnode.py | 10 +++++++++- 2 files changed, 11 insertions(+), 3 deletions(-) diff --git a/pennylane/workflow/execution.py b/pennylane/workflow/execution.py index a7fbaf26e94..7b8040b21d0 100644 --- a/pennylane/workflow/execution.py +++ b/pennylane/workflow/execution.py @@ -265,7 +265,7 @@ def _make_inner_execute( for the 1st order derivatives. Steps in between the ml framework execution and the device are: - - device expansion (old device) + - device expansion (old device) or device preprocessing (new device) - conversion to numpy - caching @@ -443,7 +443,7 @@ def execute( This affects the types of parameters that can exist on the input tapes. Available options include ``autograd``, ``torch``, ``tf``, ``jax`` and ``auto``. transform_program(.TransformProgram): A transform program to be applied to the initial tape. - inner_transform (.TransformProgram): A transform program to be applied to the tapes in the inner_execute. + inner_transform (.TransformProgram): A transform program to be applied to the tapes in inner execution, inside the ml interface. config (qml.devices.ExecutionConfig): A datastructure describing the parameters needed to fully describe the execution. grad_on_execution (bool, str): Whether the gradients should be computed on the execution or not. Only applies if the device is queried for the gradient; gradient transform diff --git a/pennylane/workflow/qnode.py b/pennylane/workflow/qnode.py index fabf3e1275c..0bd6b4e003e 100644 --- a/pennylane/workflow/qnode.py +++ b/pennylane/workflow/qnode.py @@ -1162,7 +1162,15 @@ def __call__(self, *args, **kwargs) -> qml.typing.Result: # pylint: disable=protected-access def _prune_dynamic_transform(outer_transform, inner_transform): - """Ensure a single ``dynamic_one_shot`` transform is applied.""" + """Ensure a single ``dynamic_one_shot`` transform is applied. + + Sometimes device preprocess contains a ``mid_circuit_measurements`` transform, which will + be added to the inner transform program. If the user then applies a ``dynamic_one_shot`` + manually, it will duplicate the ``mid_circuit_measurements`` transform. This function ensures + that there is only one ``dynamic_one_shot`` transform in the outer and inner transform + programs combined. + + """ all_transforms = outer_transform._transform_program + inner_transform._transform_program type_to_keep = 0 From 4470f27fcb9ac1157600e39bb504c0789f3981df Mon Sep 17 00:00:00 2001 From: Astral Cai Date: Fri, 7 Jun 2024 10:05:31 -0400 Subject: [PATCH 33/36] make function public again --- pennylane/transforms/core/transform_program.py | 2 +- pennylane/workflow/qnode.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/pennylane/transforms/core/transform_program.py b/pennylane/transforms/core/transform_program.py index 19ce9be17f8..47d53b7f53e 100644 --- a/pennylane/transforms/core/transform_program.py +++ b/pennylane/transforms/core/transform_program.py @@ -352,7 +352,7 @@ def set_classical_component(self, qnode, args, kwargs): self._set_all_classical_jacobians(qnode, args, kwargs, argnums) self._set_all_argnums(qnode, args, kwargs, argnums) - def _prune_dynamic_transform(self, type_to_keep=1): + def prune_dynamic_transform(self, type_to_keep=1): """Ensures that only one or none ``dynamic_one_shot`` is applied. Args: diff --git a/pennylane/workflow/qnode.py b/pennylane/workflow/qnode.py index 0bd6b4e003e..acec03a364b 100644 --- a/pennylane/workflow/qnode.py +++ b/pennylane/workflow/qnode.py @@ -1182,7 +1182,7 @@ def _prune_dynamic_transform(outer_transform, inner_transform): if type_to_keep == 0: return - dynamic_transform_found = inner_transform._prune_dynamic_transform(type_to_keep) + dynamic_transform_found = inner_transform.prune_dynamic_transform(type_to_keep) if dynamic_transform_found: type_to_keep = 0 - outer_transform._prune_dynamic_transform(type_to_keep) + outer_transform.prune_dynamic_transform(type_to_keep) From c92231723414e56d1297d771e06f8081a9539781 Mon Sep 17 00:00:00 2001 From: Astral Cai Date: Tue, 11 Jun 2024 13:04:46 -0400 Subject: [PATCH 34/36] Update tests/test_qnode.py Co-authored-by: David Wierichs --- tests/test_qnode.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/test_qnode.py b/tests/test_qnode.py index f9495e3da60..fffeb455abd 100644 --- a/tests/test_qnode.py +++ b/tests/test_qnode.py @@ -2055,4 +2055,5 @@ def test_prune_dynamic_transform_with_mcm(): _prune_dynamic_transform(program1, program2) assert len(program1) == 2 + assert qml.transforms.mid_circuit_measurements in program1 assert len(program2) == 1 From 33fbcca0c1b734bbf48c98eec80872d4a4ea847c Mon Sep 17 00:00:00 2001 From: Astral Cai Date: Tue, 11 Jun 2024 13:32:57 -0400 Subject: [PATCH 35/36] fix bug in test --- tests/test_qnode.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_qnode.py b/tests/test_qnode.py index fffeb455abd..2bea3676217 100644 --- a/tests/test_qnode.py +++ b/tests/test_qnode.py @@ -2055,5 +2055,5 @@ def test_prune_dynamic_transform_with_mcm(): _prune_dynamic_transform(program1, program2) assert len(program1) == 2 - assert qml.transforms.mid_circuit_measurements in program1 + assert qml.devices.preprocess.mid_circuit_measurements in program1 assert len(program2) == 1 From 7ebb4366322034864764afa5cb08445f9ff3ac81 Mon Sep 17 00:00:00 2001 From: Astral Cai Date: Wed, 12 Jun 2024 11:13:11 -0400 Subject: [PATCH 36/36] Apply suggestions from code review Co-authored-by: Christina Lee --- pennylane/workflow/qnode.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/pennylane/workflow/qnode.py b/pennylane/workflow/qnode.py index 07291de58d7..f07dbdb0810 100644 --- a/pennylane/workflow/qnode.py +++ b/pennylane/workflow/qnode.py @@ -1167,7 +1167,6 @@ def __call__(self, *args, **kwargs) -> qml.typing.Result: qnode.__signature__ = inspect.signature(QNode) -# pylint: disable=protected-access def _prune_dynamic_transform(outer_transform, inner_transform): """Ensure a single ``dynamic_one_shot`` transform is applied. @@ -1179,7 +1178,7 @@ def _prune_dynamic_transform(outer_transform, inner_transform): """ - all_transforms = outer_transform._transform_program + inner_transform._transform_program + all_transforms = outer_transform + inner_transform type_to_keep = 0 if any("mid_circuit_measurements" in str(t) for t in all_transforms): type_to_keep = 2