Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimize PhaseShift, T, S gates #5876

Merged
merged 25 commits into from
Jul 9, 2024
Merged
Show file tree
Hide file tree
Changes from 17 commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
f9cfdfa
Register apply_phaseshift and special cases.
vincentmr Jun 18, 2024
3c43b52
Add apply_controlled_operation
vincentmr Jun 18, 2024
c873db2
Remove apply_controlled_operation
vincentmr Jun 19, 2024
d4600b7
Merge branch 'master' into optim_apply_operations
vincentmr Jun 19, 2024
2d9bfe4
Deal with PhaseShift batched.
vincentmr Jun 20, 2024
ea4e256
Fix autograd bug.
vincentmr Jun 20, 2024
ee5542b
Fix TF tests.
vincentmr Jun 20, 2024
d8ac0a6
Merge remote-tracking branch 'origin/master' into optim_apply_operations
vincentmr Jun 20, 2024
80505c6
Do not reuse PhaseShift.
vincentmr Jun 20, 2024
6358bba
Put back S.
vincentmr Jun 20, 2024
aaa32ff
Fix complex issue with TF
vincentmr Jun 20, 2024
0e3700e
Use PhaseShift in T, S, Z.
vincentmr Jun 20, 2024
dfdc39b
Update changelog
vincentmr Jun 20, 2024
1ebb499
Revert to full implementations.
vincentmr Jun 20, 2024
cb76118
Merge remote-tracking branch 'origin/master' into optim_apply_operations
vincentmr Jun 25, 2024
60ac35b
pragma: no cover
vincentmr Jun 25, 2024
b6719b9
Merge branch 'master' into optim_apply_operations
vincentmr Jul 5, 2024
18c48ea
Merge remote-tracking branch 'origin/master' into optim_apply_operations
vincentmr Jul 8, 2024
fa0cc28
Broadcast params in phaseshift.
vincentmr Jul 8, 2024
e4e7898
Fix broadcasting for all interfaces.
vincentmr Jul 8, 2024
1287256
Merge branch 'master' into optim_apply_operations
vincentmr Jul 8, 2024
1689722
Fix jax issue with math.array.
vincentmr Jul 8, 2024
b309faf
Add tests and remove obsolete pragma: no cover
vincentmr Jul 8, 2024
b518eb8
Merge branch 'master' into optim_apply_operations
vincentmr Jul 9, 2024
75af4b6
Update pennylane/devices/qubit/apply_operation.py
vincentmr Jul 9, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion doc/releases/changelog-dev.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@

<h3>Improvements 🛠</h3>

* Port the fast `apply_operation` implementation of `PauliZ` to `PhaseShift`, `S` and `T`.
[(#5876)](https://github.com/PennyLaneAI/pennylane/pull/5876)

* `qml.UCCSD` now accepts an additional optional argument, `n_repeats`, which defines the number of
times the UCCSD template is repeated. This can improve the accuracy of the template by reducing
the Trotter error but would result in deeper circuits.
Expand All @@ -31,8 +34,8 @@
<h3>Contributors ✍️</h3>

This release contains contributions from (in alphabetical order):

Yushao Chen,
Christina Lee,
William Maxwell,
Vincent Michaud-Rioux,
Erik Schultheis.
72 changes: 71 additions & 1 deletion pennylane/devices/qubit/apply_operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -407,14 +407,84 @@ def apply_pauliz(op: qml.Z, state, is_state_batched: bool = False, debugger=None
return math.stack([state[sl_0], state1], axis=axis)


@apply_operation.register
def apply_phaseshift(op: qml.PhaseShift, state, is_state_batched: bool = False, debugger=None, **_):
"""Apply PhaseShift to state."""

params = op.parameters[0]
if is_state_batched or (op.batch_size is not None and len(params) > 1):
params = math.atleast_1d(params)
slices = []
for i, p in enumerate(params):
slices.append(
apply_phaseshift(
qml.PhaseShift(p, wires=op.wires),
state[i] if is_state_batched else state,
is_state_batched=False,
)
)
return math.stack(slices, axis=0)
vincentmr marked this conversation as resolved.
Show resolved Hide resolved

axis = op.wires[0] + is_state_batched
n_dim = math.ndim(state)

if n_dim >= 9 and math.get_interface(state) == "tensorflow": # pragma: no cover
albi3ro marked this conversation as resolved.
Show resolved Hide resolved
return apply_operation_tensordot(op, state, is_state_batched=is_state_batched)
vincentmr marked this conversation as resolved.
Show resolved Hide resolved

sl_0 = _get_slice(0, axis, n_dim)
sl_1 = _get_slice(1, axis, n_dim)

state1 = math.multiply(
math.cast(state[sl_1], dtype=complex), math.exp(1j * math.cast(params, dtype=complex))
)
state = math.stack([state[sl_0], state1], axis=axis)
if op.batch_size == 1:
state = math.stack([state], axis=0)
return state


@apply_operation.register
def apply_T(op: qml.T, state, is_state_batched: bool = False, debugger=None, **_):
"""Apply T to state."""

axis = op.wires[0] + is_state_batched
n_dim = math.ndim(state)

if n_dim >= 9 and math.get_interface(state) == "tensorflow": # pragma: no cover
return apply_operation_tensordot(op, state, is_state_batched=is_state_batched)

sl_0 = _get_slice(0, axis, n_dim)
sl_1 = _get_slice(1, axis, n_dim)

state1 = math.multiply(math.cast(state[sl_1], dtype=complex), math.exp(0.25j * np.pi))
return math.stack([state[sl_0], state1], axis=axis)


@apply_operation.register
def apply_S(op: qml.S, state, is_state_batched: bool = False, debugger=None, **_):
"""Apply S to state."""

axis = op.wires[0] + is_state_batched
n_dim = math.ndim(state)

if n_dim >= 9 and math.get_interface(state) == "tensorflow": # pragma: no cover
return apply_operation_tensordot(op, state, is_state_batched=is_state_batched)

sl_0 = _get_slice(0, axis, n_dim)
sl_1 = _get_slice(1, axis, n_dim)

state1 = math.multiply(math.cast(state[sl_1], dtype=complex), 1j)
return math.stack([state[sl_0], state1], axis=axis)


@apply_operation.register
def apply_cnot(op: qml.CNOT, state, is_state_batched: bool = False, debugger=None, **_):
"""Apply cnot gate to state."""
target_axes = (op.wires[1] - 1 if op.wires[1] > op.wires[0] else op.wires[1]) + is_state_batched
control_axes = op.wires[0] + is_state_batched
n_dim = math.ndim(state)

if n_dim >= 9 and math.get_interface(state) == "tensorflow":
if n_dim >= 9 and math.get_interface(state) == "tensorflow": # pragma: no cover
return apply_operation_tensordot(op, state, is_state_batched=is_state_batched)

sl_0 = _get_slice(0, control_axes, n_dim)
Expand Down
4 changes: 2 additions & 2 deletions tests/devices/qubit/test_apply_operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -796,7 +796,7 @@ def test_broadcasted_op(self, op, method, ml_framework):
@pytest.mark.parametrize("op", unbroadcasted_ops)
def test_broadcasted_state(self, op, method, ml_framework):
"""Tests that unbatched operations are applied correctly to a batched state."""
state = np.ones((3, 2, 2, 2)) / np.sqrt(8)
state = np.ones((3, 2, 2, 2), dtype=complex) / np.sqrt(8)

res = method(op, qml.math.asarray(state, like=ml_framework), is_state_batched=True)
missing_wires = 3 - len(op.wires)
Expand All @@ -813,7 +813,7 @@ def test_broadcasted_op_broadcasted_state(self, op, method, ml_framework):
if method is apply_operation_tensordot:
pytest.skip("Tensordot doesn't support batched operator and batched state.")

state = np.ones((3, 2, 2, 2)) / np.sqrt(8)
state = np.ones((3, 2, 2, 2), dtype=complex) / np.sqrt(8)

res = method(op, qml.math.asarray(state, like=ml_framework), is_state_batched=True)
missing_wires = 3 - len(op.wires)
Expand Down
Loading