Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix a sorting order in parameter-shift terms #5583

Merged
merged 6 commits into from
May 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion doc/releases/changelog-dev.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@

<h3>Improvements 🛠</h3>

* The sorting order of parameter-shift terms is now guaranteed to resolve ties in the absolute value with the sign of the shifts.
[(#5582)](https://github.com/PennyLaneAI/pennylane/pull/5582)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Was this the wrong PR number?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh dear. Yes!


<h4>Mid-circuit measurements and dynamic circuits</h4>

* The `dynamic_one_shot` transform can be compiled with `jax.jit`.
Expand Down Expand Up @@ -79,7 +82,7 @@

* ``qml.from_qasm_file`` has been removed. The user can open files and load their content using `qml.from_qasm`.
[(#5659)](https://github.com/PennyLaneAI/pennylane/pull/5659)

* ``qml.load`` has been removed in favour of more specific functions, such as ``qml.from_qiskit``, etc.
[(#5654)](https://github.com/PennyLaneAI/pennylane/pull/5654)

Expand Down
6 changes: 4 additions & 2 deletions pennylane/gradients/general_shift_rules.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ def process_shifts(rule, tol=1e-10, batch_duplicates=True):

- Finally, the terms are sorted according to the absolute value of ``shift``,
This ensures that a zero-shift term, if it exists, is returned first.
For equal absolute values of two shifts, the positive shift is sorted to come first.
"""
# set all small coefficients, multipliers if present, and shifts to zero.
rule[np.abs(rule) < tol] = 0
Expand All @@ -78,8 +79,9 @@ def process_shifts(rule, tol=1e-10, batch_duplicates=True):
coeffs = [np.sum(rule[slc, 0]) for slc in matches.T]
rule = np.hstack([np.stack(coeffs)[:, np.newaxis], unique_mods])

# sort columns according to abs(shift)
return rule[np.argsort(np.abs(rule[:, -1]), kind="stable")]
# sort columns according to abs(shift), ties are resolved with the sign,
# positive shifts being returned before negative shifts.
return rule[np.lexsort((-np.sign(rule[:, -1]), np.abs(rule[:, -1])))]


@functools.lru_cache(maxsize=None)
Expand Down
12 changes: 6 additions & 6 deletions tests/gradients/core/test_general_shift_rules.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,7 +287,7 @@ def test_second_order_two_term_shift_rule_custom_shifts(self):
properly simplified when custom shift values are provided"""
frequencies = (1,)
generated_terms = generate_shift_rule(frequencies, shifts=(np.pi / 4,), order=2)
correct_terms = [[-1, 0], [0.5, -np.pi / 2], [0.5, np.pi / 2]]
correct_terms = [[-1, 0], [0.5, np.pi / 2], [0.5, -np.pi / 2]]
assert np.allclose(generated_terms, correct_terms)

def test_second_order_four_term_shift_rule(self):
Expand All @@ -297,8 +297,8 @@ def test_second_order_four_term_shift_rule(self):
generated_terms = generate_shift_rule(frequencies, order=2)
correct_terms = [
[-0.375, 0],
[0.25, -np.pi],
[0.25, np.pi],
[0.25, -np.pi],
[-0.125, -2 * np.pi],
]
assert np.allclose(generated_terms, correct_terms)
Expand All @@ -310,12 +310,12 @@ def test_second_order_non_equidistant_shift_rule(self):
generated_terms = generate_shift_rule(frequencies, order=2)
correct_terms = [
[-6, 0],
[3.91421356, -np.pi / 4],
[3.91421356, np.pi / 4],
[-1, -np.pi / 2],
[3.91421356, -np.pi / 4],
[-1, np.pi / 2],
[0.08578644, -3 * np.pi / 4],
[-1, -np.pi / 2],
[0.08578644, 3 * np.pi / 4],
[0.08578644, -3 * np.pi / 4],
]
assert np.allclose(generated_terms, correct_terms)

Expand All @@ -335,7 +335,7 @@ def test_single_parameter(self):
assert np.allclose(res, expected)

res = generate_multi_shift_rule([(1,)], orders=[2], shifts=[(np.pi / 4,)])
expected = [[-1, 0], [0.5, -np.pi / 2], [0.5, np.pi / 2]]
expected = [[-1, 0], [0.5, np.pi / 2], [0.5, -np.pi / 2]]
assert np.allclose(res, expected)

def test_two_single_frequency(self):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -170,8 +170,8 @@ class DummyOp(qml.RX):
two_term_2nd_order = [(-0.5, 1.0, 0.0), (0.5, 1.0, -np.pi)]
four_term_2nd_order = [
(-0.375, 1.0, 0),
(0.25, 1.0, -np.pi),
(0.25, 1.0, np.pi),
(0.25, 1.0, -np.pi),
(-0.125, 1.0, -2 * np.pi),
]

Expand Down Expand Up @@ -1606,8 +1606,8 @@ def circuit(x):
# - 1 for second diagonal.
assert len(tapes) == 1 + 2 + 4 + 1
assert np.allclose(tapes[0].get_parameters(), x)
assert np.allclose(tapes[1].get_parameters(), x + np.array([-2 * np.pi / 3, 0.0]))
assert np.allclose(tapes[2].get_parameters(), x + np.array([2 * np.pi / 3, 0.0]))
assert np.allclose(tapes[1].get_parameters(), x + np.array([2 * np.pi / 3, 0.0]))
assert np.allclose(tapes[2].get_parameters(), x + np.array([-2 * np.pi / 3, 0.0]))
assert np.allclose(tapes[-1].get_parameters(), x + np.array([0.0, -np.pi]))
expected_shifts = np.array([[1, 1], [1, -1], [-1, 1], [-1, -1]]) * (np.pi / 2)
for _tape, exp_shift in zip(tapes[3:-1], expected_shifts):
Expand Down Expand Up @@ -1662,7 +1662,7 @@ def circuit(x):
assert np.allclose(_tape.get_parameters(), x + exp_shift)

# Check that the vanilla diagonal rule is used for the second diagonal entry
shift_order = [-1, 1, -2]
shift_order = [1, -1, -2]
for mult, _tape in zip(shift_order, tapes[10:]):
assert np.allclose(_tape.get_parameters(), x + np.array([0.0, np.pi * mult]))

Expand Down
Loading