Skip to content

Commit

Permalink
Support parameter broadcasting with GlobalPhase (#5923)
Browse files Browse the repository at this point in the history
**Context:**
`GlobalPhase` does not support broadcasting yet.

**Description of the Change:**
Introduce broadcasting to `GlobalPhase`, using the code from `PauliRot`
with `set(pauli_word) == {"I"}`.
The latter falls back to the `GlobalPhase` implementation.

Also fixes a small bug where the global phase returned by
`one_qubit_decomposition` gains a broadcasting dimension even if the
input matrix does not have one.

**Benefits:**
Broadcasting support & Bug fix
Unlocks #4460

**Possible Drawbacks:**

**Related GitHub Issues:**
Implements #5815 
Fixes #5880

[sc-65316]
  • Loading branch information
dwierichs committed Jul 30, 2024
1 parent 6e122ae commit 0515647
Show file tree
Hide file tree
Showing 5 changed files with 115 additions and 56 deletions.
10 changes: 9 additions & 1 deletion doc/releases/changelog-dev.md
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,9 @@

<h3>Improvements 🛠</h3>

* `GlobalPhase` now supports parameter broadcasting.
[(#5923)](https://github.com/PennyLaneAI/pennylane/pull/5923)

* `qml.devices.LegacyDeviceFacade` has been added to map the legacy devices to the new
device interface.
[(#5927)](https://github.com/PennyLaneAI/pennylane/pull/5927)
Expand Down Expand Up @@ -223,9 +226,13 @@
[(#5974)](https://github.com/PennyLaneAI/pennylane/pull/5974)

<h3>Bug fixes 🐛</h3>

* Fix `jax.grad` + `jax.jit` not working for `AmplitudeEmbedding`, `StatePrep` and `MottonenStatePreparation`.
[(#5620)](https://github.com/PennyLaneAI/pennylane/pull/5620)

* Fix a bug where the global phase returned by `one_qubit_decomposition` gained a broadcasting dimension.
[(#5923)](https://github.com/PennyLaneAI/pennylane/pull/5923)

* Fixed a bug in `qml.SPSAOptimizer` that ignored keyword arguments in the objective function.
[(#6027)](https://github.com/PennyLaneAI/pennylane/pull/6027)

Expand Down Expand Up @@ -275,4 +282,5 @@ Vincent Michaud-Rioux,
Anurav Modak,
Mudit Pandey,
Erik Schultheis,
nate stemen.
nate stemen,
David Wierichs,
49 changes: 40 additions & 9 deletions pennylane/ops/identity.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,13 @@
from scipy import sparse

import pennylane as qml
from pennylane.operation import AllWires, AnyWires, CVObservable, Operation
from pennylane.operation import (
AllWires,
AnyWires,
CVObservable,
Operation,
SparseMatrixUndefinedError,
)


class Identity(CVObservable, Operation):
Expand Down Expand Up @@ -227,6 +233,7 @@ class GlobalPhase(Operation):
* Number of wires: All (the operation acts on all wires)
* Number of parameters: 1
* Number of dimensions per parameter: (0,)
* Gradient recipe: None
Args:
Expand Down Expand Up @@ -282,11 +289,17 @@ def circuit():
"""

grad_method = None
num_params = 1
num_wires = AllWires
"""int: Number of wires that the operator acts on."""

num_params = 1
"""int: Number of trainable parameters that the operator depends on."""

ndim_params = (0,)
"""tuple[int]: Number of dimensions per trainable parameter that the operator depends on."""

grad_method = None

def __init__(self, phi, wires=None, id=None):
super().__init__(phi, wires=[] if wires is None else wires, id=id)

Expand All @@ -313,7 +326,15 @@ def compute_eigvals(phi, n_wires=1): # pylint: disable=arguments-differ
>>> qml.GlobalPhase.compute_eigvals(np.pi/2)
array([6.123234e-17+1.j, 6.123234e-17+1.j])
"""
return qml.math.exp(-1j * phi) * qml.math.ones(2**n_wires)
exp = qml.math.exp(-1j * phi)
ones = qml.math.ones(2**n_wires, like=phi)
if qml.math.get_interface(phi) == "tensorflow":
ones = qml.math.cast_like(ones, 1j)

if qml.math.ndim(phi) == 0:
return exp * ones

return qml.math.tensordot(exp, ones, axes=0)

@staticmethod
def compute_matrix(phi, n_wires=1): # pylint: disable=arguments-differ
Expand All @@ -333,15 +354,22 @@ def compute_matrix(phi, n_wires=1): # pylint: disable=arguments-differ
[0. +0.j , 0.70710678-0.70710678j]])
"""
interface = qml.math.get_interface(phi)
eye = qml.math.eye(2**n_wires, like=phi)
exp = qml.math.exp(-1j * qml.math.cast(phi, complex))
if interface == "tensorflow":
return qml.math.exp(-1j * qml.math.cast(phi, complex)) * qml.math.eye(int(2**n_wires))
return qml.math.exp(-1j * qml.math.cast(phi, complex)) * qml.math.eye(
int(2**n_wires), like=interface
)
eye = qml.math.cast_like(eye, 1j)
elif interface == "torch":
eye = eye.to(exp.device)

if qml.math.ndim(phi) == 0:
return exp * eye
return qml.math.tensordot(exp, eye, axes=0)

@staticmethod
def compute_sparse_matrix(phi, n_wires=1): # pylint: disable=arguments-differ
return qml.math.exp(-1j * phi) * sparse.eye(int(2**n_wires), format="csr")
if qml.math.ndim(phi) > 0:
raise SparseMatrixUndefinedError("Sparse matrices do not support broadcasting")
return qml.math.exp(-1j * phi) * sparse.eye(2**n_wires, format="csr")

@staticmethod
def compute_diagonalizing_gates(
Expand Down Expand Up @@ -403,6 +431,9 @@ def compute_decomposition(phi, wires=None): # pylint:disable=arguments-differ,u
"""
return []

def eigvals(self):
return self.compute_eigvals(self.data[0], n_wires=len(self.wires))

def matrix(self, wire_order=None):
n_wires = len(wire_order) if wire_order else len(self.wires)
return self.compute_matrix(self.data[0], n_wires=n_wires)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,7 @@ def _rot_decomposition(U, wire, return_global_phase=False):
angle = 2 * math.angle(U_det1[0, 1, 1]) % (4 * np.pi)
operations = [qml.RZ(angle, wires=wire)]
if return_global_phase:
alphas = math.squeeze(alphas)
operations.append(qml.GlobalPhase(-alphas))
return operations

Expand Down
23 changes: 2 additions & 21 deletions pennylane/ops/qubit/parametric_ops_multi_qubit.py
Original file line number Diff line number Diff line change
Expand Up @@ -368,18 +368,7 @@ def compute_matrix(theta, pauli_word): # pylint: disable=arguments-differ

# Simplest case is if the Pauli is the identity matrix
if set(pauli_word) == {"I"}:
exp = qml.math.exp(-0.5j * theta)
iden = qml.math.eye(2 ** len(pauli_word), like=theta)
if qml.math.get_interface(theta) == "tensorflow":
iden = qml.math.cast_like(iden, 1j)
if qml.math.get_interface(theta) == "torch":
td = exp.device
iden = iden.to(td)

if qml.math.ndim(theta) == 0:
return exp * iden

return qml.math.stack([e * iden for e in exp])
return qml.GlobalPhase.compute_matrix(0.5 * theta, n_wires=len(pauli_word))

# We first generate the matrix excluding the identity parts and expand it afterwards.
# To this end, we have to store on which wires the non-identity parts act
Expand Down Expand Up @@ -445,15 +434,7 @@ def compute_eigvals(theta, pauli_word): # pylint: disable=arguments-differ

# Identity must be treated specially because its eigenvalues are all the same
if set(pauli_word) == {"I"}:
exp = qml.math.exp(-0.5j * theta)
ones = qml.math.ones(2 ** len(pauli_word), like=theta)
if qml.math.get_interface(theta) == "tensorflow":
ones = qml.math.cast_like(ones, 1j)

if qml.math.ndim(theta) == 0:
return exp * ones

return qml.math.tensordot(exp, ones, axes=0)
return qml.GlobalPhase.compute_eigvals(0.5 * theta, n_wires=len(pauli_word))

return MultiRZ.compute_eigvals(theta, len(pauli_word))

Expand Down
88 changes: 63 additions & 25 deletions tests/ops/qubit/test_parametric_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,14 +269,13 @@ def test_phase_decomposition(self, phi, tol):
assert np.allclose(res[0].data[0], phi)

decomposed_matrix = res[0].matrix()
global_phase = np.exp(-1j * phi / 2)[..., np.newaxis, np.newaxis]
global_phase = np.exp(1j * phi / 2)[..., np.newaxis, np.newaxis]

assert res[1].name == "GlobalPhase"
assert np.allclose(qml.matrix(res[1]), np.exp(1j * phi / 2))
assert np.allclose(qml.matrix(res[1]), global_phase)

assert np.allclose(decomposed_matrix, global_phase * op.matrix(), atol=tol, rtol=0)
if qml.math.shape(phi) == (): # GlobalPhase matrix doesn't support batching
assert np.allclose(op.matrix(), qml.prod(*res[::-1]).matrix())
assert np.allclose(decomposed_matrix, (1 / global_phase) * op.matrix(), atol=tol, rtol=0)
assert np.allclose(op.matrix(), qml.prod(*res[::-1]).matrix())

def test_Rot_decomposition(self):
"""Test the decomposition of Rot."""
Expand Down Expand Up @@ -710,20 +709,42 @@ def test_phase_shift(self, tol):
assert np.allclose(qml.PhaseShift.compute_matrix(phi), expected, atol=tol, rtol=0)
assert np.allclose(qml.U1.compute_matrix(phi), expected, atol=tol, rtol=0)

def test_global_phase(self, tol):
@pytest.mark.parametrize("n_wires", [0, 1, 2])
def test_global_phase(self, tol, n_wires):
"""Test GlobalPhase matrix is correct"""

wires = list(range(n_wires))
eye = np.eye(2**n_wires)
eye2 = np.eye(2)
# test identity for theta=0
assert np.allclose(qml.GlobalPhase.compute_matrix(0), np.identity(2), atol=tol, rtol=0)
assert np.allclose(
qml.GlobalPhase(0).matrix(wire_order=[0]), np.identity(2), atol=tol, rtol=0
qml.GlobalPhase.compute_matrix(0, n_wires=n_wires), eye, atol=tol, rtol=0
)
assert np.allclose(qml.GlobalPhase.compute_matrix(0), eye2, atol=tol, rtol=0)
assert np.allclose(qml.GlobalPhase(0).matrix(wire_order=wires), eye, atol=tol, rtol=0)

# test arbitrary phase shift
# test arbitrary global phase
phi = 0.5432
expected = np.array([[qml.math.exp(-1j * phi), 0], [0, qml.math.exp(-1j * phi)]])
assert np.allclose(qml.GlobalPhase.compute_matrix(phi), expected, atol=tol, rtol=0)
assert np.allclose(qml.GlobalPhase(phi).matrix(wire_order=[0]), expected, atol=tol, rtol=0)
exp = np.exp(-1j * phi)
assert np.allclose(
qml.GlobalPhase.compute_matrix(phi, n_wires=n_wires), exp * eye, atol=tol, rtol=0
)
assert np.allclose(qml.GlobalPhase.compute_matrix(phi), exp * eye2, atol=tol, rtol=0)
assert np.allclose(
qml.GlobalPhase(phi).matrix(wire_order=wires), exp * eye, atol=tol, rtol=0
)

# test arbitrary broadcasted global phase with non-default n_wires=0
phi = np.array([0.5, 0.4, 0.3])
expected = np.tensordot(np.exp(-1j * phi), eye, axes=0)
expected2 = np.tensordot(np.exp(-1j * phi), eye2, axes=0)
assert np.allclose(
qml.GlobalPhase.compute_matrix(phi, n_wires=n_wires), expected, atol=tol, rtol=0
)
assert np.allclose(qml.GlobalPhase.compute_matrix(phi), expected2, atol=tol, rtol=0)
assert np.allclose(
qml.GlobalPhase(phi).matrix(wire_order=wires), expected, atol=tol, rtol=0
)

def test_identity(self, tol):
"""Test Identity matrix is correct with no wires"""
Expand Down Expand Up @@ -1698,19 +1719,29 @@ def test_pcphase_eigvals(self):
)
assert np.allclose(op.eigvals(), expected)

def test_global_phase_eigvals(self):
@pytest.mark.parametrize("n_wires", [0, 1, 2])
def test_global_phase_eigvals(self, n_wires):
"""Test GlobalPhase eigenvalues are correct"""

dim = 2**n_wires
# test identity for theta=0
op = qml.GlobalPhase(0.0)
assert np.allclose(op.compute_eigvals(*op.parameters, **op.hyperparameters), np.ones(2))
assert np.allclose(op.eigvals(), np.ones(2))
phi = 0.0
op = qml.GlobalPhase(phi, wires=list(range(n_wires)))
assert np.allclose(op.compute_eigvals(phi, n_wires=n_wires), np.ones(dim))
assert np.allclose(op.eigvals(), np.ones(dim))

# test arbitrary phase shift
# test arbitrary global phase
phi = 0.5432
op = qml.GlobalPhase(phi)
expected = np.array([np.exp(-1j * phi), np.exp(-1j * phi)])
assert np.allclose(op.compute_eigvals(*op.parameters, **op.hyperparameters), expected)
op = qml.GlobalPhase(phi, wires=list(range(n_wires)))
expected = np.array([np.exp(-1j * phi)] * dim)
assert np.allclose(op.compute_eigvals(phi, n_wires=n_wires), expected)
assert np.allclose(op.eigvals(), expected)

# test arbitrary broadcasted global phase
phi = np.array([0.5, 0.4, 0.3])
op = qml.GlobalPhase(phi, wires=list(range(n_wires)))
expected = np.array([np.exp(-1j * p) * np.ones(dim) for p in phi])
assert np.allclose(op.compute_eigvals(phi, n_wires=n_wires), expected)
assert np.allclose(op.eigvals(), expected)


Expand Down Expand Up @@ -2790,10 +2821,8 @@ def test_PauliRot_all_Identity_broadcasted(self):
qml.assert_equal(decomp_op, qml.GlobalPhase(theta / 2))

op_matrices = op.matrix()
decomp_op_matrices = decomp_op.matrix().T
assert len(op_matrices) == len(decomp_op_matrices)
for op_matrix, decomp_phase in zip(op_matrices, decomp_op_matrices):
assert qml.math.allclose(op_matrix, decomp_phase * np.eye(4))
decomp_op_matrices = decomp_op.matrix(wire_order=[0, 1])
assert qml.math.allclose(op_matrices, decomp_op_matrices)

@pytest.mark.parametrize("theta", [0.4, np.array([np.pi / 3, 0.1, -0.9])])
def test_PauliRot_decomposition_ZZ(self, theta):
Expand Down Expand Up @@ -3922,14 +3951,23 @@ def test_diagonalization_static_global_phase():
@pytest.mark.parametrize("phi", [0.123, np.pi / 4, 0])
@pytest.mark.parametrize("n_wires", [0, 1, 2])
def test_global_phase_compute_sparse_matrix(phi, n_wires):
"""Test that compute_sparse_matrix"""
"""Test compute_sparse_matrix"""

sparse_matrix = qml.GlobalPhase.compute_sparse_matrix(phi, n_wires=n_wires)
expected = np.exp(-1j * phi) * sparse.eye(int(2**n_wires), format="csr")

assert np.allclose(sparse_matrix.todense(), expected.todense())


@pytest.mark.parametrize("n_wires", [0, 1, 2])
def test_global_phase_compute_sparse_matrix_broadcasted_raises(n_wires):
"""Test that compute_sparse_matrix raises an error for broadcasted GlobalPhase"""

phi = np.array([0.123, np.pi / 4, 0])
with pytest.raises(qml.operation.SparseMatrixUndefinedError, match="broadcasting"):
_ = qml.GlobalPhase.compute_sparse_matrix(phi, n_wires=n_wires)


def test_decomposition():
"""Test the decomposition of the GlobalPhase operation."""

Expand Down

0 comments on commit 0515647

Please sign in to comment.