diff --git a/pennylane/devices/qubit/apply_operation.py b/pennylane/devices/qubit/apply_operation.py index 1fd21dfaee4..c64725f9e9b 100644 --- a/pennylane/devices/qubit/apply_operation.py +++ b/pennylane/devices/qubit/apply_operation.py @@ -403,14 +403,9 @@ def apply_pauliz(op: qml.Z, state, is_state_batched: bool = False, debugger=None def apply_phaseshift(op: qml.PhaseShift, state, is_state_batched: bool = False, debugger=None, **_): """Apply PhaseShift to state.""" - axis = op.wires[0] + is_state_batched - n_dim = math.ndim(state) - - if n_dim >= 9 and math.get_interface(state) == "tensorflow": - return apply_operation_tensordot(op, state, is_state_batched=is_state_batched) - - params = qml.math.atleast_1d(op.parameters[0]) + params = op.parameters[0] if is_state_batched or (op.batch_size is not None and len(params) > 1): + params = math.atleast_1d(params) slices = [] for i, p in enumerate(params): slices.append( @@ -422,11 +417,17 @@ def apply_phaseshift(op: qml.PhaseShift, state, is_state_batched: bool = False, ) return math.stack(slices, axis=0) + axis = op.wires[0] + is_state_batched + n_dim = math.ndim(state) + + if n_dim >= 9 and math.get_interface(state) == "tensorflow": + return apply_operation_tensordot(op, state, is_state_batched=is_state_batched) + sl_0 = _get_slice(0, axis, n_dim) sl_1 = _get_slice(1, axis, n_dim) # must be first state and then -1 because it breaks otherwise - state1 = math.cast(state[sl_1], dtype=complex) * math.exp(1.0j * params) + state1 = math.multiply(state[sl_1], math.exp(1j * params)) state = math.stack([state[sl_0], state1], axis=axis) if op.batch_size == 1: state = math.stack([state], axis=0) diff --git a/tests/devices/qubit/test_apply_operation.py b/tests/devices/qubit/test_apply_operation.py index 3e517879b78..5d6f1da11bf 100644 --- a/tests/devices/qubit/test_apply_operation.py +++ b/tests/devices/qubit/test_apply_operation.py @@ -802,7 +802,7 @@ def test_broadcasted_op(self, op, method, ml_framework): @pytest.mark.parametrize("op", unbroadcasted_ops) def test_broadcasted_state(self, op, method, ml_framework): """Tests that unbatched operations are applied correctly to a batched state.""" - state = np.ones((3, 2, 2, 2)) / np.sqrt(8) + state = np.ones((3, 2, 2, 2), dtype=complex) / np.sqrt(8) res = method(op, qml.math.asarray(state, like=ml_framework), is_state_batched=True) missing_wires = 3 - len(op.wires) @@ -819,7 +819,7 @@ def test_broadcasted_op_broadcasted_state(self, op, method, ml_framework): if method is apply_operation_tensordot: pytest.skip("Tensordot doesn't support batched operator and batched state.") - state = np.ones((3, 2, 2, 2)) / np.sqrt(8) + state = np.ones((3, 2, 2, 2), dtype=complex) / np.sqrt(8) res = method(op, qml.math.asarray(state, like=ml_framework), is_state_batched=True) missing_wires = 3 - len(op.wires)