From aebe7fcc0e7a6a3d423faf03b19829bc81109d6a Mon Sep 17 00:00:00 2001 From: Astral Cai Date: Tue, 17 Sep 2024 14:38:09 -0400 Subject: [PATCH] Add logic to verify that csc_dot_product is usable --- pennylane/devices/qubit/measure.py | 22 ++++++++++++++++++++-- pennylane/operation.py | 12 ++++++++++++ 2 files changed, 32 insertions(+), 2 deletions(-) diff --git a/pennylane/devices/qubit/measure.py b/pennylane/devices/qubit/measure.py index ae2ddb07c02..9547945d36a 100644 --- a/pennylane/devices/qubit/measure.py +++ b/pennylane/devices/qubit/measure.py @@ -18,6 +18,7 @@ from scipy.sparse import csr_matrix +import pennylane as qml from pennylane import math from pennylane.measurements import ( ExpectationMP, @@ -195,13 +196,30 @@ def get_measurement_function( backprop_mode = math.get_interface(state, *measurementprocess.obs.data) != "numpy" if isinstance(measurementprocess.obs, (Hamiltonian, LinearCombination)): - # need to work out thresholds for when its faster to use "backprop mode" measurements - return sum_of_terms_method if backprop_mode else csr_dot_products + + # need to work out thresholds for when it's faster to use "backprop mode" + if backprop_mode: + return sum_of_terms_method + + if not all(obs.has_sparse_matrix for obs in measurementprocess.obs.terms()[1]): + return sum_of_terms_method + + if isinstance(measurementprocess.obs, Hamiltonian) and any( + any(len(o.wires) > 1 for o in qml.operation.Tensor(op).obs) + for op in measurementprocess.obs.ops + ): + return sum_of_terms_method + + return csr_dot_products if isinstance(measurementprocess.obs, Sum): if backprop_mode: # always use sum_of_terms_method for Sum observables in backprop mode return sum_of_terms_method + + if not all(obs.has_sparse_matrix() for _, obs in measurementprocess.obs.terms()): + return sum_of_terms_method + if ( measurementprocess.obs.has_overlapping_wires and len(measurementprocess.obs.wires) > 7 diff --git a/pennylane/operation.py b/pennylane/operation.py index 77f989e6383..46a8d4e5bb7 100644 --- a/pennylane/operation.py +++ b/pennylane/operation.py @@ -870,6 +870,18 @@ def compute_sparse_matrix( """ raise SparseMatrixUndefinedError + # pylint: disable=no-self-argument, comparison-with-callable + @classproperty + def has_sparse_matrix(cls) -> bool: + r"""Bool: Whether the Operator returns a defined sparse matrix. + + Note: Child classes may have this as an instance property instead of as a class property. + """ + return ( + cls.compute_sparse_matrix != Operator.compute_sparse_matrix + or cls.sparse_matrix != Operator.sparse_matrix + ) + def sparse_matrix(self, wire_order: Optional[WiresLike] = None) -> csr_matrix: r"""Representation of the operator as a sparse matrix in the computational basis.