Skip to content

Commit

Permalink
Deal with PhaseShift batched.
Browse files Browse the repository at this point in the history
  • Loading branch information
vincentmr committed Jun 20, 2024
1 parent d4600b7 commit 2d9bfe4
Showing 1 changed file with 17 additions and 1 deletion.
18 changes: 17 additions & 1 deletion pennylane/devices/qubit/apply_operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -403,6 +403,19 @@ def apply_pauliz(op: qml.Z, state, is_state_batched: bool = False, debugger=None
def apply_phaseshift(op: qml.PhaseShift, state, is_state_batched: bool = False, debugger=None, **_):
"""Apply PhaseShift to state."""

params = op.parameters[0]
if is_state_batched or (op.batch_size is not None and len(params) > 1):
slices = []
for i, p in enumerate(op.parameters[0]):
slices.append(
apply_phaseshift(
qml.PhaseShift(p, wires=op.wires),
state[i] if is_state_batched else state,
is_state_batched=False,
)
)
return math.stack(slices, axis=0)

axis = op.wires[0] + is_state_batched
n_dim = math.ndim(state)

Expand All @@ -414,7 +427,10 @@ def apply_phaseshift(op: qml.PhaseShift, state, is_state_batched: bool = False,

# must be first state and then -1 because it breaks otherwise
state1 = math.multiply(state[sl_1], math.exp(1j * op.parameters[0]))
return math.stack([state[sl_0], state1], axis=axis)
state = math.stack([state[sl_0], state1], axis=axis)
if op.batch_size == 1:
state = math.stack([state], axis=0)
return state


@apply_operation.register
Expand Down

0 comments on commit 2d9bfe4

Please sign in to comment.