Skip to content

Commit

Permalink
Fix bug where split_non_commuting erases trainability of observables (
Browse files Browse the repository at this point in the history
#5838)

**Context:**
The post processing function of `split_non_commuting` incorrectly
converts the interface of results and erases the trainability of
`Hamiltonian` observables

**Description of the Change:**
Removes `convert_like` call.

**Related GitHub Issues:**
Fixes #5837
[sc-65649]
  • Loading branch information
astralcai authored Jun 11, 2024
1 parent d70f61a commit 4600c22
Show file tree
Hide file tree
Showing 3 changed files with 93 additions and 30 deletions.
1 change: 1 addition & 0 deletions doc/releases/changelog-dev.md
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@

* `qml.transforms.split_non_commuting` can now handle circuits containing measurements of multi-term observables.
[(#5729)](https://github.com/PennyLaneAI/pennylane/pull/5729)
[(#5853)](https://github.com/PennyLaneAI/pennylane/pull/5838)

* The qchem module has dedicated functions for calling `pyscf` and `openfermion` backends.
[(#5553)](https://github.com/PennyLaneAI/pennylane/pull/5553)
Expand Down
2 changes: 1 addition & 1 deletion pennylane/transforms/split_non_commuting.py
Original file line number Diff line number Diff line change
Expand Up @@ -693,7 +693,7 @@ def _sum_terms(res: ResultBatch, coeffs: List[float], offset: float, shape: Tupl
if len(dot_products) == 0:
return qml.math.ones(shape) * offset
summed_dot_products = qml.math.sum(qml.math.stack(dot_products), axis=0)
return qml.math.convert_like(summed_dot_products + offset, res[0])
return summed_dot_products + offset


def _mp_to_obs(mp: MeasurementProcess, tape: qml.tape.QuantumScript) -> qml.operation.Operator:
Expand Down
120 changes: 91 additions & 29 deletions tests/transforms/test_split_non_commuting.py
Original file line number Diff line number Diff line change
Expand Up @@ -526,6 +526,11 @@ def test_tape_with_non_pauli_obs(self, non_pauli_obs):
"""Tests that the tape is split correctly when containing non-Pauli observables"""

obs_list = single_term_obs_list + non_pauli_obs

if not qml.operation.active_new_opmath():
non_pauli_obs = _convert_obs_to_legacy_opmath(non_pauli_obs)
obs_list = _convert_obs_to_legacy_opmath(obs_list)

measurements = [
qml.expval(c * o) for c, o in zip([0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7], obs_list)
]
Expand Down Expand Up @@ -956,9 +961,31 @@ def cost(theta, phi):
assert qml.math.allclose(grad1, expected_grad_1)
assert qml.math.allclose(grad2, expected_grad_2)

@pytest.mark.autograd
@pytest.mark.parametrize("grouping_strategy", [None, "default", "qwc", "wires"])
def test_trainable_hamiltonian_autograd(self, grouping_strategy):
"""Tests that measurements of trainable Hamiltonians are differentiable"""

import pennylane.numpy as pnp

dev = qml.device("default.qubit", wires=2, shots=50000)

@partial(split_non_commuting, grouping_strategy=grouping_strategy)
@qml.qnode(dev)
def circuit(coeff1, coeff2):
qml.RX(np.pi / 4, wires=0)
qml.RY(np.pi / 4, wires=1)
return qml.expval(qml.Hamiltonian([coeff1, coeff2], [qml.Y(0) @ qml.Z(1), qml.X(1)]))

params = pnp.array(pnp.pi / 4), pnp.array(3 * pnp.pi / 4)
actual = qml.jacobian(circuit)(*params)

assert qml.math.allclose(actual, [-0.5, np.cos(np.pi / 4)], rtol=0.05)

@pytest.mark.jax
@pytest.mark.parametrize("use_jit", [False, True])
@pytest.mark.parametrize("grouping_strategy", [None, "default", "qwc", "wires"])
def test_jax(self, grouping_strategy):
def test_jax(self, grouping_strategy, use_jit):
"""Tests that the output of ``split_non_commuting`` is differentiable with jax"""

import jax
Expand All @@ -979,6 +1006,9 @@ def circuit(theta, phi):
qml.RY(phi, wires=1)
return qml.probs(wires=[0, 1]), *[qml.expval(o) for o in obs_list]

if use_jit:
circuit = jax.jit(circuit)

def cost(theta, phi):
res = circuit(theta, phi)
return qml.math.concatenate([res[0], qml.math.stack(res[1:])], axis=0)
Expand All @@ -996,44 +1026,30 @@ def cost(theta, phi):
assert qml.math.allclose(grad2, expected_grad_2)

@pytest.mark.jax
@pytest.mark.parametrize("use_jit", [False, True])
@pytest.mark.parametrize("grouping_strategy", [None, "default", "qwc", "wires"])
def test_jax_jit(self, grouping_strategy):
"""Tests that the output of ``split_non_commuting`` is differentiable with jax and jit"""
def test_trainable_hamiltonian_jax(self, grouping_strategy, use_jit):
"""Tests that measurements of trainable Hamiltonians are differentiable with jax"""

import jax
import jax.numpy as jnp

dev = qml.device("default.qubit", wires=2)
dev = qml.device("default.qubit", wires=2, shots=50000)

obs_list = complex_obs_list
if not qml.operation.active_new_opmath():
obs_list = obs_list[:-1] # exclude the identity term

@jax.jit
@partial(split_non_commuting, grouping_strategy=grouping_strategy)
@qml.qnode(dev)
def circuit(theta, phi):
qml.RX(theta, wires=0)
qml.RY(phi, wires=0)
qml.RX(theta, wires=1)
qml.RY(phi, wires=1)
return qml.probs(wires=[0, 1]), *[qml.expval(o) for o in obs_list]
def circuit(coeff1, coeff2):
qml.RX(np.pi / 4, wires=0)
qml.RY(np.pi / 4, wires=1)
return qml.expval(qml.Hamiltonian([coeff1, coeff2], [qml.Y(0) @ qml.Z(1), qml.X(1)]))

def cost(theta, phi):
res = circuit(theta, phi)
return qml.math.concatenate([res[0], qml.math.stack(res[1:])], axis=0)
if use_jit:
circuit = jax.jit(circuit)

params = jnp.array(jnp.pi / 4), jnp.array(3 * jnp.pi / 4)
grad1, grad2 = jax.jacobian(cost, argnums=[0, 1])(*params)
params = jnp.array(np.pi / 4), jnp.array(3 * np.pi / 4)
actual = jax.jacobian(circuit, argnums=[0, 1])(*params)

expected_grad_1 = expected_grad_param_0
expected_grad_2 = expected_grad_param_1
if not qml.operation.active_new_opmath():
expected_grad_1 = expected_grad_param_0[:-1]
expected_grad_2 = expected_grad_param_1[:-1]

assert qml.math.allclose(grad1, expected_grad_1)
assert qml.math.allclose(grad2, expected_grad_2)
assert qml.math.allclose(actual, [-0.5, np.cos(np.pi / 4)], rtol=0.05)

@pytest.mark.torch
@pytest.mark.parametrize("grouping_strategy", [None, "default", "qwc", "wires"])
Expand Down Expand Up @@ -1074,10 +1090,32 @@ def cost(theta, phi):
assert qml.math.allclose(grad1, expected_grad_1, atol=1e-5)
assert qml.math.allclose(grad2, expected_grad_2, atol=1e-5)

@pytest.mark.torch
@pytest.mark.parametrize("grouping_strategy", [None, "default", "qwc", "wires"])
def test_trainable_hamiltonian_torch(self, grouping_strategy):
"""Tests that measurements of trainable Hamiltonians are differentiable with torch"""

import torch
from torch.autograd.functional import jacobian

dev = qml.device("default.qubit", wires=2, shots=50000)

@partial(split_non_commuting, grouping_strategy=grouping_strategy)
@qml.qnode(dev)
def circuit(coeff1, coeff2):
qml.RX(np.pi / 4, wires=0)
qml.RY(np.pi / 4, wires=1)
return qml.expval(qml.Hamiltonian([coeff1, coeff2], [qml.Y(0) @ qml.Z(1), qml.X(1)]))

params = torch.tensor(np.pi / 4), torch.tensor(3 * np.pi / 4)
actual = jacobian(circuit, params)

assert qml.math.allclose(actual, [-0.5, np.cos(np.pi / 4)], rtol=0.05)

@pytest.mark.tf
@pytest.mark.parametrize("grouping_strategy", [None, "default", "qwc", "wires"])
def test_tensorflow(self, grouping_strategy):
"""Tests that the output of ``split_non_commuting`` is differentiable with torch"""
"""Tests that the output of ``split_non_commuting`` is differentiable with tensorflow"""

import tensorflow as tf

Expand Down Expand Up @@ -1111,3 +1149,27 @@ def circuit(theta, phi):

assert qml.math.allclose(grad1, expected_grad_1, atol=1e-5)
assert qml.math.allclose(grad2, expected_grad_2, atol=1e-5)

@pytest.mark.tf
@pytest.mark.parametrize("grouping_strategy", [None, "default", "qwc", "wires"])
def test_trainable_hamiltonian_tensorflow(self, grouping_strategy):
"""Tests that measurements of trainable Hamiltonians are differentiable with tensorflow"""

import tensorflow as tf

dev = qml.device("default.qubit", wires=2, shots=50000)

@qml.qnode(dev)
def circuit(coeff1, coeff2):
qml.RX(np.pi / 4, wires=0)
qml.RY(np.pi / 4, wires=1)
return qml.expval(qml.Hamiltonian([coeff1, coeff2], [qml.Y(0) @ qml.Z(1), qml.X(1)]))

params = tf.Variable(np.pi / 4), tf.Variable(3 * np.pi / 4)

with tf.GradientTape() as tape:
cost = split_non_commuting(circuit, grouping_strategy=grouping_strategy)(*params)

actual = tape.jacobian(cost, params)

assert qml.math.allclose(actual, [-0.5, np.cos(np.pi / 4)], rtol=0.05)

0 comments on commit 4600c22

Please sign in to comment.