Skip to content

Commit

Permalink
Faster, better and differentiable Pauli decompose (#4395)
Browse files Browse the repository at this point in the history
* faster pauli decompose

* add `non_square`

* change docs

* Update conversion.py

* merging functions attempt

* merge function attempt 2

* add tests for general matrices

* add changelog

* fix black

* make pylint happy?

* restore a typo

* use `qml.math` instead of `numpy`

* remove `numpy` import

* make `xor`ing differentiable

* run `black`

* fix casting

* separate `phase` and `walsh_hadamard` calculations

* remove comments

* make `torch` work

* make compatible with `tensorflow`

* fix padding for `torch`

* fix `pylint` for tests

* make `torch` work :)

* fix `black`

* add differentiability tests

* add new examples

* update docstrings

* update docstring

* update docstring

* remove newline

* use review comments

* update warning

* fix typo

* add `builtins` support

* update `requirements`
  • Loading branch information
obliviateandsurrender authored Aug 10, 2023
1 parent a025e29 commit b5789db
Show file tree
Hide file tree
Showing 5 changed files with 246 additions and 51 deletions.
4 changes: 4 additions & 0 deletions doc/releases/changelog-dev.md
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,9 @@ array([False, False])
* When given a callable, `qml.ctrl` now does its custom pre-processing on all queued operators from the callable.
[(#4370)](https://github.com/PennyLaneAI/pennylane/pull/4370)

* `qml.pauli_decompose` is now differentiable and works with any non-Hermitian and non-square matrices.
[(#4395)](https://github.com/PennyLaneAI/pennylane/pull/4395)

* `qml.interfaces.set_shots` accepts `Shots` object as well as `int`'s and tuples of `int`'s.
[(#4388)](https://github.com/PennyLaneAI/pennylane/pull/4388)

Expand Down Expand Up @@ -318,6 +321,7 @@ array([False, False])

This release contains contributions from (in alphabetical order):

Utkarsh Azad,
Isaac De Vlugt,
Stepan Fomichev,
Lillian M. A. Frederiksen,
Expand Down
4 changes: 2 additions & 2 deletions pennylane/ops/qubit/matrix_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,14 +62,14 @@ def _walsh_hadamard_transform(D, n=None):
new_shape = (orig_shape[0],) + (2,) * n
else:
new_shape = (2,) * n
D = D.reshape(new_shape)
D = qml.math.reshape(D, new_shape)
# Apply Hadamard transform to each axis, shifted by one for broadcasting
for i in range(broadcasted, n + broadcasted):
D = qml.math.tensordot(_walsh_hadamard_matrix, D, axes=[[1], [i]])
# The axes are in reverted order after all matrix multiplications, so we need to transpose;
# If D was broadcasted, this moves the broadcasting axis to first position as well.
# Finally, reshape to original shape
return qml.math.transpose(D).reshape(orig_shape)
return qml.math.reshape(qml.math.transpose(D), orig_shape)


class QubitUnitary(Operation):
Expand Down
164 changes: 127 additions & 37 deletions pennylane/pauli/conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,36 +19,52 @@
from operator import matmul
from typing import Union

import numpy as np

import pennylane as qml
from pennylane.operation import Tensor
from pennylane.ops import Hamiltonian, Identity, PauliX, PauliY, PauliZ, Prod, SProd, Sum
from pennylane.ops.qubit.matrix_ops import _walsh_hadamard_transform

from .pauli_arithmetic import I, PauliSentence, PauliWord, X, Y, Z, mat_map, op_map
from .pauli_arithmetic import I, PauliSentence, PauliWord, X, Y, Z, op_map
from .utils import is_pauli_word


# pylint: disable=too-many-branches
def pauli_decompose(
H, hide_identity=False, wire_order=None, pauli=False
matrix,
hide_identity=False,
wire_order=None,
pauli=False,
padding=False,
) -> Union[Hamiltonian, PauliSentence]:
r"""Decomposes a Hermitian matrix into a linear combination of Pauli operators.
r"""Decomposes a matrix into a linear combination of Pauli operators.
This method converts any matrix to a weighted sum of Pauli words acting on :math:`n` qubits
in time :math:`O(n 4^n)`. The input matrix is first padded with zeros if its dimensions are not
:math:`2^n\times 2^n` and written as a quantum state in the computational basis following the
`channel-state duality <https://en.wikipedia.org/wiki/Channel-state_duality>`_.
A Bell basis transformation is then performed using the
`Walsh-Hadamard transform <https://en.wikipedia.org/wiki/Hadamard_transform>`_, after which
coefficients for each of the :math:`4^n` Pauli words are computed while accounting for the
phase from each ``PauliY`` term occurring in the word.
Args:
H (array[complex]): a Hermitian matrix of dimension :math:`2^n\times 2^n`.
matrix (tensor_like[complex]): any matrix M, the keyword argument ``padding=True``
should be provided if the dimension of M is not :math:`2^n\times 2^n`.
hide_identity (bool): does not include the Identity observable within
the tensor products of the decomposition if ``True``.
wire_order (list[Union[int, str]]): the ordered list of wires with respect
to which the operator is represented as a matrix.
pauli (bool): return a PauliSentence instance if ``True``.
padding (bool): makes the function compatible with rectangular matrices and square matrices
that are not of shape :math:`2^n\times 2^n` by padding them with zeros if ``True``.
Returns:
Union[~.Hamiltonian, ~.PauliSentence]: the matrix decomposed as a linear combination
of Pauli operators, either as a :class:`~.Hamiltonian` or :class:`~.PauliSentence` instance.
**Example:**
We can use this function to compute the Pauli operator decomposition of an arbitrary Hermitian
matrix:
We can use this function to compute the Pauli operator decomposition of an arbitrary matrix:
>>> A = np.array(
... [[-2, -2+1j, -2, -2], [-2-1j, 0, 0, -1], [-2, 0, -2, -1], [-2, -1, -1, 0]])
Expand Down Expand Up @@ -94,52 +110,126 @@ def pauli_decompose(
+ 1.0 * Y(a) @ Y(b)
+ -0.5 * Z(a) @ X(b)
+ -0.5 * Z(a) @ Y(b)
"""
n = int(np.log2(len(H)))
N = 2**n
if wire_order is not None and len(wire_order) != n:
raise ValueError(
f"number of wires {len(wire_order)} is not compatible with number of qubits {n}"
)
.. details::
:title: Usage Details
:href: usage-decompose-operation
if wire_order is None:
wire_order = range(n)
For non-square matrices, we need to provide the ``padding=True`` keyword argument:
>>> A = np.array([[-2, -2 + 1j]])
>>> H = qml.pauli_decompose(A, padding=True)
>>> print(H)
((-1+0j)) [I0]
+ ((-1+0.5j)) [X0]
+ ((-1+0j)) [Z0]
+ ((-0.5-1j)) [Y0]
We can also use the method within a differentiable workflow and obtain gradients:
>>> A = qml.numpy.array([[-2, -2 + 1j]], requires_grad=True)
>>> dev = qml.device("default.qubit", wires=1)
>>> @qml.qnode(dev)
... def circuit(A):
... decomp = qml.pauli_decompose(A, padding=True)
... qml.RX(decomp.coeffs[2], 0)
... return qml.expval(qml.PauliZ(0))
>>> grad_numpy = qml.grad(circuit)(A)
tensor([[-2.+0.j, -2.+1.j]], requires_grad=True)
if H.shape != (N, N):
raise ValueError("The matrix should have shape (2**n, 2**n), for any qubit number n>=1")
"""
# Ensuring original matrix is not manipulated and we support builtin types.
matrix = qml.math.convert_like(matrix, matrix)

# Pad with zeros to make the matrix shape equal and a power of two.
if padding:
shape = qml.math.shape(matrix)
num_qubits = int(qml.math.ceil(qml.math.log2(qml.math.max(shape))))
if shape[0] != shape[1] or shape[0] != 2**num_qubits:
padd_diffs = qml.math.abs(qml.math.array(shape) - 2**num_qubits)
padding = (
((0, padd_diffs[0]), (0, padd_diffs[1]))
if qml.math.get_interface(matrix) != "torch"
else ((padd_diffs[0], 0), (padd_diffs[1], 0))
)
matrix = qml.math.pad(matrix, padding, mode="constant", constant_values=0)

if not np.allclose(H, H.conj().T):
raise ValueError("The matrix is not Hermitian")
shape = qml.math.shape(matrix)
if shape[0] != shape[1]:
raise ValueError(
f"The matrix should be square, got {shape}. Use 'padding=True' for rectangular matrices."
)

obs_lst = []
coeffs = []
num_qubits = int(qml.math.log2(shape[0]))
if shape[0] != 2**num_qubits:
raise ValueError(
f"Dimension of the matrix should be a power of 2, got {shape}. Use 'padding=True' for these matrices."
)

for term in product([I, X, Y, Z], repeat=n):
matrices = [mat_map[i] for i in term]
coeff = np.trace(reduce(np.kron, matrices) @ H) / N
coeff = np.real_if_close(coeff).item()
if wire_order is not None and len(wire_order) != num_qubits:
raise ValueError(
f"number of wires {len(wire_order)} is not compatible with the number of qubits {num_qubits}"
)

if not np.allclose(coeff, 0):
obs_term = (
[(o, w) for w, o in zip(wire_order, term) if o != I]
if hide_identity and not all(t == I for t in term)
else [(o, w) for w, o in zip(wire_order, term)]
if wire_order is None:
wire_order = range(num_qubits)

# Permute by XORing
indices = [qml.math.array(range(shape[0]))]
for idx in range(shape[0] - 1):
indices.append(qml.math.bitwise_xor(indices[-1], (idx + 1) ^ (idx)))
term_mat = qml.math.cast(
qml.math.stack(
[qml.math.gather(matrix[idx], indice) for idx, indice in enumerate(indices)]
),
complex,
)

# Perform Hadamard transformation on coloumns
hadamard_transform_mat = _walsh_hadamard_transform(qml.math.transpose(term_mat))

# Account for the phases from Y
phase_mat = qml.math.ones(shape, dtype=complex).reshape((2,) * (2 * num_qubits))
for idx in range(num_qubits):
index = [slice(None)] * (2 * num_qubits)
index[idx] = index[idx + num_qubits] = 1
phase_mat[tuple(index)] *= 1j
phase_mat = qml.math.convert_like(qml.math.reshape(phase_mat, shape), matrix)

# c_00 + c_11 -> I; c_00 - c_11 -> Z; c_01 + c_10 -> X; 1j*(c_10 - c_01) -> Y
# https://quantumcomputing.stackexchange.com/a/31790
term_mat = qml.math.transpose(qml.math.multiply(hadamard_transform_mat, phase_mat))

# Obtain the coefficients for each Pauli word
coeffs, obs = [], []
for pauli_rep in product("IXYZ", repeat=num_qubits):
bit_array = qml.math.array(
[[(rep in "YZ"), (rep in "XY")] for rep in pauli_rep], dtype=int
).T
coefficient = term_mat[tuple(int("".join(map(str, x)), 2) for x in bit_array)]

if not qml.math.allclose(coefficient, 0):
observables = (
[(o, w) for w, o in zip(wire_order, pauli_rep) if o != I]
if hide_identity and not all(t == I for t in pauli_rep)
else [(o, w) for w, o in zip(wire_order, pauli_rep)]
)
if observables:
coeffs.append(coefficient)
obs.append(observables)

if obs_term:
coeffs.append(coeff)
obs_lst.append(obs_term)
coeffs = qml.math.stack(coeffs)

# Convert to Hamiltonian and PauliSentence
if pauli:
return PauliSentence(
{
PauliWord({w: o for o, w in obs_n_wires}): coeff
for coeff, obs_n_wires in zip(coeffs, obs_lst)
for coeff, obs_n_wires in zip(coeffs, obs)
}
)

obs = [reduce(matmul, [op_map[o](w) for o, w in obs_term]) for obs_term in obs_lst]
obs = [reduce(matmul, [op_map[o](w) for o, w in obs_term]) for obs_term in obs]
return Hamiltonian(coeffs, obs)


Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ toml~=0.10
appdirs~=1.4
semantic_version~=2.10
dask[delayed]~=2022.4.1
autoray==0.3.1
autoray>=0.3.1
matplotlib~=3.5
opt_einsum~=3.3
requests~=2.31.0
Expand Down
Loading

0 comments on commit b5789db

Please sign in to comment.