Skip to content

Commit

Permalink
Adding new method and tests
Browse files Browse the repository at this point in the history
  • Loading branch information
PietropaoloFrisoni committed Jun 3, 2024
1 parent ec873db commit f3ef0ec
Show file tree
Hide file tree
Showing 4 changed files with 152 additions and 91 deletions.
82 changes: 58 additions & 24 deletions pennylane/devices/default_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@
)
# The set of supported observables.

_methods = frozenset({"mps"})
_methods = frozenset({"mps", "tns"})
# The set of supported methods.


Expand Down Expand Up @@ -247,7 +247,6 @@ def circuit(theta, phi, num_qubits):

# pylint: disable=too-many-instance-attributes

# So far we just consider the options for MPS simulator
_device_options = (
"contract",
"cutoff",
Expand All @@ -272,7 +271,7 @@ def __init__(

if not accepted_methods(method):
raise ValueError(
f"Unsupported method: {method}. The only currently supported method is mps."
f"Unsupported method: {method}. Supported methods are 'mps' (Matrix Product State) and 'tns' (Exact Tensor Network)."
)

if dtype not in [np.complex64, np.complex128]:
Expand All @@ -285,16 +284,21 @@ def __init__(
self._method = method
self._dtype = dtype

# options both for MPS and TNS
# TODO: add options

# options for MPS
self._max_bond_dim = kwargs.get("max_bond_dim", None)
self._cutoff = kwargs.get("cutoff", np.finfo(self._dtype).eps)
self._contract = kwargs.get("contract", "auto-mps")

# The `quimb` state is a class attribute so that we can implement methods
# that access it as soon as the device is created without running a circuit.
# The `quimb` circuit is a class attribute so that we can implement methods
# that access it as soon as the device is created before running a circuit.
# The state is reset every time a new circuit is executed, and number of wires
# can be established at runtime to match the circuit.
self._quimb_mps = qtn.CircuitMPS(psi0=self._initial_mps(self.wires))
# can be established at runtime to match the circuit if not provided.
self._quimb_circuit = None

self._initialize_quimb_circuit(self.wires)

for arg in kwargs:
if arg not in self._device_options:
Expand All @@ -303,39 +307,50 @@ def __init__(
)

@property
def name(self):
def name(self) -> str:
"""The name of the device."""
return "default.tensor"

@property
def method(self):
def method(self) -> str:
"""Supported method."""
return self._method

@property
def dtype(self):
def dtype(self) -> type:
"""Tensor complex data type."""
return self._dtype

def _reset_mps(self, wires: qml.wires.Wires) -> None:
def _initialize_quimb_circuit(self, wires: qml.wires.Wires) -> None:
"""
Reset the MPS associated with the circuit.
Initialize the quimb circuit according to the method chosen.
Internally, it uses `quimb`'s `CircuitMPS` class.
Internally, it uses `quimb`'s `CircuitMPS` or `Circuit` class.
Args:
wires (Wires): The wires to reset the MPS.
wires (Wires): The wires to initialize the quimb circuit.
"""
self._quimb_mps = qtn.CircuitMPS(
psi0=self._initial_mps(wires),
max_bond=self._max_bond_dim,
gate_contract=self._contract,
cutoff=self._cutoff,
)

if self.method == "mps":
self._quimb_circuit = qtn.CircuitMPS(
psi0=self._initial_mps(wires),
max_bond=self._max_bond_dim,
gate_contract=self._contract,
cutoff=self._cutoff,
)

elif self.method == "tns":
self._quimb_circuit = qtn.Circuit(
psi0=self._initial_tns(wires),
# TODO: add options for TNS
)

else:
raise NotImplementedError # pragma: no cover

def _initial_mps(self, wires: qml.wires.Wires) -> "qtn.MatrixProductState":
r"""
Return an initial state to :math:`\ket{0}`.
Return an initial mps to :math:`\ket{0}`.
Internally, it uses `quimb`'s `MPS_computational_state` method.
Expand All @@ -351,6 +366,23 @@ def _initial_mps(self, wires: qml.wires.Wires) -> "qtn.MatrixProductState":
tags=[str(l) for l in wires.labels] if wires else None,
)

def _initial_tns(self, wires: qml.wires.Wires) -> "qtn.TensorNetwork":
r"""
Return an initial tensor network state to :math:`\ket{0}`.
Internally, it uses `quimb`'s `TN_from_sites_computational_state` method.
Args:
wires (Wires): The wires to initialize the tensor network.
Returns:
TensorNetwork: The initial tensor network of a circuit.
"""
return qtn.TN_from_sites_computational_state(
site_map={i: "0" for i in range(len(wires) if wires else 1)},
dtype=self._dtype.__name__,
)

def _setup_execution_config(
self, config: Optional[ExecutionConfig] = DefaultExecutionConfig
) -> ExecutionConfig:
Expand Down Expand Up @@ -449,7 +481,7 @@ def simulate(self, circuit: QuantumScript) -> Result:

wires = circuit.wires if self.wires is None else self.wires

self._reset_mps(wires)
self._initialize_quimb_circuit(wires)

for op in circuit.operations:
self._apply_operation(op)
Expand All @@ -470,7 +502,9 @@ def _apply_operation(self, op: qml.operation.Operator) -> None:
op (Operator): The operation to apply.
"""

self._quimb_mps.apply_gate(qml.matrix(op).astype(self._dtype), *op.wires, parametrize=None)
self._quimb_circuit.apply_gate(
qml.matrix(op).astype(self._dtype), *op.wires, parametrize=None
)

def measurement(self, measurementprocess: MeasurementProcess) -> TensorLike:
"""Measure the measurement required by the circuit over the MPS.
Expand Down Expand Up @@ -552,7 +586,7 @@ def _local_expectation(self, matrix, wires) -> float:
"""

# We need to copy the MPS since `local_expectation` modifies the state
qc = copy.deepcopy(self._quimb_mps)
qc = copy.deepcopy(self._quimb_circuit)

exp_val = qc.local_expectation(
matrix,
Expand Down
41 changes: 27 additions & 14 deletions tests/devices/default_tensor/test_default_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,15 +162,21 @@ def test_wires_runtime_error():
@pytest.mark.parametrize("max_bond_dim", [None, 10])
@pytest.mark.parametrize("cutoff", [1e-16, 1e-12])
@pytest.mark.parametrize("contract", ["auto-mps", "nonlocal"])
def test_kwargs(max_bond_dim, cutoff, contract):
@pytest.mark.parametrize("method", ["mps", "tns"])
def test_kwargs(method, max_bond_dim, cutoff, contract):
"""Test the class initialization with different arguments and returned properties."""

kwargs = {"max_bond_dim": max_bond_dim, "cutoff": cutoff, "contract": contract}
kwargs = {
"method": method,
"max_bond_dim": max_bond_dim,
"cutoff": cutoff,
"contract": contract,
}

dev = qml.device("default.tensor", wires=0, **kwargs)

_, config = dev.preprocess()
assert config.device_options["method"] == "mps"
assert config.device_options["method"] == method
assert config.device_options["max_bond_dim"] == max_bond_dim
assert config.device_options["cutoff"] == cutoff
assert config.device_options["contract"] == contract
Expand All @@ -185,9 +191,10 @@ def test_invalid_kwarg():
qml.device("default.tensor", wires=0, fake_arg=None)


def test_method():
@pytest.mark.parametrize("method", ["mps", "tns"])
def test_method(method):
"""Test the device method."""
assert qml.device("default.tensor").method == "mps"
assert qml.device("default.tensor", method=method).method == method


def test_invalid_method():
Expand All @@ -209,14 +216,15 @@ def test_ivalid_data_type():
qml.device("default.tensor", wires=0, dtype=float)


@pytest.mark.parametrize("method", ["mps", "tns"])
class TestSupportedGatesAndObservables:
"""Test that the DefaultTensor device supports all gates and observables that it claims to support."""

@pytest.mark.parametrize("operation", all_ops)
def test_supported_gates_can_be_implemented(self, operation):
def test_supported_gates_can_be_implemented(self, operation, method):
"""Test that the device can implement all its supported gates."""

dev = qml.device("default.tensor", wires=4, method="mps")
dev = qml.device("default.tensor", wires=4, method=method)

tape = qml.tape.QuantumScript(
[operations_list[operation]],
Expand All @@ -227,10 +235,10 @@ def test_supported_gates_can_be_implemented(self, operation):
assert np.allclose(result, 1.0)

@pytest.mark.parametrize("observable", all_obs)
def test_supported_observables_can_be_implemented(self, observable):
def test_supported_observables_can_be_implemented(self, observable, method):
"""Test that the device can implement all its supported observables."""

dev = qml.device("default.tensor", wires=3, method="mps")
dev = qml.device("default.tensor", wires=3, method=method)

if observable == "Projector":
for o in observables_list[observable]:
Expand All @@ -249,7 +257,7 @@ def test_supported_observables_can_be_implemented(self, observable):
result = dev.execute(circuits=tape)
assert isinstance(result, (float, np.ndarray))

def test_not_implemented_meas(self):
def test_not_implemented_meas(self, method):
"""Tests that support only exists for `qml.expval` and `qml.var` so far."""

op = [qml.Identity(0)]
Expand Down Expand Up @@ -311,12 +319,17 @@ def test_execute_and_compute_vjp(self):
):
dev.execute_and_compute_vjp(circuits=None, cotangents=None)


@pytest.mark.parametrize("method", ["mps", "tns"])
class TestJaxSupport:
"""Test the JAX support for the DefaultTensor device."""

@pytest.mark.jax
def test_jax(self):
def test_jax(self, method):
"""Test the device with JAX."""

jax = pytest.importorskip("jax")
dev = qml.device("default.tensor", wires=1)
dev = qml.device("default.tensor", wires=1, method=method)
ref_dev = qml.device("default.qubit.jax", wires=1)

def circuit(x):
Expand All @@ -331,11 +344,11 @@ def circuit(x):
assert np.allclose(qnode(weights), ref_qnode(weights))

@pytest.mark.jax
def test_jax_jit(self):
def test_jax_jit(self, method):
"""Test the device with JAX's JIT compiler."""

jax = pytest.importorskip("jax")
dev = qml.device("default.tensor", wires=1)
dev = qml.device("default.tensor", wires=1, method=method)

@jax.jit
@qml.qnode(dev, interface="jax")
Expand Down
Loading

0 comments on commit f3ef0ec

Please sign in to comment.