Skip to content

Commit

Permalink
fix Tensor with jitted projector (#5720)
Browse files Browse the repository at this point in the history
**Context:**

We are getting legacy op math failures:


https://github.com/PennyLaneAI/pennylane/actions/runs/9167752473/job/25205404890

Due to PR #5595 . When `Projector` started letting tracers through,
`Tensor` was unable to handle them.

**Description of the Change:**

Updates `Tensor.matrix` to handle matrices that have an ML interface.

**Benefits:**

No more bugs.

**Possible Drawbacks:**

**Related GitHub Issues:**
  • Loading branch information
albi3ro authored May 22, 2024
1 parent 3c08049 commit 48f21d8
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 2 deletions.
3 changes: 3 additions & 0 deletions doc/releases/changelog-dev.md
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,9 @@

<h3>Bug fixes 🐛</h3>

* The legacy `Tensor` class can now handle a `Projector` with abstract tracer input.
[(#5720)](https://github.com/PennyLaneAI/pennylane/pull/5720)

* Fixed a bug that raised an error regarding expected vs actual `dtype` when using `JAX-JIT` on a circuit that
returned samples of observables containing the `qml.Identity` operator.
[(#5607)](https://github.com/PennyLaneAI/pennylane/pull/5607)
Expand Down
4 changes: 2 additions & 2 deletions pennylane/operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -2409,7 +2409,7 @@ def matrix(self, wire_order=None):
# append diagonalizing unitary for specific wire to U_list
U_list.append(mats[0])

mat_size = np.prod([np.shape(mat)[0] for mat in U_list])
mat_size = np.prod([qml.math.shape(mat)[0] for mat in U_list])
wire_size = 2 ** len(self.wires)
if mat_size != wire_size:
if partial_overlap:
Expand All @@ -2428,7 +2428,7 @@ def matrix(self, wire_order=None):

# Return the Hermitian matrix representing the observable
# over the defined wires.
return functools.reduce(np.kron, U_list)
return functools.reduce(qml.math.kron, U_list)

def check_wires_partial_overlap(self):
r"""Tests whether any two observables in the Tensor have partially
Expand Down
14 changes: 14 additions & 0 deletions tests/test_operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -2061,6 +2061,20 @@ def test_matmul_not_implemented(self):
with pytest.raises(TypeError, match="unsupported operand type"):
_ = op @ 1.0

@pytest.mark.jax
def test_matrix_jax_projector(self):
"""Test that matrix can be computed with a jax projector."""

import jax

def f(state):
op = qml.Projector(state, wires=0)
return qml.operation.Tensor(op, qml.Z(1)).matrix()

res = jax.jit(f)(jax.numpy.array([0, 1]))
expected = np.array([[0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 1, 0], [0, 0, 0, -1]])
assert qml.math.allclose(res, expected)


with qml.operation.disable_new_opmath_cm():
equal_obs = [
Expand Down

0 comments on commit 48f21d8

Please sign in to comment.