Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix derivatives of merge_rotations and single_qubit_fusion #6033

Merged
merged 41 commits into from
Aug 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
41 commits
Select commit Hold shift + click to select a range
d0126df
new fuse_rot_angles function, incl test for evaluation and Jacobian, …
dwierichs Jul 23, 2024
961d6ca
changelog
dwierichs Jul 23, 2024
5c231ab
pr number
dwierichs Jul 23, 2024
8cb80a5
undo merge_rotations change
dwierichs Jul 23, 2024
fc65bc9
merge
dwierichs Jul 23, 2024
102a7a4
ml interfaces tests, precision comment
dwierichs Jul 23, 2024
ce1b97f
merge
dwierichs Jul 23, 2024
c1e918e
docstring dtype
dwierichs Jul 23, 2024
0e4085d
prints
dwierichs Jul 23, 2024
4dc44a5
math derivation in docstring
dwierichs Jul 24, 2024
ea3fc47
Merge branch 'master' into rot_math
dwierichs Jul 24, 2024
b436aba
mixed batching test
dwierichs Jul 24, 2024
3148643
documentation
dwierichs Jul 24, 2024
976a59e
documentation
dwierichs Jul 24, 2024
343d2fb
format
dwierichs Jul 24, 2024
91b2f78
no_fuse simplifications. JIT test
dwierichs Jul 24, 2024
a3a752c
fixed
dwierichs Jul 24, 2024
cf3e621
fix precision problems in _try_no_fuse
dwierichs Jul 24, 2024
1967a3c
clean out imports
dwierichs Jul 24, 2024
954535a
import fix
dwierichs Jul 24, 2024
538c436
Merge branch 'rot_math' into fix_merge_rotations
dwierichs Jul 24, 2024
d23e447
clean out imports
dwierichs Jul 24, 2024
c326471
import stuff
dwierichs Jul 24, 2024
39d0827
fix, rename angles, comment
dwierichs Jul 24, 2024
c4d0132
Merge branch 'rot_math' into fix_merge_rotations
dwierichs Jul 24, 2024
65c3e8b
Apply suggestions from code review
dwierichs Jul 25, 2024
42c80d6
changelog
dwierichs Jul 25, 2024
ccfcb34
squash tests
dwierichs Jul 25, 2024
f4f1c7f
tiny
dwierichs Jul 26, 2024
7c37e14
expand doc
dwierichs Jul 26, 2024
d172a89
Apply suggestions from code review
dwierichs Jul 29, 2024
b3faa45
Apply suggestions from code review
dwierichs Jul 30, 2024
c988692
merge
dwierichs Aug 1, 2024
b735418
interface tests
dwierichs Aug 1, 2024
931da7d
update torch
dwierichs Aug 2, 2024
6f239ee
merge
dwierichs Aug 2, 2024
cb2b20a
lint
dwierichs Aug 2, 2024
5bb1540
merge
dwierichs Aug 7, 2024
02d513c
Merge branch 'rot_math' into fix_merge_rotations
dwierichs Aug 7, 2024
1fc17c5
code
dwierichs Aug 7, 2024
12fbf15
Merge branch 'master' into fix_merge_rotations
dwierichs Aug 19, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions doc/releases/changelog-dev.md
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,9 @@
* `fuse_rot_angles` now respects the global phase of the combined rotations.
[(#6031)](https://github.com/PennyLaneAI/pennylane/pull/6031)

* `fuse_rot_angles` now respects the global phase of the combined rotations.
[(#6031)](https://github.com/PennyLaneAI/pennylane/pull/6031)

* `QNGOptimizer` now supports cost functions with multiple arguments, updating each argument independently.
[(#5926)](https://github.com/PennyLaneAI/pennylane/pull/5926)

Expand Down
47 changes: 17 additions & 30 deletions pennylane/transforms/optimization/merge_rotations.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
# pylint: disable=too-many-branches

import pennylane as qml
from pennylane.math import allclose, cast_like, get_interface, is_abstract, stack, zeros
from pennylane.ops.op_math import Adjoint
from pennylane.ops.qubit.attributes import composable_rotations
from pennylane.queuing import QueuingManager
Expand Down Expand Up @@ -177,58 +176,46 @@ def stop_at(obj):
continue

# We need to use stack to get this to work and be differentiable in all interfaces
cumulative_angles = stack(current_gate.parameters)
interface = get_interface(cumulative_angles)
cumulative_angles = qml.math.stack(current_gate.parameters)
interface = qml.math.get_interface(cumulative_angles)
# As long as there is a valid next gate, check if we can merge the angles
while next_gate_idx is not None:
# Get the next gate
next_gate = list_copy[next_gate_idx + 1]

# If next gate is of the same type, we can merge the angles
if current_gate.name == next_gate.name and current_gate.wires == next_gate.wires:
if isinstance(current_gate, type(next_gate)) and current_gate.wires == next_gate.wires:
list_copy.pop(next_gate_idx + 1)
next_params = qml.math.stack(next_gate.parameters, like=interface)
# jax-jit does not support cast_like
if not qml.math.is_abstract(cumulative_angles):
next_params = qml.math.cast_like(next_params, cumulative_angles)

# The Rot gate must be treated separately
if current_gate.name == "Rot":
if is_abstract(cumulative_angles):
# jax-jit does not support cast_like
cumulative_angles = cumulative_angles + stack(next_gate.parameters)
else:
cumulative_angles = fuse_rot_angles(
cumulative_angles,
cast_like(
stack(next_gate.parameters, like=interface), cumulative_angles
),
)
if isinstance(current_gate, qml.Rot):
cumulative_angles = fuse_rot_angles(cumulative_angles, next_params)
# Other, single-parameter rotation gates just have the angle summed
else:
if is_abstract(cumulative_angles):
# jax-jit does not support cast_like
cumulative_angles = cumulative_angles + stack(next_gate.parameters)
else:
cumulative_angles = cumulative_angles + cast_like(
stack(next_gate.parameters, like=interface), cumulative_angles
)
cumulative_angles = cumulative_angles + next_params
# If it is not, we need to stop
else:
break

# If we did merge, look now at the next gate
next_gate_idx = find_next_gate(current_gate.wires, list_copy[1:])

# If we are tracing/jitting, don't perform any conditional checks and
# If we are tracing/jitting or differentiating, don't perform any conditional checks and
# apply the operation regardless of the angles. Otherwise, only apply if
# the rotation angle is non-trivial.
if is_abstract(cumulative_angles):
if (
qml.math.is_abstract(cumulative_angles)
or qml.math.requires_grad(cumulative_angles)
or not qml.math.allclose(cumulative_angles, 0.0, atol=atol, rtol=0)
):
with QueuingManager.stop_recording():
new_operations.append(
current_gate.__class__(*cumulative_angles, wires=current_gate.wires)
)
else:
if not allclose(cumulative_angles, zeros(len(cumulative_angles)), atol=atol, rtol=0):
with QueuingManager.stop_recording():
new_operations.append(
current_gate.__class__(*cumulative_angles, wires=current_gate.wires)
)
dwierichs marked this conversation as resolved.
Show resolved Hide resolved

# Remove the first gate from the working list
list_copy.pop(0)
Expand Down
35 changes: 17 additions & 18 deletions pennylane/transforms/optimization/single_qubit_fusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
"""Transform for fusing sequences of single-qubit gates."""
# pylint: disable=too-many-branches

from pennylane.math import allclose, is_abstract, stack
import pennylane as qml
from pennylane.ops.qubit import Rot
from pennylane.queuing import QueuingManager
from pennylane.tape import QuantumTape, QuantumTapeBatch
Expand Down Expand Up @@ -257,7 +257,7 @@ def qfunc(r1, r2):
# Look for single_qubit_rot_angles; if not available, queue and move on.
# If available, grab the angles and try to fuse.
try:
cumulative_angles = stack(current_gate.single_qubit_rot_angles())
cumulative_angles = qml.math.stack(current_gate.single_qubit_rot_angles())
except (NotImplementedError, AttributeError):
new_operations.append(current_gate)
list_copy.pop(0)
Expand Down Expand Up @@ -295,31 +295,30 @@ def qfunc(r1, r2):
# the gate in question, only valid single-qubit gates on the same
# wire as the current gate will be fused.
try:
next_gate_angles = stack(next_gate.single_qubit_rot_angles())
next_gate_angles = qml.math.stack(next_gate.single_qubit_rot_angles())
except (NotImplementedError, AttributeError):
break

cumulative_angles = fuse_rot_angles(cumulative_angles, stack(next_gate_angles))
cumulative_angles = fuse_rot_angles(cumulative_angles, next_gate_angles)

list_copy.pop(next_gate_idx + 1)
next_gate_idx = find_next_gate(current_gate.wires, list_copy[1:])

# If we are tracing/jitting, don't perform any conditional checks and
# If we are tracing/jitting or differentiating, don't perform any conditional checks and
# apply the rotation regardless of the angles.
if is_abstract(cumulative_angles):
with QueuingManager.stop_recording():
new_operations.append(Rot(*cumulative_angles, wires=current_gate.wires))
# If not tracing, check whether all angles are 0 (or equivalently, if the RY
# angle is close to 0, and so is the sum of the RZ angles
else:
if not allclose(
stack([cumulative_angles[0] + cumulative_angles[2], cumulative_angles[1]]),
[0.0, 0.0],
# If not tracing or differentiating, check whether total rotation is trivial by checking
# if the RY angle and the sum of the RZ angles are close to 0
if (
qml.math.is_abstract(cumulative_angles)
or qml.math.requires_grad(cumulative_angles)
or not qml.math.allclose(
qml.math.stack([cumulative_angles[0] + cumulative_angles[2], cumulative_angles[1]]),
0.0,
atol=atol,
rtol=0,
):
with QueuingManager.stop_recording():
new_operations.append(Rot(*cumulative_angles, wires=current_gate.wires))
)
):
with QueuingManager.stop_recording():
new_operations.append(Rot(*cumulative_angles, wires=current_gate.wires))
dwierichs marked this conversation as resolved.
Show resolved Hide resolved

# Remove the starting gate from the list
list_copy.pop(0)
Expand Down
10 changes: 4 additions & 6 deletions tests/transforms/test_optimization/test_merge_rotations.py
Original file line number Diff line number Diff line change
Expand Up @@ -424,13 +424,11 @@ def test_merge_rotations_jax_jit(self):
@qml.qnode(qml.device("default.qubit", wires=["w1", "w2"]), interface="jax")
@merge_rotations
def qfunc():
qml.Rot(jax.numpy.array(0.1), jax.numpy.array(0.2), jax.numpy.array(0.3), wires=["w1"])
qml.Rot(
jax.numpy.array(-0.1), jax.numpy.array(-0.2), jax.numpy.array(-0.3), wires=["w1"]
)
qml.Rot(*jax.numpy.array([0.1, 0.2, 0.3]), wires=["w1"])
qml.Rot(*jax.numpy.array([-0.3, -0.2, -0.1]), wires=["w1"])
dwierichs marked this conversation as resolved.
Show resolved Hide resolved
qml.CRX(jax.numpy.array(0.2), wires=["w1", "w2"])
qml.CRX(jax.numpy.array(-0.2), wires=["w1", "w2"])
return qml.expval(qml.PauliZ("w1"))
return qml.expval(qml.PauliZ("w2"))

res = qfunc()

Expand Down Expand Up @@ -525,7 +523,7 @@ def test_qnode(self):

@pytest.mark.xfail
def test_merge_rotations_non_commuting_observables():
"""Test that merge_roatations works with non-commuting observables."""
"""Test that merge_rotations works with non-commuting observables."""

@qml.transforms.merge_rotations
def circuit(x):
Expand Down
Loading