Skip to content

Commit

Permalink
Make Projector jit compatible (#5595)
Browse files Browse the repository at this point in the history
**Description of the Change:**

Add some jit-friendly logic to `BasisStateProjector` and
`StateVectorProjector`.

**Benefits:**

Projector is now jit compatible.

We might finally be able to move forward with #4969 if we ever actually
had time to do that.

**Possible Drawbacks:**

Not testing trainability, as it doesn't really make sense.

Could potentially be made more efficient.

**Related GitHub Issues:**

Fixes #4977 [sc-52777]

---------

Co-authored-by: Mudit Pandey <mudit.pandey@xanadu.ai>
Co-authored-by: Astral Cai <astral.cai@xanadu.ai>
Co-authored-by: Vincent Michaud-Rioux <vincentm@nanoacademic.com>
Co-authored-by: David Wierichs <david.wierichs@xanadu.ai>
  • Loading branch information
5 people committed May 10, 2024
1 parent 9c9b6ba commit 762b337
Show file tree
Hide file tree
Showing 7 changed files with 122 additions and 18 deletions.
3 changes: 3 additions & 0 deletions doc/releases/changelog-dev.md
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,9 @@

<h3>Bug fixes 🐛</h3>

* `qml.Projector` is now compatible with jax-jit.
[(#5595)](https://github.com/PennyLaneAI/pennylane/pull/5595)

* Finite shot circuits with a `qml.probs` measurement, both with a `wires` or `op` argument, can now be compiled with `jax.jit`.
[(#5619)](https://github.com/PennyLaneAI/pennylane/pull/5619)

Expand Down
1 change: 1 addition & 0 deletions pennylane/math/single_dispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,6 +243,7 @@ def _take_autograd(tensor, indices, axis=None):
ar.autoray._SUBMODULE_ALIASES["tensorflow", "isclose"] = "tensorflow.experimental.numpy"
ar.autoray._SUBMODULE_ALIASES["tensorflow", "atleast_1d"] = "tensorflow.experimental.numpy"
ar.autoray._SUBMODULE_ALIASES["tensorflow", "all"] = "tensorflow.experimental.numpy"
ar.autoray._SUBMODULE_ALIASES["tensorflow", "vstack"] = "tensorflow.experimental.numpy"

tf_fft_functions = [
"fft",
Expand Down
5 changes: 3 additions & 2 deletions pennylane/ops/op_math/composite.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,8 +175,9 @@ def eigvals(self):
list(self.wires),
)
)

return self._math_op(math.asarray(eigvals, like=math.get_deep_interface(eigvals)), axis=0)
framework = math.get_deep_interface(eigvals)
eigvals = [math.asarray(ei, like=framework) for ei in eigvals]
return self._math_op(math.vstack(eigvals), axis=0)

@abc.abstractmethod
def matrix(self, wire_order=None):
Expand Down
64 changes: 53 additions & 11 deletions pennylane/ops/qubit/observables.py
Original file line number Diff line number Diff line change
Expand Up @@ -435,14 +435,23 @@ class BasisStateProjector(Projector, Operation):
r"""Observable corresponding to the state projector :math:`P=\ket{\phi}\bra{\phi}`, where
:math:`\phi` denotes a basis state."""

grad_method = None

# The call signature should be the same as Projector.__new__ for the positional
# arguments, but with free key word arguments.
def __init__(self, state, wires, id=None):
wires = Wires(wires)
state = list(qml.math.toarray(state).astype(int))

if not set(state).issubset({0, 1}):
raise ValueError(f"Basis state must only consist of 0s and 1s; got {state}")
if qml.math.get_interface(state) == "jax":
dtype = qml.math.dtype(state)
if not (np.issubdtype(dtype, np.integer) or np.issubdtype(dtype, bool)):
raise ValueError("Basis state must consist of integers or booleans.")
else:
# state is index into data, rather than data, so cast it to built-ins when
# no need for tracing
state = tuple(qml.math.toarray(state).astype(int))
if not set(state).issubset({0, 1}):
raise ValueError(f"Basis state must only consist of 0s and 1s; got {state}")

super().__init__(state, wires=wires, id=id)

Expand Down Expand Up @@ -471,7 +480,7 @@ def label(self, decimals=None, base_label=None, cache=None):

if base_label is not None:
return base_label
basis_string = "".join(str(int(i)) for i in self.parameters[0])
basis_string = "".join(str(int(i)) for i in self.data[0])
return f"|{basis_string}⟩⟨{basis_string}|"

@staticmethod
Expand All @@ -497,7 +506,15 @@ def compute_matrix(basis_state): # pylint: disable=arguments-differ
[0. 0. 0. 0.]
[0. 0. 0. 0.]]
"""
m = np.zeros((2 ** len(basis_state), 2 ** len(basis_state)))
shape = (2 ** len(basis_state), 2 ** len(basis_state))
if qml.math.get_interface(basis_state) == "jax":
idx = 0
for i, m in enumerate(basis_state):
idx = idx + (m << (len(basis_state) - i - 1))
mat = qml.math.zeros(shape, like=basis_state)
return mat.at[idx, idx].set(1.0)

m = np.zeros(shape)
idx = int("".join(str(i) for i in basis_state), 2)
m[idx, idx] = 1
return m
Expand Down Expand Up @@ -528,6 +545,12 @@ def compute_eigvals(basis_state): # pylint: disable=arguments-differ
>>> BasisStateProjector.compute_eigvals([0, 1])
[0. 1. 0. 0.]
"""
if qml.math.get_interface(basis_state) == "jax":
idx = 0
for i, m in enumerate(basis_state):
idx = idx + (m << (len(basis_state) - i - 1))
eigvals = qml.math.zeros(2 ** len(basis_state), like=basis_state)
return eigvals.at[idx].set(1.0)
w = np.zeros(2 ** len(basis_state))
idx = int("".join(str(i) for i in basis_state), 2)
w[idx] = 1
Expand Down Expand Up @@ -566,12 +589,12 @@ class StateVectorProjector(Projector):
r"""Observable corresponding to the state projector :math:`P=\ket{\phi}\bra{\phi}`, where
:math:`\phi` denotes a state."""

grad_method = None

# The call signature should be the same as Projector.__new__ for the positional
# arguments, but with free key word arguments.
def __init__(self, state, wires, id=None):
wires = Wires(wires)
state = list(qml.math.toarray(state))

super().__init__(state, wires=wires, id=id)

def __new__(cls, *_, **__): # pylint: disable=arguments-differ
Expand Down Expand Up @@ -680,9 +703,11 @@ def compute_eigvals(state_vector): # pylint: disable=arguments-differ,arguments
>>> StateVectorProjector.compute_eigvals([0, 0, 1, 0])
array([1, 0, 0, 0])
"""
w = qml.math.zeros_like(state_vector)
precision = qml.math.get_dtype_name(state_vector)[-2:]
dtype = f"float{precision}" if precision in {"32", "64"} else "float64"
w = np.zeros(qml.math.shape(state_vector), dtype=dtype)
w[0] = 1
return w
return qml.math.convert_like(w, state_vector)

@staticmethod
def compute_diagonalizing_gates(
Expand Down Expand Up @@ -715,10 +740,27 @@ def compute_diagonalizing_gates(
# Adapting the approach discussed in the link below to work with arbitrary complex-valued state vectors.
# Alternatively, we could take the adjoint of the Mottonen decomposition for the state vector.
# https://quantumcomputing.stackexchange.com/questions/10239/how-can-i-fill-a-unitary-knowing-only-its-first-column
phase = qml.math.exp(-1j * qml.math.angle(state_vector[0]))

if qml.math.get_interface(state_vector) == "tensorflow":
dtype_name = qml.math.get_dtype_name(state_vector)
if dtype_name == "int32":
state_vector = qml.math.cast(state_vector, np.complex64)
elif dtype_name == "int64":
state_vector = qml.math.cast(state_vector, np.complex128)

angle = qml.math.angle(state_vector[0])
if qml.math.get_interface(angle) == "tensorflow":
if qml.math.get_dtype_name(angle) == "float32":
angle = qml.math.cast(angle, np.complex64)
else:
angle = qml.math.cast(angle, np.complex128)

phase = qml.math.exp(-1.0j * angle)
psi = phase * state_vector
denominator = qml.math.sqrt(2 + 2 * psi[0])
psi = qml.math.set_index(psi, 0, psi[0] + 1) # psi[0] += 1, but JAX-JIT compatible
summed_array = np.zeros(qml.math.shape(psi), dtype=qml.math.get_dtype_name(psi))
summed_array[0] = 1.0
psi = psi + summed_array
psi /= denominator
u = 2 * qml.math.outer(psi, qml.math.conj(psi)) - qml.math.eye(len(psi))
return [QubitUnitary(u, wires=wires)]
2 changes: 1 addition & 1 deletion pennylane/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,4 +205,4 @@ def expand_vector(vector, original_wires, expanded_wires):
expanded_tensor, tuple(original_indices), tuple(wire_indices)
)

return qml.math.reshape(expanded_tensor, qudit_order**M)
return qml.math.reshape(expanded_tensor, (qudit_order**M,))
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
"""Integration tests for using the TensorFlow interface with a QNode"""
import numpy as np

# pylint: disable=too-many-arguments,too-few-public-methods,comparison-with-callable
# pylint: disable=too-many-arguments,too-few-public-methods,comparison-with-callable, use-implicit-booleaness-not-comparison
import pytest

import pennylane as qml
Expand Down Expand Up @@ -955,8 +955,9 @@ def cost_fn(x, y):
assert np.allclose(grad, expected, atol=tol, rtol=0)

@pytest.mark.parametrize("state", [[1], [0, 1]]) # Basis state and state vector
@pytest.mark.parametrize("dtype", ("int32", "int64"))
def test_projector(
self, state, dev, diff_method, grad_on_execution, device_vjp, tol, interface
self, state, dev, diff_method, grad_on_execution, device_vjp, tol, interface, dtype
):
"""Test that the variance of a projector is correctly returned"""
kwargs = {
Expand All @@ -974,7 +975,7 @@ def test_projector(
kwargs["num_directions"] = 20
tol = TOL_FOR_SPSA

P = tf.constant(state)
P = tf.constant(state, dtype=dtype)

x, y = 0.765, -0.654
weights = tf.Variable([x, y], dtype=tf.float64)
Expand Down
58 changes: 57 additions & 1 deletion tests/ops/qubit/test_observables.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Unit tests for qubit observables."""
# pylint: disable=protected-access
# pylint: disable=protected-access, use-implicit-booleaness-not-comparison
import functools
import pickle

Expand Down Expand Up @@ -540,6 +540,44 @@ def test_serialization(self):
assert qml.equal(new_proj, proj)
assert new_proj.id == proj.id # Ensure they are identical

@pytest.mark.jax
def test_jit_measurement(self):
"""Test that the measurement of a projector can be jitted."""

import jax

@jax.jit
@qml.qnode(qml.device("default.qubit"))
def circuit(state):
qml.X(1)
return qml.expval(qml.Projector(state, wires=(0, 1)))

state00 = jax.numpy.array([0, 0])
out00 = circuit(state00)
assert qml.math.allclose(out00, 0)
state01 = jax.numpy.array([0, 1])
out01 = circuit(state01)
assert qml.math.allclose(out01, 1)
state10 = jax.numpy.array([True, False])
out10 = circuit(state10)
assert qml.math.allclose(out10, 0)

with pytest.raises(ValueError, match=r"Basis state must consist of integers or booleans."):
circuit(jax.numpy.array([0.5, 0.6]))

@pytest.mark.jax
def test_jit_matrix(self):
"""Test that computing the matrix of a projector is jittable."""

import jax

basis_state = jax.numpy.array([0, 1])
f = jax.jit(BasisStateProjector.compute_matrix)
out = f(basis_state)

expected = np.array([[0, 0, 0, 0], [0, 1, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0]])
assert qml.math.allclose(out, expected)


class TestBasisStateProjector:
"""Tests for the basis state projector observable."""
Expand Down Expand Up @@ -716,6 +754,24 @@ def test_label_matrices_list_in_cache(self, projector):
assert len(cache["matrices"]) == 2
assert np.allclose(cache["matrices"][1], projector.parameters[0])

@pytest.mark.jax
def test_jit_execution(self):
"""Test that executing a StateVectorProjector can be jitted."""

import jax

@jax.jit
@qml.qnode(qml.device("default.qubit"))
def circuit(state):
return qml.expval(qml.Projector(state, wires=(0, 1)))

basis_state = jax.numpy.array([1, 1, 1, 1.0]) / 2
out = circuit(basis_state)
assert qml.math.allclose(out, 0.25)

basis_state2 = jax.numpy.array([0, 0, 0, 0])
assert qml.math.allclose(circuit(basis_state2), 0)


label_data = [
(qml.Hermitian(np.eye(2), wires=1), "𝓗"),
Expand Down

0 comments on commit 762b337

Please sign in to comment.