Skip to content

Commit

Permalink
Speed up Pauli sentence sparse matrix calculation (#4411)
Browse files Browse the repository at this point in the history
* fix wires and formatting

* sum same sparse structure

* sum different structure Pauli words

* Consistent error for wire order

* fix wires test

* replace `wires.union(other_wires)` by `+`

* change `wires` kwarg in `Tensor.sparse_matrix()`

* Fix conversion test with new wires

* replace `wires` by `wire_order`

* Update changelog

* update technical details

* fix remove zeros and add some comments

* buffer kwarg and tests

* precision 1e-16

* Improve readability

Co-authored-by: Christina Lee <christina@xanadu.ai>

* Improve readability

Co-authored-by: Christina Lee <christina@xanadu.ai>

---------

Co-authored-by: Christina Lee <christina@xanadu.ai>
  • Loading branch information
BorjaRequena and albi3ro committed Aug 11, 2023
1 parent 16dfee8 commit f48eb5f
Show file tree
Hide file tree
Showing 6 changed files with 287 additions and 37 deletions.
3 changes: 2 additions & 1 deletion doc/releases/changelog-dev.md
Original file line number Diff line number Diff line change
Expand Up @@ -88,8 +88,9 @@ array([False, False])
Instead, operators that need to be mutated are copied with new parameters.
[(#4220)](https://github.com/PennyLaneAI/pennylane/pull/4220)

* `PauliWord` sparse matrices are much faster, which directly improves `PauliSentence`.
* The calculation of `PauliWord` and `PauliSentence` sparse matrices are orders of magnitude faster.
[(#4272)](https://github.com/PennyLaneAI/pennylane/pull/4272)
[($4411)](https://github.com/PennyLaneAI/pennylane/pull/4411)

* Enable linting of all tests in CI and the pre-commit hook.
[(#4335)](https://github.com/PennyLaneAI/pennylane/pull/4335)
Expand Down
143 changes: 116 additions & 27 deletions pennylane/pauli/pauli_arithmetic.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@
from scipy import sparse

import pennylane as qml
from pennylane import math, wires
from pennylane import math
from pennylane.wires import Wires
from pennylane.operation import Tensor
from pennylane.ops import Hamiltonian, Identity, PauliX, PauliY, PauliZ, prod, s_prod

Expand Down Expand Up @@ -80,6 +81,20 @@ def _cached_arange(n):
return np.arange(n)


pauli_to_sparse_int = {I: 0, X: 1, Y: 1, Z: 0} # (I, Z) and (X, Y) have the same sparsity


def _ps_to_sparse_index(pauli_words, wires):
"""Represent the Pauli words sparse structure in a matrix of shape n_words x n_wires."""
indices = np.zeros((len(pauli_words), len(wires)))
for i, pw in enumerate(pauli_words):
if not pw.wires:
continue
wire_indices = np.array(wires.indices(pw.wires))
indices[i, wire_indices] = [pauli_to_sparse_int[pw[w]] for w in pw.wires]
return indices


_map_I = {
I: (1, I),
X: (1, X),
Expand Down Expand Up @@ -200,30 +215,36 @@ def __repr__(self):
@property
def wires(self):
"""Track wires in a PauliWord."""
return set(self)
return Wires(self)

def to_mat(self, wire_order, format="dense", coeff=1.0):
def to_mat(self, wire_order=None, format="dense", coeff=1.0):
"""Returns the matrix representation.
Keyword Args:
wire_order (iterable or None): The order of qubits in the tensor product.
format (str): The format of the matrix ("dense" by default), if not a dense
matrix, then the format for the sparse representation of the matrix.
coeff (float): Coefficient multiplying the resulting sparse matrix.
format (str): The format of the matrix. It is "dense" by default. Use "csr" for sparse.
coeff (float): Coefficient multiplying the resulting matrix.
Returns:
(Union[NumpyArray, ScipySparseArray]): Matrix representation of the Pauliword.
(Union[NumpyArray, ScipySparseArray]): Matrix representation of the Pauli word.
Raises:
ValueError: Can't get the matrix of an empty PauliWord.
"""
wire_order = self.wires if wire_order is None else Wires(wire_order)
if not wire_order.contains_wires(self.wires):
raise ValueError(
"Can't get the matrix for the specified wire order because it "
f"does not contain all the Pauli word's wires {self.wires}"
)

if len(self) == 0:
if wire_order is None or wire_order == wires.Wires([]):
if not wire_order:
raise ValueError("Can't get the matrix of an empty PauliWord.")
return (
np.diag([coeff] * 2 ** len(wire_order))
if format == "dense"
else coeff * sparse.eye(2 ** len(wire_order), format=format)
else coeff * sparse.eye(2 ** len(wire_order), format=format, dtype="complex128")
)

if format == "dense":
Expand All @@ -233,7 +254,7 @@ def to_mat(self, wire_order, format="dense", coeff=1.0):

def _to_sparse_mat(self, wire_order, coeff):
"""Compute the sparse matrix of the Pauli word times a coefficient, given a wire order.
See pauli_word_sparse_matrix.md for the technical details of the implementation."""
See pauli_sparse_matrices.md for the technical details of the implementation."""
full_word = [self[wire] for wire in wire_order]
matrix_size = 2 ** len(wire_order)
data = np.empty(matrix_size, dtype=np.complex128) # Non-zero values
Expand Down Expand Up @@ -261,14 +282,14 @@ def _to_sparse_mat(self, wire_order, coeff):
indices[current_size : 2 * current_size] = indices[:current_size] + current_size
current_size *= 2
# Avoid checks and copies in __init__ by directly setting the attributes of an empty matrix
matrix = sparse.csr_matrix((matrix_size, matrix_size), dtype=np.complex128)
matrix = sparse.csr_matrix((matrix_size, matrix_size), dtype="complex128")
matrix.data, matrix.indices, matrix.indptr = data, indices, indptr
return matrix

def operation(self, wire_order=None, get_as_tensor=False):
"""Returns a native PennyLane :class:`~pennylane.operation.Operation` representing the PauliWord."""
if len(self) == 0:
if wire_order in (None, [], wires.Wires([])):
if wire_order in (None, [], Wires([])):
raise ValueError("Can't get the operation for an empty PauliWord.")
return Identity(wires=wire_order)

Expand All @@ -281,7 +302,7 @@ def operation(self, wire_order=None, get_as_tensor=False):
def hamiltonian(self, wire_order=None):
"""Return :class:`~pennylane.Hamiltonian` representing the PauliWord."""
if len(self) == 0:
if wire_order in (None, [], wires.Wires([])):
if wire_order in (None, [], Wires([])):
raise ValueError("Can't get the Hamiltonian for an empty PauliWord.")
return Hamiltonian([1], [Identity(wires=wire_order)])

Expand Down Expand Up @@ -365,40 +386,48 @@ def __repr__(self):
@property
def wires(self):
"""Track wires of the PauliSentence."""
return set().union(*(pw.wires for pw in self.keys()))
return Wires(set().union(*(pw.wires for pw in self.keys())))

def to_mat(self, wire_order, format="dense"):
def to_mat(self, wire_order=None, format="dense", buffer_size=None):
"""Returns the matrix representation.
Keyword Args:
wire_order (iterable or None): The order of qubits in the tensor product.
format (str): The format of the matrix ("dense" by default), if not a dense
matrix, then the format for the sparse representation of the matrix.
format (str): The format of the matrix. It is "dense" by default. Use "csr" for sparse.
buffer_size (int or None): The maximum allowed memory in bytes to store intermediate results
in the calculation of sparse matrices. It defaults to ``2 ** 30`` bytes that make
1GB of memory. In general, larger buffers allow faster computations.
Returns:
(Union[NumpyArray, ScipySparseArray]): Matrix representation of the PauliSentence.
(Union[NumpyArray, ScipySparseArray]): Matrix representation of the Pauli sentence.
Rasies:
ValueError: Can't get the matrix of an empty PauliSentence.
"""
wire_order = self.wires if wire_order is None else Wires(wire_order)
if not wire_order.contains_wires(self.wires):
raise ValueError(
"Can't get the matrix for the specified wire order because it "
f"does not contain all the Pauli sentence's wires {self.wires}"
)

def _pw_wires(w: Iterable) -> wires.Wires:
def _pw_wires(w: Iterable) -> Wires:
"""Return the native Wires instance for a list of wire labels.
w represents the wires of the PauliWord being processed. In case
the PauliWord is empty ({}), choose any arbitrary wire from the
PauliSentence it is composed in.
"""
if w:
return wires.Wires(w)

return wires.Wires(list(self.wires)[0]) if len(self.wires) > 0 else wires.Wires([])
return w or Wires(self.wires[0]) if self.wires else self.wires

if len(self) == 0:
if wire_order is None or wire_order == wires.Wires([]):
if not wire_order:
raise ValueError("Can't get the matrix of an empty PauliSentence.")
if format == "dense":
return np.eye(2 ** len(wire_order))
return sparse.eye(2 ** len(wire_order), format=format)
return sparse.eye(2 ** len(wire_order), format=format, dtype="complex128")

if format != "dense":
return self._to_sparse_mat(wire_order, buffer_size=buffer_size)

mats_and_wires_gen = (
(
Expand All @@ -414,10 +443,70 @@ def _pw_wires(w: Iterable) -> wires.Wires:

return math.expand_matrix(reduced_mat, result_wire_order, wire_order=wire_order)

def _to_sparse_mat(self, wire_order, buffer_size=None):
"""Compute the sparse matrix of the Pauli sentence by efficiently adding the Pauli words
that conform it. See pauli_sparse_matrices.md for the technical details."""
pauli_words = list(self) # Ensure consistent ordering
n_wires = len(wire_order)
matrix_size = 2**n_wires
matrix = sparse.csr_matrix((matrix_size, matrix_size), dtype="complex128")
op_sparse_idx = _ps_to_sparse_index(pauli_words, wire_order)
_, unique_sparse_structures, unique_invs = np.unique(
op_sparse_idx, axis=0, return_index=True, return_inverse=True
)
pw_sparse_structures = unique_sparse_structures[unique_invs]

buffer_size = buffer_size or 2**30 # Default to 1GB of memory
# Convert bytes to number of matrices:
# complex128 (16) for each data entry and int64 (8) for each indices entry
buffer_size = max(1, buffer_size // ((16 + 8) * matrix_size))
mat_data = np.empty((matrix_size, buffer_size), dtype=np.complex128)
mat_indices = np.empty((matrix_size, buffer_size), dtype=np.int64)
n_matrices_in_buffer = 0
for sparse_structure in unique_sparse_structures:
indices, *_ = np.nonzero(pw_sparse_structures == sparse_structure)
mat = self._sum_same_structure_pws([pauli_words[i] for i in indices], wire_order)
mat_data[:, n_matrices_in_buffer] = mat.data
mat_indices[:, n_matrices_in_buffer] = mat.indices

n_matrices_in_buffer += 1
if n_matrices_in_buffer == buffer_size:
# Add partial results in batches to control the memory usage
matrix += self._sum_different_structure_pws(mat_indices, mat_data)
n_matrices_in_buffer = 0

matrix += self._sum_different_structure_pws(
mat_indices[:, :n_matrices_in_buffer], mat_data[:, :n_matrices_in_buffer]
)
return matrix

def _sum_same_structure_pws(self, pauli_words, wire_order):
"""Sums Pauli words with the same sparse structure."""
mat = pauli_words[0].to_mat(wire_order, coeff=self[pauli_words[0]], format="csr")
for word in pauli_words[1:]:
mat.data += word.to_mat(wire_order, coeff=self[word], format="csr").data
return mat

@staticmethod
def _sum_different_structure_pws(indices, data):
"""Sums Pauli words with different parse structures."""
size = indices.shape[0]
idx = np.argsort(indices, axis=1)
matrix = sparse.csr_matrix((size, size), dtype="complex128")
matrix.indices = np.take_along_axis(indices, idx, axis=1).ravel()
matrix.data = np.take_along_axis(data, idx, axis=1).ravel()
num_entries_per_row = indices.shape[1]
matrix.indptr = _cached_arange(size + 1) * num_entries_per_row

# remove zeros and things sufficiently close to zero
matrix.data[np.abs(matrix.data) < 1e-16] = 0 # Faster than np.isclose(matrix.data, 0)
matrix.eliminate_zeros()
return matrix

def operation(self, wire_order=None):
"""Returns a native PennyLane :class:`~pennylane.operation.Operation` representing the PauliSentence."""
if len(self) == 0:
if wire_order in (None, [], wires.Wires([])):
if wire_order in (None, [], Wires([])):
raise ValueError("Can't get the operation for an empty PauliSentence.")
return qml.s_prod(0, Identity(wires=wire_order))

Expand All @@ -431,7 +520,7 @@ def operation(self, wire_order=None):
def hamiltonian(self, wire_order=None):
"""Returns a native PennyLane :class:`~pennylane.Hamiltonian` representing the PauliSentence."""
if len(self) == 0:
if wire_order in (None, [], wires.Wires([])):
if wire_order in (None, [], Wires([])):
raise ValueError("Can't get the Hamiltonian for an empty PauliSentence.")
return Hamiltonian([], [])

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -190,3 +190,130 @@ Z &= \begin{cases}
\texttt{col}\leftarrow [\texttt{col}, \texttt{col}+m]
\end{cases}
\end{align*}$$

# Pauli sentence sparse representation

Pauli sentences are linear combinations of Pauli words.
Again, we can exploit the properties of Pauli words to add their sparse representations in the most efficient way, which is the key aspect to compute the sparse matrix of Pauli sentences.

In general, adding two arbitrary sparse matrices requires us to check all their NNZ entries.
If they happen to be in the same position, we add their values together, otherwise, we add them as a new entry to the resulting matrix.
Therefore, we can exploit the information about the matching NNZ entries between two Pauli words to directly manipulate their data.

## Add matrices with the same sparse structure first

As we have seen before, $I$ and $Z$ have the same sparse structure (same $\texttt{col}$ and $\texttt{row}$), and so do $X$ and $Y$.
Thus, words with the same combination of $I$ or $Z$ with $X$ or $Y$ have the NNZ entries in the same positions.
This means that all their NNZ will be in the same positions, which allows us to directly add their $\texttt{val}$ without any additional consideration!

For example, consider the case of the Pauli word $XZI$ from above, with $\texttt{val}=\left[1, 1, -1, -1, 1, 1, -1, -1\right]$ and $\texttt{col}=\left[4, 5, 6, 7, 0, 1, 2, 3\right]$.
The Pauli word $YZZ$ has the same sparse structure:

$$ YZZ = \left(\begin{array}{cccc|cccc}
& & & & -i & 0 & 0 & 0 \\
& & & & 0 & i & 0 & 0 \\
& & & & 0 & 0 & i & 0 \\
& & & & 0 & 0 & 0 & -i \\
\hline
i & 0 & 0 & 0 & & & & \\
0 & -i & 0 & 0 & & & & \\
0 & 0 & -i & 0 & & & & \\
0 & 0 & 0 & i & & & & \\
\end{array}\right),$$

with $\texttt{val}=\left[-i, i, i, -i, i, -i, -i, i\right]$ and $\texttt{col}=\left[4, 5, 6, 7, 0, 1, 2, 3\right]$.
The resulting matrix from $XZI + YZZ$ will have the same number of NNZ entries, preserving $\texttt{col}$ and $\texttt{row}$, and its $\texttt{val}$ will be the elementwise addition of the constituent $\texttt{val}$ arrays.

To identify Pauli words with the same sparse structure, we represnt the Pauli sentence as a matrix in which every row represents a Pauli word, and every column corresponds to a wire.
The entries denote the sparse structure of the Pauli operator acting on each wire: `0` for $I$ and $Z$, `1` for $X$ and $Y$.

For example:

$$XZI + YZZ + IIX + YYX \rightarrow
\begin{pmatrix}
1 & 0 & 0 \\
1 & 0 & 0 \\
0 & 0 & 1 \\
1 & 1 & 1
\end{pmatrix},
$$

which allows us to identify common patterns to add the matrices at nearly zero cost.

## Add matrices with different structure last

Given that Pauli words have a single NNZ per row and column, any pair of words that differ in (at least) one operator will not have any matching NNZ entry.

Hence, once we have added all the words with the same sparse structure, we can be certain that the intermediate results do not have any matching NNZ entries between themselves.
Again, we can exploit this information to add them by directly manipulating their representation data.

In this case, all the NNZ are new entries in the result matrix so we need to combine the $\texttt{col}$ and $\texttt{val}$ arrays of the summands.
In order to see this clearly, let us break down the addition of two small Pauli words with different structures, $XI$ and $ZZ$:

$$
\begin{align*}
XI &= \left(\begin{array}{cc|cc}
& & 1 & 0 \\
& & 0 & 1 \\
\hline
1 & 0 & & \\
0 & 1 & &
\end{array}\right)
&&= \begin{cases}
\texttt{val}=\left[1, 1, 1, 1\right]\\
\texttt{col}=\left[2, 3, 0, 1\right]
\end{cases} \\
\\
ZZ &= \left(\begin{array}{cc|cc}
1 & 0 & & \\
0 & -1 & & \\
\hline
& & -1 & 0 \\
& & 0 & 1
\end{array}\right)
&&= \begin{cases}
\texttt{val}=\left[1, -1, -1, 1\right]\\
\texttt{col}=\left[0, 1, 2, 3\right]
\end{cases}
\end{align*}
$$

The resulting matrix is

$$
XI + ZZ = \begin{pmatrix}
1 & 0 & 1 & 0 \\
0 & -1 & 0 & 1 \\
1 & 0 & -1 & 0 \\
0 & 1 & 0 & 1
\end{pmatrix}
=
\begin{cases}
\texttt{val}=\left[1, 1, -1, 1, 1, -1, 1, 1\right]\\
\texttt{col}=\left[0, 2, 1, 3, 0, 2, 1, 3\right]
\end{cases}
$$

The result is the entry-wise concatenation of $\texttt{col}$ and $\texttt{val}$ sorted by $\texttt{col}$.
To do so, we start by arranging the $\texttt{col}$ arrays in a matrix and sorting them row-wise.

$$
\texttt{col}_{XI, ZZ} =
\begin{bmatrix}
2 & 0 \\
3 & 1 \\
0 & 2 \\
1 & 3
\end{bmatrix}
\rightarrow
\begin{bmatrix}
0 & 2 \\
1 & 3 \\
0 & 2 \\
1 & 3
\end{bmatrix}
$$

Then, we concatenate the resulting matrix rows to obtain the final $\texttt{col}$.
We do the same with the $\texttt{val}$ arrays sorting them according to the $\texttt{col}$.
In `numpy` terms, we "take along axis" with the sorted $\texttt{col}$ indices.
Loading

0 comments on commit f48eb5f

Please sign in to comment.