Skip to content

Commit

Permalink
pass along postselect mode
Browse files Browse the repository at this point in the history
  • Loading branch information
albi3ro committed Jul 24, 2024
1 parent cf46f57 commit f654249
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 2 deletions.
10 changes: 8 additions & 2 deletions pennylane/devices/legacy_facade.py
Original file line number Diff line number Diff line change
Expand Up @@ -391,11 +391,17 @@ def execute(self, circuits, execution_config=DefaultExecutionConfig):
else self._device
)

kwargs = {}
if dev.capabilities().get("supports_mid_measure", False):
kwargs["postselect_mode"] = execution_config.mcm_config.postselect_mode

first_shot = circuits[0].shots
if all(t.shots == first_shot for t in circuits):
results = _set_shots(dev, first_shot)(dev.batch_execute)(circuits)
results = _set_shots(dev, first_shot)(dev.batch_execute)(circuits, **kwargs)
else:
results = tuple(_set_shots(dev, t.shots)(dev.batch_execute)((t,))[0] for t in circuits)
results = tuple(
_set_shots(dev, t.shots)(dev.batch_execute)((t,), **kwargs)[0] for t in circuits
)

if dev is not self._device:
self._update_original_device(dev)
Expand Down
22 changes: 22 additions & 0 deletions tests/devices/test_legacy_facade.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,6 +236,28 @@ class MidMeasureDev(DummyDevice):
assert qml.defer_measurements not in program


@pytest.mark.parametrize("t_postselect_mode", ("hw-like", "fill-shots"))
def test_pass_postselect_mode_to_dev(t_postselect_mode):
"""test that postselect mode is passed to the target if it supports mid measure."""

class MidMeasureDev(DummyDevice):
"""A dummy device that supports mid circuit measurements."""

_capabilities = {"supports_mid_measure": True}

def batch_execute(self, circuits, postselect_mode):
assert postselect_mode == t_postselect_mode
return tuple(0 for _ in circuits)

target = MidMeasureDev()
dev = LegacyDeviceFacade(target)

mcm_config = qml.devices.MCMConfig(postselect_mode=t_postselect_mode)
config = qml.devices.ExecutionConfig(mcm_config=mcm_config)

dev.execute(qml.tape.QuantumScript(), config)


class TestGradientSupport:
"""Test integration with various kinds of device derivatives."""

Expand Down

0 comments on commit f654249

Please sign in to comment.