Skip to content

Commit

Permalink
[BUG FIX] Add support for adjoint operations in default.qutrit (#4348)
Browse files Browse the repository at this point in the history
* Fixed adjoint op bug

* Added tests

* Update doc/releases/changelog-dev.md

* linting

* Update doc/releases/changelog-dev.md

Co-authored-by: Matthew Silverman <matthews@xanadu.ai>

* Added integration test

* Linting

---------

Co-authored-by: Matthew Silverman <matthews@xanadu.ai>
  • Loading branch information
mudit2812 and timmysilv authored Jul 12, 2023
1 parent 18d508a commit 712f99c
Show file tree
Hide file tree
Showing 3 changed files with 128 additions and 4 deletions.
3 changes: 3 additions & 0 deletions doc/releases/changelog-dev.md
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,9 @@
* `qml.qinfo.purity` now produces correct results with custom wire labels.
[#4331](https://github.com/PennyLaneAI/pennylane/pull/4331)

* `default.qutrit` now supports all qutrit operations used with `qml.adjoint`.
[(#4348)](https://github.com/PennyLaneAI/pennylane/pull/4348)

<h3>Contributors ✍️</h3>

This release contains contributions from (in alphabetical order):
Expand Down
12 changes: 11 additions & 1 deletion pennylane/devices/default_qutrit.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,10 +71,14 @@ class DefaultQutrit(QutritDevice):
"QutritUnitary",
"ControlledQutritUnitary",
"TShift",
"Adjoint(TShift)",
"TClock",
"Adjoint(TClock)",
"TAdd",
"Adjoint(TAdd)",
"TSWAP",
"THadamard",
"Adjoint(THadamard)",
"TRX",
"TRY",
"TRZ",
Expand Down Expand Up @@ -205,9 +209,15 @@ def _apply_operation(self, state, operation):
return state
wires = operation.wires

if operation.name in self._apply_ops:
if operation.name in self._apply_ops: # pylint: disable=no-else-return
axes = self.wires.indices(wires)
return self._apply_ops[operation.name](state, axes)
elif (
isinstance(operation, qml.ops.Adjoint) # pylint: disable=no-member
and operation.base.name in self._apply_ops
):
axes = self.wires.indices(wires)
return self._apply_ops[operation.base.name](state, axes, inverse=True)

matrix = self._asarray(self._get_unitary_matrix(operation), dtype=self.C_DTYPE)

Expand Down
117 changes: 114 additions & 3 deletions tests/devices/test_default_qutrit.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import pytest
from flaky import flaky
from gate_data import OMEGA, TSHIFT, TCLOCK, TSWAP, TADD, GELL_MANN
from scipy.stats import unitary_group
import pennylane as qml
from pennylane import numpy as np, DeviceError
from pennylane.wires import Wires, WireError
Expand Down Expand Up @@ -106,6 +107,24 @@ def test_apply_operation_single_wire_no_parameters(
assert np.allclose(qutrit_device_1_wire._state, np.array(expected_output), atol=tol, rtol=0)
assert qutrit_device_1_wire._state.dtype == qutrit_device_1_wire.C_DTYPE

@pytest.mark.parametrize("operation, expected_output, input, subspace", test_data_no_parameters)
def test_apply_operation_single_wire_no_parameters_adjoint(
self, qutrit_device_1_wire, tol, operation, input, expected_output, subspace
):
"""Tests that applying an adjoint operation yields the expected output state for single wire
operations that have no parameters."""
qutrit_device_1_wire._state = np.array(input, dtype=qutrit_device_1_wire.C_DTYPE)
qutrit_device_1_wire.apply(
[
qml.adjoint(operation(wires=[0]))
if subspace is None
else qml.adjoint(operation(wires=[0], subspace=subspace))
]
)

assert np.allclose(qutrit_device_1_wire._state, np.array(expected_output), atol=tol, rtol=0)
assert qutrit_device_1_wire._state.dtype == qutrit_device_1_wire.C_DTYPE

test_data_two_wires_no_parameters = [
(qml.TSWAP, [0, 1, 0, 0, 0, 0, 0, 0, 0], np.array([0, 0, 0, 1, 0, 0, 0, 0, 0]), None),
(
Expand Down Expand Up @@ -165,6 +184,30 @@ def test_apply_operation_two_wires_no_parameters(
)
assert qutrit_device_2_wires._state.dtype == qutrit_device_2_wires.C_DTYPE

@pytest.mark.parametrize(
"operation,expected_output,input, subspace", all_two_wires_no_parameters
)
def test_apply_operation_two_wires_no_parameters_adjoint(
self, qutrit_device_2_wires, tol, operation, input, expected_output, subspace
):
"""Tests that applying an adjoint operation yields the expected output state for two wire
operations that have no parameters."""
qutrit_device_2_wires._state = np.array(input, dtype=qutrit_device_2_wires.C_DTYPE).reshape(
(3, 3)
)
qutrit_device_2_wires.apply(
[
qml.adjoint(operation(wires=[0, 1]))
if subspace is None
else qml.adjoint(operation(wires=[0, 1], subspace=subspace))
]
)

assert np.allclose(
qutrit_device_2_wires._state.flatten(), np.array(expected_output), atol=tol, rtol=0
)
assert qutrit_device_2_wires._state.dtype == qutrit_device_2_wires.C_DTYPE

# TODO: Add more data as parametric ops get added
test_data_single_wire_with_parameters = [
(qml.QutritUnitary, [1, 0, 0], [1, 1, 0] / np.sqrt(2), [U_thadamard_01], None),
Expand Down Expand Up @@ -214,6 +257,23 @@ def test_apply_operation_single_wire_with_parameters(
assert np.allclose(qutrit_device_1_wire._state, np.array(expected_output), atol=tol, rtol=0)
assert qutrit_device_1_wire._state.dtype == qutrit_device_1_wire.C_DTYPE

@pytest.mark.parametrize(
"operation, expected_output, input, par, subspace", test_data_single_wire_with_parameters
)
def test_apply_operation_single_wire_with_parameters_adjoint(
self, qutrit_device_1_wire, tol, operation, input, expected_output, par, subspace
):
"""Tests that applying an adjoint operation yields the expected output state for single wire
operations that have parameters."""

qutrit_device_1_wire._state = np.array(input, dtype=qutrit_device_1_wire.C_DTYPE)

kwargs = {} if subspace is None else {"subspace": subspace}
qutrit_device_1_wire.apply([qml.adjoint(operation(*par, wires=[0], **kwargs))])

assert np.allclose(qutrit_device_1_wire._state, np.array(expected_output), atol=tol, rtol=0)
assert qutrit_device_1_wire._state.dtype == qutrit_device_1_wire.C_DTYPE

# TODO: Add more ops as parametric operations get added
test_data_two_wires_with_parameters = [
(qml.QutritUnitary, [0, 0, 0, 0, 0, 1, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 1, 0], [TSWAP]),
Expand Down Expand Up @@ -253,6 +313,25 @@ def test_apply_operation_two_wires_with_parameters(
)
assert qutrit_device_2_wires._state.dtype == qutrit_device_2_wires.C_DTYPE

@pytest.mark.parametrize(
"operation,expected_output,input,par", test_data_two_wires_with_parameters
)
def test_apply_operation_two_wires_with_parameters_adjoint(
self, qutrit_device_2_wires, tol, operation, input, expected_output, par
):
"""Tests that applying an adjoint operation yields the expected output state for two wire
operations that have parameters."""

qutrit_device_2_wires._state = np.array(input, dtype=qutrit_device_2_wires.C_DTYPE).reshape(
(3, 3)
)
qutrit_device_2_wires.apply([qml.adjoint(operation(*par, wires=[0, 1]))])

assert np.allclose(
qutrit_device_2_wires._state.flatten(), np.array(expected_output), atol=tol, rtol=0
)
assert qutrit_device_2_wires._state.dtype == qutrit_device_2_wires.C_DTYPE

def test_apply_rotations_one_wire(self, qutrit_device_1_wire):
"""Tests that rotations are applied in correct order after operations"""

Expand Down Expand Up @@ -706,6 +785,35 @@ def circuit(mat):
state = circuit(mat)
assert np.allclose(state, expected_out, atol=tol)

def test_qutrit_circuit_adjoint_integration(self):
"""Test that using qml.adjoint in a `default.qutrit` qnode works as expected."""
dev = qml.device("default.qutrit", wires=3)

def ansatz(phi, theta, omega, U):
qml.TShift(0)
qml.TAdd([0, 1])
qml.TRX(phi, wires=2, subspace=(0, 1))
qml.TClock(1)
qml.TRY(theta, wires=0, subspace=(1, 2))
qml.TSWAP([0, 2])
qml.QutritUnitary(U, wires=[2, 1])
qml.TRZ(omega, wires=1, subspace=(0, 2))

@qml.qnode(dev)
def circuit():
phi, theta, omega = np.random.rand(3) * 2 * np.pi
U = unitary_group.rvs(9, random_state=10)

ansatz(phi, theta, omega, U)
qml.adjoint(ansatz)(phi, theta, omega, U)
return qml.state()

expected = np.zeros(27)
expected[0] = 1
res = circuit()

assert np.allclose(res, expected)


class TestTensorExpval:
"""Test tensor expectation values"""
Expand Down Expand Up @@ -1201,7 +1309,9 @@ def compute_matrix(*params, **hyperparams):

# Set the internal _apply_unitary_tensordot
history = []
mock_apply_tensordot = lambda state, matrix, wires: history.append((state, matrix, wires))

def mock_apply_tensordot(state, matrix, wires):
history.append((state, matrix, wires))

with monkeypatch.context() as m:
m.setattr(dev, "_apply_unitary", mock_apply_tensordot)
Expand Down Expand Up @@ -1241,11 +1351,12 @@ def test_internal_apply_ops_case(self, monkeypatch, mocker):

# Create a dummy operation
expected_test_output = np.ones(1)
supported_gate_application = lambda *args, **kwargs: expected_test_output

with monkeypatch.context() as m:
# Set the internal ops implementations dict
m.setattr(dev, "_apply_ops", {"QutritUnitary": supported_gate_application})
m.setattr(
dev, "_apply_ops", {"QutritUnitary": lambda *args, **kwargs: expected_test_output}
)

test_state = np.array([1, 0, 0])
op = qml.QutritUnitary(TSHIFT, wires=0)
Expand Down

0 comments on commit 712f99c

Please sign in to comment.