diff --git a/pennylane/devices/qubit/apply_operation.py b/pennylane/devices/qubit/apply_operation.py index 6526bc22771..1fd21dfaee4 100644 --- a/pennylane/devices/qubit/apply_operation.py +++ b/pennylane/devices/qubit/apply_operation.py @@ -403,10 +403,16 @@ def apply_pauliz(op: qml.Z, state, is_state_batched: bool = False, debugger=None def apply_phaseshift(op: qml.PhaseShift, state, is_state_batched: bool = False, debugger=None, **_): """Apply PhaseShift to state.""" - params = op.parameters[0] + axis = op.wires[0] + is_state_batched + n_dim = math.ndim(state) + + if n_dim >= 9 and math.get_interface(state) == "tensorflow": + return apply_operation_tensordot(op, state, is_state_batched=is_state_batched) + + params = qml.math.atleast_1d(op.parameters[0]) if is_state_batched or (op.batch_size is not None and len(params) > 1): slices = [] - for i, p in enumerate(op.parameters[0]): + for i, p in enumerate(params): slices.append( apply_phaseshift( qml.PhaseShift(p, wires=op.wires), @@ -416,17 +422,11 @@ def apply_phaseshift(op: qml.PhaseShift, state, is_state_batched: bool = False, ) return math.stack(slices, axis=0) - axis = op.wires[0] + is_state_batched - n_dim = math.ndim(state) - - if n_dim >= 9 and math.get_interface(state) == "tensorflow": - return apply_operation_tensordot(op, state, is_state_batched=is_state_batched) - sl_0 = _get_slice(0, axis, n_dim) sl_1 = _get_slice(1, axis, n_dim) # must be first state and then -1 because it breaks otherwise - state1 = math.multiply(state[sl_1], math.exp(1j * op.parameters[0])) + state1 = math.cast(state[sl_1], dtype=complex) * math.exp(1.0j * params) state = math.stack([state[sl_0], state1], axis=axis) if op.batch_size == 1: state = math.stack([state], axis=0)