Skip to content

Commit

Permalink
Merge branch 'master' into plxpr-capture-assertion-tests
Browse files Browse the repository at this point in the history
  • Loading branch information
albi3ro committed Jun 3, 2024
2 parents ae18650 + 92ce59d commit 3e3f81e
Show file tree
Hide file tree
Showing 4 changed files with 70 additions and 73 deletions.
3 changes: 3 additions & 0 deletions doc/releases/changelog-dev.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@

<h3>Improvements 🛠</h3>

* The wires for the `default.tensor` device are selected at runtime if they are not provided by user.
[(#5744)](https://github.com/PennyLaneAI/pennylane/pull/5744)

* Added `packaging` in the required list of packages.
[(#5769)](https://github.com/PennyLaneAI/pennylane/pull/5769).

Expand Down
86 changes: 42 additions & 44 deletions pennylane/devices/default_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
from pennylane.tape import QuantumScript, QuantumTape
from pennylane.transforms.core import TransformProgram
from pennylane.typing import Result, ResultBatch, TensorLike
from pennylane.wires import WireError

Result_or_ResultBatch = Union[Result, ResultBatch]
QuantumTapeBatch = Sequence[QuantumTape]
Expand Down Expand Up @@ -257,16 +258,12 @@ def circuit(theta, phi, num_qubits):

def __init__(
self,
wires,
*,
wires=None,
method="mps",
dtype=np.complex128,
**kwargs,
) -> None:

if wires is None:
raise TypeError("Wires must be provided for the default.tensor device.")

if not has_quimb:
raise ImportError(
"This feature requires quimb, a library for tensor network manipulations. "
Expand All @@ -293,28 +290,11 @@ def __init__(
self._cutoff = kwargs.get("cutoff", np.finfo(self._dtype).eps)
self._contract = kwargs.get("contract", "auto-mps")

device_options = self._setup_execution_config().device_options

self._init_state_opts = {
"binary": "0" * (len(self._wires) if self._wires else 1),
"dtype": self._dtype.__name__,
"tags": [str(l) for l in self._wires.labels] if self._wires else None,
}

self._gate_opts = {
"parametrize": None,
"contract": device_options["contract"],
"cutoff": device_options["cutoff"],
"max_bond": device_options["max_bond_dim"],
}

self._expval_opts = {
"dtype": self._dtype.__name__,
"simplify_sequence": "ADCRS",
"simplify_atol": 0.0,
}

self._circuitMPS = qtn.CircuitMPS(psi0=self._initial_mps())
# The `quimb` state is a class attribute so that we can implement methods
# that access it as soon as the device is created without running a circuit.
# The state is reset every time a new circuit is executed, and number of wires
# can be established at runtime to match the circuit.
self._quimb_mps = qtn.CircuitMPS(psi0=self._initial_mps(self.wires))

for arg in kwargs:
if arg not in self._device_options:
Expand All @@ -337,24 +317,39 @@ def dtype(self):
"""Tensor complex data type."""
return self._dtype

def _reset_state(self) -> None:
def _reset_mps(self, wires: qml.wires.Wires) -> None:
"""
Reset the MPS.
Reset the MPS associated with the circuit.
Internally, it uses `quimb`'s `CircuitMPS` class.
This method modifies the tensor state of the device.
Args:
wires (Wires): The wires to reset the MPS.
"""
self._circuitMPS = qtn.CircuitMPS(psi0=self._initial_mps())
self._quimb_mps = qtn.CircuitMPS(
psi0=self._initial_mps(wires),
max_bond=self._max_bond_dim,
gate_contract=self._contract,
cutoff=self._cutoff,
)

def _initial_mps(self) -> "qtn.MatrixProductState":
def _initial_mps(self, wires: qml.wires.Wires) -> "qtn.MatrixProductState":
r"""
Return an initial state to :math:`\ket{0}`.
Internally, it uses `quimb`'s `MPS_computational_state` method.
Args:
wires (Wires): The wires to initialize the MPS.
Returns:
MatrixProductState: The initial MPS of a circuit.
"""
return qtn.MPS_computational_state(**self._init_state_opts)
return qtn.MPS_computational_state(
binary="0" * (len(wires) if wires else 1),
dtype=self._dtype.__name__,
tags=[str(l) for l in wires.labels] if wires else None,
)

def _setup_execution_config(
self, config: Optional[ExecutionConfig] = DefaultExecutionConfig
Expand Down Expand Up @@ -429,10 +424,11 @@ def execute(

results = []
for circuit in circuits:
# we need to check if the wires of the circuit are compatible with the wires of the device
# since the initial tensor state is created with the wires of the device
if not self.wires.contains_wires(circuit.wires):
raise AttributeError(
if self.wires is not None and not self.wires.contains_wires(circuit.wires):
# quimb raises a cryptic error if the circuit has wires that are not in the device,
# so we raise a more informative error here
raise WireError(
"Mismatch between circuit and device wires. "
f"Circuit has wires {circuit.wires.tolist()}. "
f"Tensor on device has wires {self.wires.tolist()}"
)
Expand All @@ -451,7 +447,9 @@ def simulate(self, circuit: QuantumScript) -> Result:
Tuple[TensorLike]: The results of the simulation.
"""

self._reset_state()
wires = circuit.wires if self.wires is None else self.wires

self._reset_mps(wires)

for op in circuit.operations:
self._apply_operation(op)
Expand All @@ -472,9 +470,7 @@ def _apply_operation(self, op: qml.operation.Operator) -> None:
op (Operator): The operation to apply.
"""

self._circuitMPS.apply_gate(
qml.matrix(op).astype(self._dtype), *op.wires, **self._gate_opts
)
self._quimb_mps.apply_gate(qml.matrix(op).astype(self._dtype), *op.wires, parametrize=None)

def measurement(self, measurementprocess: MeasurementProcess) -> TensorLike:
"""Measure the measurement required by the circuit over the MPS.
Expand Down Expand Up @@ -555,13 +551,15 @@ def _local_expectation(self, matrix, wires) -> float:
Local expectation value of the matrix on the MPS.
"""

# We need to copy the MPS to avoid modifying the original state
qc = copy.deepcopy(self._circuitMPS)
# We need to copy the MPS since `local_expectation` modifies the state
qc = copy.deepcopy(self._quimb_mps)

exp_val = qc.local_expectation(
matrix,
wires,
**self._expval_opts,
dtype=self._dtype.__name__,
simplify_sequence="ADCRS",
simplify_atol=0.0,
)

return float(np.real(exp_val))
Expand Down
1 change: 1 addition & 0 deletions tests/capture/test_templates.py
Original file line number Diff line number Diff line change
Expand Up @@ -621,6 +621,7 @@ def qfunc(probs):
assert len(q) == 1
assert q.queue[0] == qml.QuantumMonteCarlo(probs, **kwargs)

@pytest.mark.usefixtures("new_opmath_only")
def test_qubitization(self):
"""Test the primitive bind call of Qubitization."""

Expand Down
53 changes: 24 additions & 29 deletions tests/devices/default_tensor/test_default_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from scipy.sparse import csr_matrix

import pennylane as qml
from pennylane.wires import WireError

quimb = pytest.importorskip("quimb")

Expand Down Expand Up @@ -125,42 +126,36 @@

def test_name():
"""Test the name of DefaultTensor."""
assert qml.device("default.tensor", wires=0).name == "default.tensor"
assert qml.device("default.tensor").name == "default.tensor"


def test_wires():
"""Test that a device can be created with wires."""
assert qml.device("default.tensor", wires=0).wires is not None
assert qml.device("default.tensor").wires is None
assert qml.device("default.tensor", wires=2).wires == qml.wires.Wires([0, 1])
assert qml.device("default.tensor", wires=[0, 2]).wires == qml.wires.Wires([0, 2])

with pytest.raises(AttributeError):
qml.device("default.tensor", wires=0).wires = [0, 1]
qml.device("default.tensor").wires = [0, 1]


def test_wires_error():
"""Test that an error is raised if the wires are not provided."""
with pytest.raises(TypeError):
qml.device("default.tensor")
def test_wires_runtime():
"""Test that this device can execute a tape with wires determined at runtime if they are not provided."""
dev = qml.device("default.tensor")
ops = [qml.Identity(0), qml.Identity((0, 1)), qml.RX(2, 0), qml.RY(1, 5), qml.RX(2, 1)]
measurements = [qml.expval(qml.PauliZ(15))]
tape = qml.tape.QuantumScript(ops, measurements)
assert dev.execute(tape) == 1.0

with pytest.raises(TypeError):
qml.device("default.tensor", wires=None)


def test_wires_execution_error():
"""Test that this device cannot execute a tape if its wires do not match the wires on the device."""
dev = qml.device("default.tensor", wires=3)
ops = [
qml.Identity(0),
qml.Identity((0, 1)),
qml.RX(2, 0),
qml.RY(1, 5),
qml.RX(2, 1),
]

def test_wires_runtime_error():
"""Test that this device raises an error if the wires are provided by user and there is a mismatch."""
dev = qml.device("default.tensor", wires=1)
ops = [qml.Identity(0), qml.Identity((0, 1)), qml.RX(2, 0), qml.RY(1, 5), qml.RX(2, 1)]
measurements = [qml.expval(qml.PauliZ(15))]
tape = qml.tape.QuantumScript(ops, measurements)

with pytest.raises(AttributeError):
with pytest.raises(WireError):
dev.execute(tape)


Expand Down Expand Up @@ -192,7 +187,7 @@ def test_invalid_kwarg():

def test_method():
"""Test the device method."""
assert qml.device("default.tensor", wires=0).method == "mps"
assert qml.device("default.tensor").method == "mps"


def test_invalid_method():
Expand Down Expand Up @@ -272,12 +267,12 @@ class TestSupportsDerivatives:

def test_support_derivatives(self):
"""Test that the device does not support derivatives yet."""
dev = qml.device("default.tensor", wires=0)
dev = qml.device("default.tensor")
assert not dev.supports_derivatives()

def test_compute_derivatives(self):
"""Test that an error is raised if the `compute_derivatives` method is called."""
dev = qml.device("default.tensor", wires=0)
dev = qml.device("default.tensor")
with pytest.raises(
NotImplementedError,
match="The computation of derivatives has yet to be implemented for the default.tensor device.",
Expand All @@ -286,7 +281,7 @@ def test_compute_derivatives(self):

def test_execute_and_compute_derivatives(self):
"""Test that an error is raised if `execute_and_compute_derivative` method is called."""
dev = qml.device("default.tensor", wires=0)
dev = qml.device("default.tensor")
with pytest.raises(
NotImplementedError,
match="The computation of derivatives has yet to be implemented for the default.tensor device.",
Expand All @@ -295,12 +290,12 @@ def test_execute_and_compute_derivatives(self):

def test_supports_vjp(self):
"""Test that the device does not support VJP yet."""
dev = qml.device("default.tensor", wires=0)
dev = qml.device("default.tensor")
assert not dev.supports_vjp()

def test_compute_vjp(self):
"""Test that an error is raised if `compute_vjp` method is called."""
dev = qml.device("default.tensor", wires=0)
dev = qml.device("default.tensor")
with pytest.raises(
NotImplementedError,
match="The computation of vector-Jacobian product has yet to be implemented for the default.tensor device.",
Expand All @@ -309,7 +304,7 @@ def test_compute_vjp(self):

def test_execute_and_compute_vjp(self):
"""Test that an error is raised if `execute_and_compute_vjp` method is called."""
dev = qml.device("default.tensor", wires=0)
dev = qml.device("default.tensor")
with pytest.raises(
NotImplementedError,
match="The computation of vector-Jacobian product has yet to be implemented for the default.tensor device.",
Expand Down

0 comments on commit 3e3f81e

Please sign in to comment.