Skip to content

Commit

Permalink
Add logic to verify that csc_dot_product is usable
Browse files Browse the repository at this point in the history
  • Loading branch information
astralcai committed Sep 17, 2024
1 parent 7e21a76 commit aebe7fc
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 2 deletions.
22 changes: 20 additions & 2 deletions pennylane/devices/qubit/measure.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

from scipy.sparse import csr_matrix

import pennylane as qml
from pennylane import math
from pennylane.measurements import (
ExpectationMP,
Expand Down Expand Up @@ -195,13 +196,30 @@ def get_measurement_function(

backprop_mode = math.get_interface(state, *measurementprocess.obs.data) != "numpy"
if isinstance(measurementprocess.obs, (Hamiltonian, LinearCombination)):
# need to work out thresholds for when its faster to use "backprop mode" measurements
return sum_of_terms_method if backprop_mode else csr_dot_products

# need to work out thresholds for when it's faster to use "backprop mode"
if backprop_mode:
return sum_of_terms_method

if not all(obs.has_sparse_matrix for obs in measurementprocess.obs.terms()[1]):
return sum_of_terms_method

if isinstance(measurementprocess.obs, Hamiltonian) and any(
any(len(o.wires) > 1 for o in qml.operation.Tensor(op).obs)
for op in measurementprocess.obs.ops
):
return sum_of_terms_method

return csr_dot_products

if isinstance(measurementprocess.obs, Sum):
if backprop_mode:
# always use sum_of_terms_method for Sum observables in backprop mode
return sum_of_terms_method

if not all(obs.has_sparse_matrix() for _, obs in measurementprocess.obs.terms()):
return sum_of_terms_method

if (
measurementprocess.obs.has_overlapping_wires
and len(measurementprocess.obs.wires) > 7
Expand Down
12 changes: 12 additions & 0 deletions pennylane/operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -870,6 +870,18 @@ def compute_sparse_matrix(
"""
raise SparseMatrixUndefinedError

# pylint: disable=no-self-argument, comparison-with-callable
@classproperty
def has_sparse_matrix(cls) -> bool:
r"""Bool: Whether the Operator returns a defined sparse matrix.
Note: Child classes may have this as an instance property instead of as a class property.
"""
return (
cls.compute_sparse_matrix != Operator.compute_sparse_matrix
or cls.sparse_matrix != Operator.sparse_matrix
)

def sparse_matrix(self, wire_order: Optional[WiresLike] = None) -> csr_matrix:
r"""Representation of the operator as a sparse matrix in the computational basis.
Expand Down

0 comments on commit aebe7fc

Please sign in to comment.