Skip to content

Commit

Permalink
Apply suggestions from code review
Browse files Browse the repository at this point in the history
Co-authored-by: Vincent Michaud-Rioux <vincentm@nanoacademic.com>
  • Loading branch information
albi3ro and vincentmr authored May 10, 2024
1 parent d7534ab commit e620559
Showing 1 changed file with 9 additions and 11 deletions.
20 changes: 9 additions & 11 deletions pennylane/ops/qubit/observables.py
Original file line number Diff line number Diff line change
Expand Up @@ -505,15 +505,15 @@ def compute_matrix(basis_state): # pylint: disable=arguments-differ
[0. 0. 0. 0.]
[0. 0. 0. 0.]]
"""
shape = (2 ** len(basis_state), 2 ** len(basis_state))
if qml.math.get_interface(basis_state) == "jax":
basis_state += 0 # convert to int
n = len(basis_state)
mask = qml.math.flip(2 ** qml.math.arange(n))
idx = qml.math.sum(mask * basis_state)
mat = qml.math.zeros((2**n, 2**n), like=basis_state)
idx = 0
for i, m in enumerate(basis_state):
idx = idx + (m << (len(basis_state) - i - 1))
mat = qml.math.zeros(shape, like=basis_state)
return mat.at[idx, idx].set(1.0)

m = np.zeros((2 ** len(basis_state), 2 ** len(basis_state)))
m = np.zeros(shape)
idx = int("".join(str(i) for i in basis_state), 2)
m[idx, idx] = 1
return m
Expand Down Expand Up @@ -545,11 +545,9 @@ def compute_eigvals(basis_state): # pylint: disable=arguments-differ
[0. 1. 0. 0.]
"""
if qml.math.get_interface(basis_state) == "jax":
basis_state += 0 # convert to int
mask = 2 ** np.arange(len(basis_state) - 1, -1, -1)
mask = qml.math.asarray(mask, like=basis_state)
mask = qml.math.cast_like(mask, basis_state)
idx = qml.math.sum(mask * basis_state)
idx = 0
for i, m in enumerate(basis_state):
idx = idx + (m << (len(basis_state) - i - 1))
eigvals = qml.math.zeros(2 ** len(basis_state), like=basis_state)
return eigvals.at[idx].set(1.0)
w = np.zeros(2 ** len(basis_state))
Expand Down

0 comments on commit e620559

Please sign in to comment.