Skip to content

Commit

Permalink
Adjoint metric tensor uses devices.qubit instead of DefaultQubit
Browse files Browse the repository at this point in the history
…private methods (#4456)

* adjoint metric tensor uses devices.qubit

* Update doc/releases/changelog-dev.md

* black

* update legacy tests

* Update pennylane/transforms/adjoint_metric_tensor.py

Co-authored-by: David Wierichs <david.wierichs@xanadu.ai>

* Update pennylane/transforms/adjoint_metric_tensor.py

Co-authored-by: David Wierichs <david.wierichs@xanadu.ai>

* unmove things that didn't need to move

* fix tests

* Update pennylane/qnode.py

Co-authored-by: David Wierichs <david.wierichs@xanadu.ai>

* no need to convert generator matrix

---------

Co-authored-by: David Wierichs <david.wierichs@xanadu.ai>
  • Loading branch information
albi3ro and dwierichs authored Aug 14, 2023
1 parent 63a63e3 commit a4c31d6
Show file tree
Hide file tree
Showing 6 changed files with 135 additions and 485 deletions.
4 changes: 4 additions & 0 deletions doc/releases/changelog-dev.md
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,10 @@ array([False, False])
* Provide users access to the logging configuration file path and improve the logging configuration structure.
[(#4377)](https://github.com/PennyLaneAI/pennylane/pull/4377)

* `qml.transforms.adjoint_metric_tensor` now uses the simulation tools in `pennylane.devices.qubit` instead of
private methods of `pennylane.devices.DefaultQubit`.
[(#4456)](https://github.com/PennyLaneAI/pennylane/pull/4456)

* Updated `Device.default_expand_fn()` to decompose `StatePrep` operations present in the middle of a provided circuit.
[(#4437)](https://github.com/PennyLaneAI/pennylane/pull/4437)

Expand Down
5 changes: 4 additions & 1 deletion pennylane/ops/functions/map_wires.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,10 @@ def map_wires(
measurements = [qml.map_wires(m, wire_map) for m in input.measurements]
prep = [qml.map_wires(p, wire_map) for p in input._prep] # pylint: disable=protected-access

return input.__class__(ops=ops, measurements=measurements, prep=prep, shots=input.shots)
out = input.__class__(ops=ops, measurements=measurements, prep=prep, shots=input.shots)
out.trainable_params = input.trainable_params
out._qfunc_output = input._qfunc_output # pylint: disable=protected-access
return out

if callable(input):
func = input.func if isinstance(input, QNode) else input
Expand Down
96 changes: 38 additions & 58 deletions pennylane/transforms/adjoint_metric_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,41 +20,13 @@
import pennylane as qml
from pennylane import numpy as np

# pylint: disable=protected-access
# pylint: disable=too-many-statements
from pennylane.transforms.metric_tensor import _contract_metric_tensor_with_cjac


def _apply_operations(state, op, device, invert=False):
"""Wrapper that allows to apply a variety of operations---or groups
of operations---to a state or to prepare a new state.
If ``invert=True``, this function makes sure not to alter the operations.
The state of the device, however may be altered, depending on the
device and performed operation(s).
"""
# pylint: disable=protected-access
if isinstance(op, (list, np.ndarray)):
if invert:
op = op[::-1]
for _op in op:
state = _apply_operations(state, _op, device, invert)
return state

if isinstance(op, qml.QubitStateVector):
if invert:
raise ValueError("Can't invert state preparation.")
device._apply_state_vector(op.parameters[0], op.wires)
return device._state

if isinstance(op, qml.BasisState):
if invert:
raise ValueError("Can't invert state preparation.")
device._apply_basis_state(op.parameters[0], op.wires)
return device._state

apply_op = qml.adjoint(op) if invert else op
state = device._apply_operation(state, apply_op)

return state
def _reshape_real_imag(state, dim):
state = qml.math.reshape(state, (dim,))
return qml.math.real(state), qml.math.imag(state)


def _group_operations(tape):
Expand Down Expand Up @@ -157,56 +129,58 @@ def circuit(weights):
shot simulations.
"""
if isinstance(circuit, qml.tape.QuantumScript):
return _adjoint_metric_tensor_tape(circuit, device)
return _adjoint_metric_tensor_tape(circuit)
if isinstance(circuit, (qml.QNode, qml.ExpvalCost)):
return _adjoint_metric_tensor_qnode(circuit, device, hybrid)

raise qml.QuantumFunctionError("The passed object is not a QuantumTape or QNode.")


def _adjoint_metric_tensor_tape(tape, device):
def _adjoint_metric_tensor_tape(tape):
"""Computes the metric tensor of a tape using the adjoint method and a given device."""
# pylint: disable=protected-access
if device.shots is not None:
if tape.shots:
raise ValueError(
"The adjoint method for the metric tensor is only implemented for shots=None"
)
if set(tape.wires) != set(range(tape.num_wires)):
wire_map = {w: i for i, w in enumerate(tape.wires)}
tape = qml.map_wires(tape, wire_map)
tape = qml.transforms.expand_trainable_multipar(tape)

# Divide all operations of a tape into trainable operations and blocks
# of untrainable operations after each trainable one.
trainable_operations, group_after_trainable_op = _group_operations(tape)

dim = 2**device.num_wires
dim = 2**tape.num_wires
# generate and extract initial state
psi = device._create_basis_state(0)
prep = tape[0] if len(tape) > 0 and isinstance(tape[0], qml.operation.StatePrep) else None

interface = qml.math.get_interface(*tape.get_parameters(trainable_only=False))
psi = qml.devices.qubit.create_initial_state(tape.wires, prep, like=interface)

# initialize metric tensor components (which all will be real-valued)
like_real = qml.math.real(psi[0])
L = qml.math.convert_like(qml.math.zeros((tape.num_params, tape.num_params)), like_real)
T = qml.math.convert_like(qml.math.zeros((tape.num_params,)), like_real)

psi = _apply_operations(psi, group_after_trainable_op[-1], device)
for op in group_after_trainable_op[-1]:
psi = qml.devices.qubit.apply_operation(op, psi)

for j, outer_op in enumerate(trainable_operations):
generator_1, prefactor_1 = qml.generator(outer_op)
generator_1 = qml.matrix(generator_1)

# the state vector phi is missing a factor of 1j * prefactor_1
phi = device._apply_unitary(
psi, qml.math.convert_like(generator_1, like_real), outer_op.wires
)
phi = qml.devices.qubit.apply_operation(generator_1, psi)

phi_real = qml.math.reshape(qml.math.real(phi), (dim,))
phi_imag = qml.math.reshape(qml.math.imag(phi), (dim,))
phi_real, phi_imag = _reshape_real_imag(phi, dim)
diag_value = prefactor_1**2 * (
qml.math.dot(phi_real, phi_real) + qml.math.dot(phi_imag, phi_imag)
)
L = qml.math.scatter_element_add(L, (j, j), diag_value)

lam = psi * 1.0
lam_real = qml.math.reshape(qml.math.real(lam), (dim,))
lam_imag = qml.math.reshape(qml.math.imag(lam), (dim,))
lam_real, lam_imag = _reshape_real_imag(lam, dim)

# this entry is missing a factor of 1j
value = prefactor_1 * (qml.math.dot(lam_real, phi_real) + qml.math.dot(lam_imag, phi_imag))
Expand All @@ -215,20 +189,23 @@ def _adjoint_metric_tensor_tape(tape, device):
for i in range(j - 1, -1, -1):
# after first iteration of inner loop: apply U_{i+1}^\dagger
if i < j - 1:
phi = _apply_operations(phi, trainable_operations[i + 1], device, invert=True)
phi = qml.devices.qubit.apply_operation(
qml.adjoint(trainable_operations[i + 1], lazy=False), phi
)
# apply V_{i}^\dagger
phi = _apply_operations(phi, group_after_trainable_op[i], device, invert=True)
lam = _apply_operations(lam, group_after_trainable_op[i], device, invert=True)
for op in reversed(group_after_trainable_op[i]):
adj_op = qml.adjoint(op, lazy=False)
phi = qml.devices.qubit.apply_operation(adj_op, phi)
lam = qml.devices.qubit.apply_operation(adj_op, lam)

inner_op = trainable_operations[i]
# extract and apply G_i
generator_2, prefactor_2 = qml.generator(inner_op)
generator_2 = qml.matrix(generator_2)
# this state vector is missing a factor of 1j * prefactor_2
mu = device._apply_unitary(lam, qml.math.convert_like(generator_2, lam), inner_op.wires)
phi_real = qml.math.reshape(qml.math.real(phi), (dim,))
phi_imag = qml.math.reshape(qml.math.imag(phi), (dim,))
mu_real = qml.math.reshape(qml.math.real(mu), (dim,))
mu_imag = qml.math.reshape(qml.math.imag(mu), (dim,))
mu = qml.devices.qubit.apply_operation(generator_2, lam)

phi_real, phi_imag = _reshape_real_imag(phi, dim)
mu_real, mu_imag = _reshape_real_imag(mu, dim)
# this entry is missing a factor of 1j * (-1j) = 1, i.e. none
value = (
prefactor_1
Expand All @@ -239,10 +216,12 @@ def _adjoint_metric_tensor_tape(tape, device):
L, [(i, j), (j, i)], value * qml.math.convert_like(qml.math.ones((2,)), value)
)
# apply U_i^\dagger
lam = _apply_operations(lam, inner_op, device, invert=True)
lam = qml.devices.qubit.apply_operation(qml.adjoint(inner_op, lazy=False), lam)

# apply U_j and V_j
psi = _apply_operations(psi, [outer_op, *group_after_trainable_op[j]], device)
psi = qml.devices.qubit.apply_operation(outer_op, psi)
for op in group_after_trainable_op[j]:
psi = qml.devices.qubit.apply_operation(op, psi)

# postprocessing: combine L and T into the metric tensor.
# We require outer(conj(T), T) here, but as we skipped the factor 1j above,
Expand Down Expand Up @@ -278,7 +257,8 @@ def wrapper(*args, **kwargs):
)

qnode.construct(args, kwargs)
mt = _adjoint_metric_tensor_tape(qnode.qtape, device)
batch, _, _ = qml.devices.qubit.preprocess((qnode.tape,))
mt = _adjoint_metric_tensor_tape(batch[0])

if old_interface == "auto":
qnode.interface = "auto"
Expand Down
161 changes: 1 addition & 160 deletions tests/legacy/test_legacy_adjoint_metric_tensor_old.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,165 +18,6 @@
from pennylane import numpy as np
import pennylane as qml

from pennylane.transforms.adjoint_metric_tensor import _apply_operations


class TestApplyOperations:
"""Tests the application of operations via the helper function
_apply_operations used in the adjoint metric tensor."""

device = qml.device("default.qubit", wires=2)
x = 0.5

def test_simple_operation(self):
"""Test that an operation is applied correctly."""
op = qml.RX(self.x, wires=0)
out = _apply_operations(self.device._state, op, self.device)
out = qml.math.reshape(out, 4)
exp = np.array([np.cos(self.x / 2), 0.0, -1j * np.sin(self.x / 2), 0.0])
assert np.allclose(out, exp)

def test_operation_group(self):
"""Test that a group of operations with is applied correctly
but does not alter the operations (in particular their order) that are used."""
op = [qml.adjoint(qml.RX(self.x, wires=0)), qml.Hadamard(wires=1), qml.CNOT(wires=[1, 0])]
out = _apply_operations(self.device._state, op, self.device)
out = qml.math.reshape(out, 4)
exp = np.array(
[
np.cos(self.x / 2) / np.sqrt(2),
1j * np.sin(self.x / 2) / np.sqrt(2),
1j * np.sin(self.x / 2) / np.sqrt(2),
np.cos(self.x / 2) / np.sqrt(2),
]
)
assert np.allclose(out, exp)
assert qml.equal(op[0], qml.adjoint(qml.RX(self.x, wires=0)))
assert isinstance(op[1], qml.Hadamard)
assert isinstance(op[2], qml.CNOT)

def test_qubit_statevector(self):
"""Test that a statevector preparation is applied correctly."""
state = np.array([0.4, 1.2 - 0.2j, 9.5, -0.3 + 1.1j])
state /= np.linalg.norm(state, ord=2)
op = qml.QubitStateVector(state, wires=self.device.wires)
out = _apply_operations(None, op, self.device, invert=False)
out = qml.math.reshape(out, 4)
assert np.allclose(out, state)

def test_error_qubit_statevector(self):
"""Test that an error is raised for a statevector preparation with invert=True."""
state = np.array([0.4, 1.2 - 0.2j, 9.5, -0.3 + 1.1j])
state = np.array([0.4, 1.2 - 0.2j, 9.5, -0.3 + 1.1j])
state /= np.linalg.norm(state, ord=2)
op = qml.QubitStateVector(state, wires=self.device.wires)
with pytest.raises(ValueError, match="Can't invert state preparation."):
_apply_operations(None, op, self.device, invert=True)

def test_basisstate(self):
"""Test that a basis state preparation is applied correctly."""
op = qml.BasisState(np.array([1, 0]), wires=self.device.wires)
out = _apply_operations(None, op, self.device, invert=False)
out = qml.math.reshape(out, 4)
exp = np.array([0.0, 0.0, 1.0, 0.0])
assert np.allclose(out, exp)

def test_error_basisstate(self):
"""Test that an error is raised for a basis state preparation with invert=True."""
op = qml.BasisState(np.array([1, 0]), wires=self.device.wires)
with pytest.raises(ValueError, match="Can't invert state preparation."):
_apply_operations(None, op, self.device, invert=True)


@pytest.mark.parametrize("invert", [False, True])
class TestApplyOperationsDifferentiability:
"""Tests the differentiability of applying operations via the helper function
_apply_operations used in the adjoint metric tensor."""

x = 0.5

@pytest.mark.autograd
def test_simple_operation_autograd(self, invert):
"""Test differentiability for a simple operation with Autograd."""
device = qml.device("default.qubit.autograd", wires=2)
x = np.array(self.x, requires_grad=True)
r_fn = lambda x: qml.math.real(
_apply_operations(device._state, qml.RX(x, wires=0), device, invert)
)
i_fn = lambda x: qml.math.imag(
_apply_operations(device._state, qml.RX(x, wires=0), device, invert)
)
out = qml.jacobian(r_fn)(x) + 1j * qml.jacobian(i_fn)(x)
exp = (
np.array([[-np.sin(self.x / 2), 0.0], [-1j * (-1) ** invert * np.cos(self.x / 2), 0.0]])
/ 2
)
assert np.allclose(out, exp)

@pytest.mark.jax
def test_simple_operation_jax(self, invert):
"""Test differentiability for a simple operation with JAX."""
import jax

device = qml.device("default.qubit.jax", wires=2)
x = jax.numpy.array(self.x)
r_fn = lambda x: qml.math.real(
_apply_operations(device._state, qml.RX(x, wires=0), device, invert)
)
i_fn = lambda x: qml.math.imag(
_apply_operations(device._state, qml.RX(x, wires=0), device, invert)
)
out = jax.jacobian(r_fn)(x) + 1j * jax.jacobian(i_fn)(x)
exp = (
np.array([[-np.sin(self.x / 2), 0.0], [-1j * (-1) ** invert * np.cos(self.x / 2), 0.0]])
/ 2
)
assert np.allclose(out, exp)

@pytest.mark.tf
def test_simple_operation_tf(self, invert):
"""Test differentiability for a simple operation with TensorFlow."""
import tensorflow as tf

device = qml.device("default.qubit.tf", wires=2)
x = tf.Variable(self.x, dtype=tf.float64)
r_fn = lambda x: qml.math.real(
_apply_operations(device._state, qml.RX(x, wires=0), device, invert)
)
i_fn = lambda x: qml.math.imag(
_apply_operations(device._state, qml.RX(x, wires=0), device, invert)
)
with tf.GradientTape(persistent=True) as tape:
r_state = r_fn(x)
i_state = i_fn(x)
out = qml.math.complex(tape.jacobian(r_state, x), tape.jacobian(i_state, x))
exp = (
np.array([[-np.sin(self.x / 2), 0.0], [-1j * (-1) ** invert * np.cos(self.x / 2), 0.0]])
/ 2
)
assert np.allclose(out, exp)

@pytest.mark.torch
def test_simple_operation_torch(self, invert):
"""Test differentiability for a simple operation with Torch."""
import torch

jac_fn = torch.autograd.functional.jacobian
device = qml.device("default.qubit.torch", wires=2)
x = torch.tensor(self.x, requires_grad=True)
r_fn = lambda x: qml.math.real(
_apply_operations(device._state, qml.RX(x, wires=0), device, invert)
)
i_fn = lambda x: qml.math.imag(
_apply_operations(device._state, qml.RX(x, wires=0), device, invert)
)
out = jac_fn(r_fn, x) + 1j * jac_fn(i_fn, x)
exp = (
np.array([[-np.sin(self.x / 2), 0.0], [-1j * (-1) ** invert * np.cos(self.x / 2), 0.0]])
/ 2
)
assert np.allclose(out, exp)


fixed_pars = [-0.2, 0.2, 0.5, 0.3, 0.7]

Expand Down Expand Up @@ -802,7 +643,7 @@ def test_error_finite_shots(self):
with qml.queuing.AnnotatedQueue() as q:
qml.RX(0.2, wires=0)
qml.RY(1.9, wires=1)
tape = qml.tape.QuantumScript.from_queue(q)
tape = qml.tape.QuantumScript.from_queue(q, shots=1)
dev = qml.device("default.qubit", wires=2, shots=1)

with pytest.raises(ValueError, match="The adjoint method for the metric tensor"):
Expand Down
4 changes: 4 additions & 0 deletions tests/ops/functions/test_map_wires.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,10 +105,12 @@ def test_map_wires_tape(self, shots):
build_op()

tape = QuantumScript.from_queue(q_tape, shots=shots)
tape.trainable_params = [0, 2]
# TODO: Use qml.equal when supported

s_tape = qml.map_wires(tape, wire_map=wire_map)
assert len(s_tape) == 1
assert s_tape.trainable_params == [0, 2]
s_op = s_tape[0]
assert isinstance(s_op, qml.ops.Prod)
assert s_op.data == mapped_op.data
Expand All @@ -125,9 +127,11 @@ def test_execute_mapped_tape(self, shots):
qml.expval(op=qml.PauliZ(1))

tape = QuantumScript.from_queue(q_tape, shots=shots)
tape._qfunc_output = (qml.expval(qml.PauliZ(1)),) # pylint: disable=protected-access
# TODO: Use qml.equal when supported

m_tape = qml.map_wires(tape, wire_map=wire_map)
assert m_tape._qfunc_output is tape._qfunc_output # pylint: disable=protected-access
m_op = m_tape.operations[0]
m_obs = m_tape.observables[0]
assert isinstance(m_op, qml.ops.Prod)
Expand Down
Loading

0 comments on commit a4c31d6

Please sign in to comment.