diff --git a/doc/releases/changelog-dev.md b/doc/releases/changelog-dev.md
index 56989c6ca54..9fe50728861 100644
--- a/doc/releases/changelog-dev.md
+++ b/doc/releases/changelog-dev.md
@@ -33,6 +33,9 @@
Improvements ðŸ›
+* Added the `compute_sparse_matrix` method for `qml.ops.qubit.BasisStateProjector`.
+ [(#5790)](https://github.com/PennyLaneAI/pennylane/pull/5790)
+
* `StateMP.process_state` defines rules in `cast_to_complex` for complex casting, avoiding a superfluous state vector copy in Lightning simulations
[(#5995)](https://github.com/PennyLaneAI/pennylane/pull/5995)
@@ -225,9 +228,9 @@ Josh Izaac,
Soran Jahangiri,
Christina Lee,
Austin Huang,
-Christina Lee,
William Maxwell,
Vincent Michaud-Rioux,
+Anurav Modak,
Mudit Pandey,
Erik Schultheis,
nate stemen.
diff --git a/pennylane/ops/qubit/observables.py b/pennylane/ops/qubit/observables.py
index b20e96dd478..9fd4ee43e65 100644
--- a/pennylane/ops/qubit/observables.py
+++ b/pennylane/ops/qubit/observables.py
@@ -584,6 +584,24 @@ def compute_diagonalizing_gates(
"""
return []
+ @staticmethod
+ def compute_sparse_matrix(basis_state): # pylint: disable=arguments-differ,unused-argument
+ """
+ Computes the sparse CSR matrix representation of the projector onto the basis state.
+
+ Args:
+ basis_state (Iterable): The basis state as an iterable of integers (0 or 1).
+
+ Returns:
+ scipy.sparse.csr_matrix: The sparse CSR matrix representation of the projector.
+ """
+
+ num_qubits = len(basis_state)
+ data = [1]
+ rows = [int("".join(str(bit) for bit in basis_state), 2)]
+ cols = rows
+ return csr_matrix((data, (rows, cols)), shape=(2**num_qubits, 2**num_qubits))
+
class StateVectorProjector(Projector):
r"""Observable corresponding to the state projector :math:`P=\ket{\phi}\bra{\phi}`, where
diff --git a/tests/ops/qubit/test_observables.py b/tests/ops/qubit/test_observables.py
index 3e291fe74c0..448080e8e25 100644
--- a/tests/ops/qubit/test_observables.py
+++ b/tests/ops/qubit/test_observables.py
@@ -19,6 +19,7 @@
import numpy as np
import pytest
from gate_data import H, I, X, Y, Z
+from scipy.sparse import csr_matrix
import pennylane as qml
from pennylane.ops.qubit.observables import BasisStateProjector, StateVectorProjector
@@ -540,6 +541,75 @@ def test_serialization(self):
qml.assert_equal(new_proj, proj)
assert new_proj.id == proj.id # Ensure they are identical
+ def test_single_qubit_basis_state_0(self):
+ """Tests the function with a single-qubit basis state |0>."""
+ basis_state = [0]
+ data = [1]
+ row_indices = [0]
+ col_indices = [0]
+ expected_matrix = csr_matrix((data, (row_indices, col_indices)), shape=(2, 2))
+
+ actual_matrix = BasisStateProjector.compute_sparse_matrix(basis_state)
+ actual_matrix = BasisStateProjector.compute_sparse_matrix(basis_state)
+
+ assert np.array_equal(expected_matrix.toarray(), actual_matrix.toarray())
+
+ def test_single_qubit_basis_state_1(self):
+ """Tests the function with a single-qubit basis state |1>."""
+ basis_state = [1]
+ data = [1]
+ row_indices = [1]
+ col_indices = [1]
+ expected_matrix = csr_matrix((data, (row_indices, col_indices)), shape=(2, 2))
+ actual_matrix = BasisStateProjector.compute_sparse_matrix(basis_state)
+ assert np.array_equal(expected_matrix.toarray(), actual_matrix.toarray())
+
+ def test_two_qubit_basis_state_10(self):
+ """Tests the function with a two-qubits basis state |10>."""
+ basis_state = [1, 0]
+ data = [1]
+ row_indices = [2]
+ col_indices = [2]
+ expected_matrix = csr_matrix((data, (row_indices, col_indices)), shape=(4, 4))
+ actual_matrix = BasisStateProjector.compute_sparse_matrix(basis_state)
+ assert np.array_equal(expected_matrix.toarray(), actual_matrix.toarray())
+
+ def test_two_qubit_basis_state_01(self):
+ """Tests the function with a two-qubits basis state |01>."""
+ basis_state = [0, 1]
+ data = [1]
+ row_indices = [1]
+ col_indices = [1]
+ expected_matrix = csr_matrix((data, (row_indices, col_indices)), shape=(4, 4))
+ actual_matrix = BasisStateProjector.compute_sparse_matrix(basis_state)
+ assert np.array_equal(expected_matrix.toarray(), actual_matrix.toarray())
+
+ def test_two_qubit_basis_state_11(self):
+ """Tests the function with a two-qubits basis state |11>."""
+ basis_state = [1, 1]
+ data = [1]
+ row_indices = [3]
+ col_indices = [3]
+ expected_matrix = csr_matrix((data, (row_indices, col_indices)), shape=(4, 4))
+ actual_matrix = BasisStateProjector.compute_sparse_matrix(basis_state)
+ assert np.array_equal(expected_matrix.toarray(), actual_matrix.toarray())
+
+ def test_three_qubit_basis_state_101(self):
+ """Tests the function with a three-qubits basis state |101>."""
+ basis_state = [1, 0, 1]
+ data = [1]
+ row_indices = [5]
+ col_indices = [5]
+ expected_matrix = csr_matrix((data, (row_indices, col_indices)), shape=(8, 8))
+ actual_matrix = BasisStateProjector.compute_sparse_matrix(basis_state)
+ assert np.array_equal(expected_matrix.toarray(), actual_matrix.toarray())
+
+ def test_invalid_basis_state(self):
+ """Tests the function with an invalid state."""
+ basis_state = [0, 2] # Invalid basis state
+ with pytest.raises(ValueError):
+ BasisStateProjector.compute_sparse_matrix(basis_state)
+
@pytest.mark.jax
def test_jit_measurement(self):
"""Test that the measurement of a projector can be jitted."""