Skip to content

Commit

Permalink
Fix TF tests.
Browse files Browse the repository at this point in the history
  • Loading branch information
vincentmr committed Jun 20, 2024
1 parent ea4e256 commit ee5542b
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 10 deletions.
17 changes: 9 additions & 8 deletions pennylane/devices/qubit/apply_operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -403,14 +403,9 @@ def apply_pauliz(op: qml.Z, state, is_state_batched: bool = False, debugger=None
def apply_phaseshift(op: qml.PhaseShift, state, is_state_batched: bool = False, debugger=None, **_):
"""Apply PhaseShift to state."""

axis = op.wires[0] + is_state_batched
n_dim = math.ndim(state)

if n_dim >= 9 and math.get_interface(state) == "tensorflow":
return apply_operation_tensordot(op, state, is_state_batched=is_state_batched)

params = qml.math.atleast_1d(op.parameters[0])
params = op.parameters[0]
if is_state_batched or (op.batch_size is not None and len(params) > 1):
params = math.atleast_1d(params)
slices = []
for i, p in enumerate(params):
slices.append(
Expand All @@ -422,11 +417,17 @@ def apply_phaseshift(op: qml.PhaseShift, state, is_state_batched: bool = False,
)
return math.stack(slices, axis=0)

axis = op.wires[0] + is_state_batched
n_dim = math.ndim(state)

if n_dim >= 9 and math.get_interface(state) == "tensorflow":
return apply_operation_tensordot(op, state, is_state_batched=is_state_batched)

sl_0 = _get_slice(0, axis, n_dim)
sl_1 = _get_slice(1, axis, n_dim)

# must be first state and then -1 because it breaks otherwise
state1 = math.cast(state[sl_1], dtype=complex) * math.exp(1.0j * params)
state1 = math.multiply(state[sl_1], math.exp(1j * params))
state = math.stack([state[sl_0], state1], axis=axis)
if op.batch_size == 1:
state = math.stack([state], axis=0)
Expand Down
4 changes: 2 additions & 2 deletions tests/devices/qubit/test_apply_operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -802,7 +802,7 @@ def test_broadcasted_op(self, op, method, ml_framework):
@pytest.mark.parametrize("op", unbroadcasted_ops)
def test_broadcasted_state(self, op, method, ml_framework):
"""Tests that unbatched operations are applied correctly to a batched state."""
state = np.ones((3, 2, 2, 2)) / np.sqrt(8)
state = np.ones((3, 2, 2, 2), dtype=complex) / np.sqrt(8)

res = method(op, qml.math.asarray(state, like=ml_framework), is_state_batched=True)
missing_wires = 3 - len(op.wires)
Expand All @@ -819,7 +819,7 @@ def test_broadcasted_op_broadcasted_state(self, op, method, ml_framework):
if method is apply_operation_tensordot:
pytest.skip("Tensordot doesn't support batched operator and batched state.")

state = np.ones((3, 2, 2, 2)) / np.sqrt(8)
state = np.ones((3, 2, 2, 2), dtype=complex) / np.sqrt(8)

res = method(op, qml.math.asarray(state, like=ml_framework), is_state_batched=True)
missing_wires = 3 - len(op.wires)
Expand Down

0 comments on commit ee5542b

Please sign in to comment.