Skip to content

Commit

Permalink
Remove op.is_hermitian check in expval, counts, sample to allow jit…
Browse files Browse the repository at this point in the history
… tracing (#5506)

Solves #5505 and also
fixes the same issue in catalyst

tldr: The `is_hermitian` check is breaking jit-compilation
  • Loading branch information
Qottmann committed Apr 15, 2024
1 parent 35de283 commit 4e785e5
Show file tree
Hide file tree
Showing 19 changed files with 34 additions and 176 deletions.
3 changes: 3 additions & 0 deletions doc/releases/changelog-dev.md
Original file line number Diff line number Diff line change
Expand Up @@ -258,6 +258,9 @@
(which is not currently compatible with `KerasLayer`), linking to instructions to enable Keras 2.
[(#5488)](https://github.com/PennyLaneAI/pennylane/pull/5488)

* Removed the warning that an observable might not be hermitian in `qnode` executions. This enables jit-compilation.
[(#5506)](https://github.com/PennyLaneAI/pennylane/pull/5506)

<h3>Breaking changes 💔</h3>

* The private functions `_pauli_mult`, `_binary_matrix` and `_get_pauli_map` from the `pauli` module have been removed. The same functionality can be achieved using newer features in the ``pauli`` module.
Expand Down
4 changes: 0 additions & 4 deletions pennylane/measurements/counts.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
"""
This module contains the qml.counts measurement.
"""
import warnings
from typing import Sequence, Tuple, Optional
import numpy as np

Expand Down Expand Up @@ -146,9 +145,6 @@ def circuit():

return CountsMP(obs=op, all_outcomes=all_outcomes)

if op is not None and not op.is_hermitian: # None type is also allowed for op
warnings.warn(f"{op.name} might not be hermitian.")

if wires is not None:
if op is not None:
raise ValueError(
Expand Down
4 changes: 0 additions & 4 deletions pennylane/measurements/expval.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
"""
This module contains the qml.expval measurement.
"""
import warnings
from typing import Sequence, Tuple, Union

import pennylane as qml
Expand Down Expand Up @@ -71,9 +70,6 @@ def circuit(x):
"Expectation values of qml.Identity() without wires are currently not allowed."
)

if not op.is_hermitian:
warnings.warn(f"{op.name} might not be hermitian.")

return ExpectationMP(obs=op)


Expand Down
4 changes: 0 additions & 4 deletions pennylane/measurements/sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
This module contains the qml.sample measurement.
"""
import functools
import warnings
from typing import Sequence, Tuple, Optional, Union

import numpy as np
Expand Down Expand Up @@ -173,9 +172,6 @@ def __init__(self, obs=None, wires=None, eigvals=None, id=None):
super().__init__(obs=obs)
return

if obs is not None and not obs.is_hermitian: # None type is also allowed for op
warnings.warn(f"{obs.name} might not be hermitian.")

if wires is not None:
if obs is not None:
raise ValueError(
Expand Down
4 changes: 0 additions & 4 deletions pennylane/measurements/var.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
"""
This module contains the qml.var measurement.
"""
import warnings
from typing import Sequence, Tuple, Union

import pennylane as qml
Expand Down Expand Up @@ -64,9 +63,6 @@ def circuit(x):
"qml.var does not support measuring sequences of measurements or observables"
)

if not op.is_hermitian:
warnings.warn(f"{op.name} might not be hermitian.")

return VarianceMP(obs=op)


Expand Down
28 changes: 28 additions & 0 deletions tests/devices/default_qubit/test_default_qubit.py
Original file line number Diff line number Diff line change
Expand Up @@ -2052,6 +2052,34 @@ def circuit():

assert circuit() == expected

@pytest.mark.jax
@pytest.mark.parametrize("measurement_func", [qml.expval, qml.var])
def test_differentiate_jitted_qnode(self, measurement_func):
"""Test that a jitted qnode can be correctly differentiated"""
import jax

dev = DefaultQubit()

def qfunc(x, y):
qml.RX(x, 0)
return measurement_func(qml.Hamiltonian(y, [qml.Z(0)]))

qnode = qml.QNode(qfunc, dev, interface="jax")
qnode_jit = jax.jit(qml.QNode(qfunc, dev, interface="jax"))

x = jax.numpy.array(0.5)
y = jax.numpy.array([0.5])

res = qnode(x, y)
res_jit = qnode_jit(x, y)

assert qml.math.allclose(res, res_jit)

grad = jax.grad(qnode)(x, y)
grad_jit = jax.grad(qnode_jit)(x, y)

assert qml.math.allclose(grad, grad_jit)


@pytest.mark.parametrize("max_workers", max_workers_list)
def test_broadcasted_parameter(max_workers):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1520,6 +1520,7 @@ def test_hamiltonian_expansion_analytic(
spy = mocker.spy(qml.transforms, "hamiltonian_expand")
obs = [qml.PauliX(0), qml.PauliX(0) @ qml.PauliZ(1), qml.PauliZ(0) @ qml.PauliZ(1)]

@jax.jit
@qnode(
dev,
interface=interface,
Expand Down
18 changes: 0 additions & 18 deletions tests/measurements/legacy/test_expval_legacy.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,24 +80,6 @@ def circuit(x):

custom_measurement_process(new_dev, spy)

def test_not_an_observable(self, mocker):
"""Test that a warning is raised if the provided
argument might not be hermitian."""
dev = qml.device("default.qubit.legacy", wires=2)

@qml.qnode(dev)
def circuit():
qml.RX(0.52, wires=0)
return qml.expval(qml.prod(qml.PauliX(0), qml.PauliZ(0)))

new_dev = circuit.device
spy = mocker.spy(qml.QubitDevice, "expval")

with pytest.warns(UserWarning, match="Prod might not be hermitian."):
_ = circuit()

custom_measurement_process(new_dev, spy)

def test_observable_return_type_is_expectation(self, mocker):
"""Test that the return type of the observable is :attr:`ObservableReturnTypes.Expectation`"""
dev = qml.device("default.qubit.legacy", wires=2)
Expand Down
18 changes: 0 additions & 18 deletions tests/measurements/legacy/test_measurements_legacy.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,6 @@
Shots,
StateMeasurement,
StateMP,
expval,
var,
)

from pennylane.wires import Wires
Expand Down Expand Up @@ -66,22 +64,6 @@ def test_shape_unrecognized_error():
mp.shape(dev, Shots(None))


@pytest.mark.parametrize("stat_func", [expval, var])
def test_not_an_observable(stat_func):
"""Test that a UserWarning is raised if the provided
argument might not be hermitian."""

dev = qml.device("default.qubit.legacy", wires=2)

@qml.qnode(dev)
def circuit():
qml.RX(0.52, wires=0)
return stat_func(qml.prod(qml.PauliX(0), qml.PauliZ(0)))

with pytest.warns(UserWarning, match="Prod might not be hermitian."):
_ = circuit()


class TestSampleMeasurement:
"""Tests for the SampleMeasurement class."""

Expand Down
16 changes: 0 additions & 16 deletions tests/measurements/legacy/test_sample_legacy.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,22 +251,6 @@ def circuit():

custom_measurement_process(dev, spy)

def test_not_an_observable(self, mocker):
"""Test that a UserWarning is raised if the provided
argument might not be hermitian."""
dev = qml.device("default.qubit.legacy", wires=2, shots=10)
spy = mocker.spy(qml.QubitDevice, "sample")

@qml.qnode(dev)
def circuit():
qml.RX(0.52, wires=0)
return qml.sample(qml.prod(qml.PauliX(0), qml.PauliZ(0)))

with pytest.warns(UserWarning, match="Prod might not be hermitian."):
_ = circuit()

custom_measurement_process(dev, spy)

def test_observable_return_type_is_sample(self, mocker):
"""Test that the return type of the observable is :attr:`ObservableReturnTypes.Sample`"""
n_shots = 10
Expand Down
16 changes: 0 additions & 16 deletions tests/measurements/legacy/test_var_legacy.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,22 +77,6 @@ def circuit(x):

custom_measurement_process(dev, spy)

def test_not_an_observable(self, mocker):
"""Test that a UserWarning is raised if the provided
argument might not be hermitian."""
dev = qml.device("default.qubit.legacy", wires=2)
spy = mocker.spy(qml.QubitDevice, "var")

@qml.qnode(dev)
def circuit():
qml.RX(0.52, wires=0)
return qml.var(qml.prod(qml.PauliX(0), qml.PauliZ(0)))

with pytest.warns(UserWarning, match="Prod might not be hermitian."):
_ = circuit()

custom_measurement_process(dev, spy)

def test_observable_return_type_is_variance(self, mocker):
"""Test that the return type of the observable is :attr:`ObservableReturnTypes.Variance`"""
dev = qml.device("default.qubit.legacy", wires=2)
Expand Down
7 changes: 0 additions & 7 deletions tests/measurements/test_counts.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,13 +61,6 @@ def test_providing_observable_and_wires(self):
):
qml.counts(qml.PauliZ(0), wires=[0, 1])

def test_observable_might_not_be_hermitian(self):
"""Test that a UserWarning is raised if the provided
argument might not be hermitian."""

with pytest.warns(UserWarning, match="Prod might not be hermitian."):
qml.counts(qml.prod(qml.PauliX(0), qml.PauliZ(0)))

def test_hash(self):
"""Test that the hash property includes the all_outcomes property."""
m1 = qml.counts(all_outcomes=True)
Expand Down
13 changes: 0 additions & 13 deletions tests/measurements/test_expval.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,19 +73,6 @@ def circuit(x):
else:
assert res.dtype == r_dtype

def test_not_an_observable(self):
"""Test that a warning is raised if the provided
argument might not be hermitian."""
dev = qml.device("default.qubit", wires=2)

@qml.qnode(dev)
def circuit():
qml.RX(0.52, wires=0)
return qml.expval(qml.prod(qml.PauliX(0), qml.PauliZ(0)))

with pytest.warns(UserWarning, match="Prod might not be hermitian."):
_ = circuit()

def test_observable_return_type_is_expectation(self):
"""Test that the return type of the observable is :attr:`ObservableReturnTypes.Expectation`"""
dev = qml.device("default.qubit", wires=2)
Expand Down
16 changes: 0 additions & 16 deletions tests/measurements/test_measurements.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,22 +293,6 @@ def test_queueing_tensor_observable(self, op1, op2, stat_func, return_type):
assert isinstance(meas_proc, MeasurementProcess)
assert meas_proc.return_type == return_type

def test_not_an_observable(self, stat_func, return_type): # pylint: disable=unused-argument
"""Test that a UserWarning is raised if the provided
argument might not be hermitian."""
if stat_func is sample:
pytest.skip("Sampling is not yet supported with symbolic operators.")

dev = qml.device("default.qubit", wires=2)

@qml.qnode(dev)
def circuit():
qml.RX(0.52, wires=0)
return stat_func(qml.prod(qml.PauliX(0), qml.PauliZ(0)))

with pytest.warns(UserWarning, match="Prod might not be hermitian."):
_ = circuit()


class TestProperties:
"""Test for the properties"""
Expand Down
13 changes: 0 additions & 13 deletions tests/measurements/test_sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,19 +126,6 @@ def circuit():
assert result[2].dtype == np.dtype("int")
assert np.array_equal(result[2].shape, (n_sample,))

def test_not_an_observable(self):
"""Test that a UserWarning is raised if the provided
argument might not be hermitian."""
dev = qml.device("default.qubit", wires=2, shots=10)

@qml.qnode(dev)
def circuit():
qml.RX(0.52, wires=0)
return qml.sample(qml.prod(qml.PauliX(0), qml.PauliZ(0)))

with pytest.warns(UserWarning, match="Prod might not be hermitian."):
_ = circuit()

def test_observable_return_type_is_sample(self):
"""Test that the return type of the observable is :attr:`ObservableReturnTypes.Sample`"""
n_shots = 10
Expand Down
13 changes: 0 additions & 13 deletions tests/measurements/test_var.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,19 +49,6 @@ def circuit(x):
else:
assert res.dtype == r_dtype

def test_not_an_observable(self):
"""Test that a UserWarning is raised if the provided
argument might not be hermitian."""
dev = qml.device("default.qubit", wires=2)

@qml.qnode(dev)
def circuit():
qml.RX(0.52, wires=0)
return qml.var(qml.prod(qml.PauliX(0), qml.PauliZ(0)))

with pytest.warns(UserWarning, match="Prod might not be hermitian."):
_ = circuit()

def test_observable_return_type_is_variance(self):
"""Test that the return type of the observable is :attr:`ObservableReturnTypes.Variance`"""
dev = qml.device("default.qubit", wires=2)
Expand Down
4 changes: 2 additions & 2 deletions tests/ops/op_math/test_prod.py
Original file line number Diff line number Diff line change
Expand Up @@ -1485,8 +1485,8 @@ def circuit(weights):
true_grad = -qnp.sqrt(2) * qnp.cos(weights[0] / 2) * qnp.sin(weights[0] / 2)
assert qnp.allclose(grad, true_grad)

def test_non_hermitian_obs_not_supported(self):
"""Test that non-hermitian ops in a measurement process will raise a warning."""
def test_non_supported_obs_not_supported(self):
"""Test that non-supported ops in a measurement process will raise an error."""
wires = [0, 1]
dev = qml.device("default.qubit", wires=wires)
prod_op = Prod(qml.RX(1.23, wires=0), qml.Identity(wires=1))
Expand Down
14 changes: 0 additions & 14 deletions tests/ops/op_math/test_sprod.py
Original file line number Diff line number Diff line change
Expand Up @@ -1084,20 +1084,6 @@ def circuit(weights):
true_grad = 100 * -qnp.sqrt(2) * qnp.cos(weights[0] / 2) * qnp.sin(weights[0] / 2)
assert qnp.allclose(grad, true_grad)

def test_non_hermitian_obs_not_supported(self):
"""Test that non-hermitian ops in a measurement process will raise a warning."""
wires = [0, 1]
dev = qml.device("default.qubit", wires=wires)
sprod_op = SProd(1.0 + 2.0j, qml.RX(1.23, wires=0))

@qml.qnode(dev)
def my_circ():
qml.PauliX(0)
return qml.expval(sprod_op)

with pytest.raises(NotImplementedError):
my_circ()

@pytest.mark.torch
@pytest.mark.parametrize("diff_method", ("parameter-shift", "backprop"))
def test_torch(self, diff_method):
Expand Down
14 changes: 0 additions & 14 deletions tests/ops/op_math/test_sum.py
Original file line number Diff line number Diff line change
Expand Up @@ -1267,20 +1267,6 @@ def circuit(weights):
true_grad = qnp.array([-0.09347337, -0.18884787, -0.28818254])
assert qnp.allclose(grad, true_grad)

def test_non_hermitian_op_in_measurement_process(self):
"""Test that non-hermitian ops in a measurement process will raise a warning."""
wires = [0, 1]
dev = qml.device("default.qubit", wires=wires)
sum_op = Sum(Prod(qml.RX(1.23, wires=0), qml.Identity(wires=1)), qml.Identity(wires=1))

@qml.qnode(dev, interface=None)
def my_circ():
qml.PauliX(0)
return qml.expval(sum_op)

with pytest.warns(UserWarning, match="Sum might not be hermitian."):
my_circ()

def test_params_can_be_considered_trainable(self):
"""Tests that the parameters of a Sum are considered trainable."""
dev = qml.device("default.qubit", wires=2)
Expand Down

0 comments on commit 4e785e5

Please sign in to comment.