Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Upgrade and generalise basis state preparation #6021

Merged
merged 64 commits into from
Aug 21, 2024
Merged
Show file tree
Hide file tree
Changes from 34 commits
Commits
Show all changes
64 commits
Select commit Hold shift + click to select a range
76203d4
some changes
KetpuntoG Jul 22, 2024
d18d9dd
black
KetpuntoG Jul 22, 2024
63d219b
remove patata.py
KetpuntoG Jul 22, 2024
16d039e
updating some tests
KetpuntoG Jul 22, 2024
8127a7a
some changes
KetpuntoG Jul 22, 2024
049f3a3
check tets
KetpuntoG Jul 23, 2024
db9e453
Update test_default_qubit_legacy.py
KetpuntoG Jul 23, 2024
773e459
Update test_state_prep.py
KetpuntoG Jul 23, 2024
e60d3f0
solving bugs
KetpuntoG Jul 23, 2024
12f93bf
Merge branch 'master' into clean_state_prep
KetpuntoG Jul 23, 2024
8ebf755
spsa from master
KetpuntoG Jul 23, 2024
3841fdd
Merge branch 'master' into clean_state_prep
KetpuntoG Jul 23, 2024
3c4465a
[skip-ci]
KetpuntoG Jul 23, 2024
733c4a6
Merge branch 'master' into clean_state_prep
KetpuntoG Jul 23, 2024
989aa11
fix pylint
KetpuntoG Jul 24, 2024
99603ec
Update basis.py
KetpuntoG Jul 24, 2024
fe36eb2
Update basis.py
KetpuntoG Jul 24, 2024
5f8ba56
deprecating basisstateprep
KetpuntoG Jul 24, 2024
94ca8bf
Update basis.py
KetpuntoG Jul 24, 2024
3e5e250
warning deprecated
KetpuntoG Jul 25, 2024
9ee26bc
Update test_templates.py
KetpuntoG Jul 25, 2024
ddefdf4
Merge branch 'master' into clean_state_prep
KetpuntoG Jul 25, 2024
7ed2fd9
test correct
KetpuntoG Jul 25, 2024
7298fb2
Merge branch 'clean_state_prep' of https://github.com/PennyLaneAI/pen…
KetpuntoG Jul 25, 2024
cc3e25b
Update test_batch_input.py
KetpuntoG Jul 25, 2024
f2f3efc
commit
KetpuntoG Jul 25, 2024
a7ad726
Update test_templates.py
KetpuntoG Jul 25, 2024
2b6b670
Update doc/development/deprecations.rst
KetpuntoG Aug 12, 2024
7f8635d
Merge branch 'master' into clean_state_prep
KetpuntoG Aug 12, 2024
461454a
Update test_basis_state_prep.py
KetpuntoG Aug 12, 2024
50d8085
docs
KetpuntoG Aug 12, 2024
d1787c1
black
KetpuntoG Aug 12, 2024
032e0ae
isort
KetpuntoG Aug 12, 2024
bb65b87
Delete patata.py
KetpuntoG Aug 12, 2024
b650cbf
autouse False
KetpuntoG Aug 13, 2024
448e369
only test_templates
KetpuntoG Aug 13, 2024
08936fd
Merge branch 'master' into clean_state_prep
KetpuntoG Aug 13, 2024
0918ac7
Update deprecations.rst
KetpuntoG Aug 19, 2024
7165004
Update changelog-dev.md
KetpuntoG Aug 19, 2024
8d7c5e0
Update test_templates.py
KetpuntoG Aug 19, 2024
3b278be
Update basis.py
KetpuntoG Aug 19, 2024
4fc8b58
Update basis.py
KetpuntoG Aug 19, 2024
5cdffee
Update test_templates.py
KetpuntoG Aug 19, 2024
a3c6e15
Update test_state_prep.py
KetpuntoG Aug 19, 2024
cfdd487
Update test_state_prep.py
KetpuntoG Aug 19, 2024
cb6ba65
Update test_basis_state_prep.py
KetpuntoG Aug 19, 2024
6531ce6
Update test_batch_input.py
KetpuntoG Aug 19, 2024
51a79de
Update test_batch_params.py
KetpuntoG Aug 19, 2024
4c1e2d5
Update test_defer_measurements.py
KetpuntoG Aug 19, 2024
e810ca7
Merge branch 'master' into clean_state_prep
KetpuntoG Aug 19, 2024
cc7801b
removing BSP modifications
KetpuntoG Aug 19, 2024
e9931bc
Update state_preparation.py
KetpuntoG Aug 20, 2024
cd64dff
Update test_state_prep.py
KetpuntoG Aug 20, 2024
c9e66e9
code review comments
KetpuntoG Aug 21, 2024
05823d3
Update state_preparation.py
KetpuntoG Aug 21, 2024
bf12289
updating test msgs
KetpuntoG Aug 21, 2024
55f9941
unify changelog
KetpuntoG Aug 21, 2024
c7afd28
Merge branch 'master' into clean_state_prep
KetpuntoG Aug 21, 2024
5d52014
Update test_tape.py
KetpuntoG Aug 21, 2024
5430b27
Merge branch 'clean_state_prep' of https://github.com/PennyLaneAI/pen…
KetpuntoG Aug 21, 2024
c4cbc98
Apply suggestions from code review
KetpuntoG Aug 21, 2024
53d41db
Merge branch 'master' into clean_state_prep
KetpuntoG Aug 21, 2024
f79135a
Update test_tape.py
KetpuntoG Aug 21, 2024
e78efb4
Merge branch 'master' into clean_state_prep
KetpuntoG Aug 21, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions doc/development/deprecations.rst
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,12 @@ Pending deprecations
- Deprecated in v0.37
- Will be removed in v0.39

* The ``BasisStatePreparation`` template is deprecated.
Instead, use ``BasisState``.

- Deprecated in v0.38
- Will be removed in v0.39
KetpuntoG marked this conversation as resolved.
Show resolved Hide resolved

New operator arithmetic deprecations
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

Expand Down
15 changes: 14 additions & 1 deletion doc/releases/changelog-dev.md
Original file line number Diff line number Diff line change
Expand Up @@ -260,6 +260,10 @@
Instead, use `pennylane.gradients.classical_fisher` and `pennylane.gradients.quantum_fisher`.
[(#5985)](https://github.com/PennyLaneAI/pennylane/pull/5985)

* The ``BasisStatePreparation`` template is deprecated.
Instead, ``BasisState`` can be called on the constructed operator.
[(#6021)](https://github.com/PennyLaneAI/pennylane/pull/6021)

* The legacy devices `default.qubit.{autograd,torch,tf,jax,legacy}` are deprecated.
Instead, use `default.qubit` as it now supports backpropagation through the several backends.
[(#5997)](https://github.com/PennyLaneAI/pennylane/pull/5997)
Expand Down Expand Up @@ -310,7 +314,16 @@
[(#5978)](https://github.com/PennyLaneAI/pennylane/pull/5978)

* `qml.AmplitudeEmbedding` has better support for features using low precision integer data types.
[(#5969)](https://github.com/PennyLaneAI/pennylane/pull/5969)
[(#5969)](https://github.com/PennyLaneAI/pennylane/pull/5969)

* `qml.BasisState` now works with jax.jit.
[(#6021)](https://github.com/PennyLaneAI/pennylane/pull/6021)

* `qml.BasisEmbedding` now gives the correct decomposition.
[(#6021)](https://github.com/PennyLaneAI/pennylane/pull/6021)

* `qml.BasisEmbedding` now works with lightning.qubit and jax.jit.
[(#6021)](https://github.com/PennyLaneAI/pennylane/pull/6021)
soranjh marked this conversation as resolved.
Show resolved Hide resolved
KetpuntoG marked this conversation as resolved.
Show resolved Hide resolved

* Jacobian shape is fixed for measurements with dimension in `qml.gradients.vjp.compute_vjp_single`.
[(5986)](https://github.com/PennyLaneAI/pennylane/pull/5986)
Expand Down
9 changes: 9 additions & 0 deletions pennylane/devices/tests/test_templates.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,15 @@ def circuit():
[math.fidelity_statevector(circuit(), exp_state)], [1.0], atol=tol(dev.shots)
)

@pytest.fixture(scope="function", autouse=True)
def capture_warnings(self, recwarn):
"""Capture warnings."""
yield
if len(recwarn) > 0:
for w in recwarn:
assert isinstance(w.message, qml.PennyLaneDeprecationWarning)
assert "BasisStatePreparation is deprecated" in str(w.message)
KetpuntoG marked this conversation as resolved.
Show resolved Hide resolved

def test_BasisStatePreparation(self, device, tol):
KetpuntoG marked this conversation as resolved.
Show resolved Hide resolved
"""Test the BasisStatePreparation template."""
dev = device(4)
Expand Down
99 changes: 75 additions & 24 deletions pennylane/ops/qubit/state_preparation.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,23 +18,23 @@
# pylint:disable=abstract-method,arguments-differ,protected-access,no-member
import numpy as np

import pennylane as qml
from pennylane import math
from pennylane.operation import AnyWires, Operation, StatePrepBase
from pennylane.templates.state_preparations import BasisStatePreparation, MottonenStatePreparation
soranjh marked this conversation as resolved.
Show resolved Hide resolved
from pennylane.templates.state_preparations import MottonenStatePreparation
from pennylane.wires import WireError, Wires

state_prep_ops = {"BasisState", "StatePrep", "QubitDensityMatrix"}


class BasisState(StatePrepBase):
r"""BasisState(n, wires)
r"""BasisState(features, wires)
KetpuntoG marked this conversation as resolved.
Show resolved Hide resolved
Prepares a single computational basis state.

**Details:**

* Number of wires: Any (the operation can act on any number of wires)
* Number of parameters: 1
* Gradient recipe: None (integer parameters not supported)
soranjh marked this conversation as resolved.
Show resolved Hide resolved

.. note::

Expand All @@ -48,9 +48,8 @@ class BasisState(StatePrepBase):
as :math:`U|0\rangle = |\psi\rangle`

Args:
n (array): prepares the basis state :math:`\ket{n}`, where ``n`` is an
array of integers from the set :math:`\{0, 1\}`, i.e.,
if ``n = np.array([0, 1, 0])``, prepares the state :math:`|010\rangle`.
features (tensor_like): binary input of shape ``(len(wires), )``. For example, for ``features=np.array([0, 1, 0])`` or ``features=2`` (binary 010), the quantum system will be prepared in state :math:`|010 \rangle`.
KetpuntoG marked this conversation as resolved.
Show resolved Hide resolved

wires (Sequence[int] or int): the wire(s) the operation acts on
id (str): custom label given to an operator instance,
can be useful for some applications where the instance has to be identified.
Expand All @@ -66,15 +65,51 @@ class BasisState(StatePrepBase):
[0.+0.j 0.+0.j 0.+0.j 1.+0.j]
"""

num_wires = AnyWires
num_params = 1
"""int: Number of trainable parameters that the operator depends on."""
def __init__(self, features, wires, id=None):

ndim_params = (1,)
"""int: Number of dimensions per trainable parameter of the operator."""
if isinstance(features, list):
features = qml.math.stack(features)

tracing = qml.math.is_abstract(features)

if qml.math.shape(features) == ():
KetpuntoG marked this conversation as resolved.
Show resolved Hide resolved
if not tracing and features >= 2 ** len(wires):
raise ValueError(
f"Features must be of length {len(wires)}, got features={features} which is >= {2 ** len(wires)}"
KetpuntoG marked this conversation as resolved.
Show resolved Hide resolved
)
bin = 2 ** math.arange(len(wires))[::-1]
features = qml.math.where((features & bin) > 0, 1, 0)

wires = Wires(wires)
shape = qml.math.shape(features)

if len(shape) != 1:
raise ValueError(f"Features must be one-dimensional; got shape {shape}.")

n_features = shape[0]
if n_features != len(wires):
raise ValueError(
f"Features must be of length {len(wires)}; got length {n_features} (features={features})."
)

if not tracing:
features_list = list(qml.math.toarray(features))
KetpuntoG marked this conversation as resolved.
Show resolved Hide resolved
if not set(features_list).issubset({0, 1}):
raise ValueError(f"Basis state must only consist of 0s and 1s; got {features_list}")

super().__init__(features, wires=wires, id=id)

def _flatten(self):
features = self.parameters[0]
features = tuple(features) if isinstance(features, list) else features
return (features,), (self.wires,)

@classmethod
def _unflatten(cls, data, metadata) -> "BasisState":
return cls(data[0], wires=metadata[0])

@staticmethod
def compute_decomposition(n, wires):
def compute_decomposition(features, wires):
r"""Representation of the operator as a product of other operators (static method). :

.. math:: O = O_1 O_2 \dots O_n.
Expand All @@ -93,33 +128,49 @@ def compute_decomposition(n, wires):
**Example:**

>>> qml.BasisState.compute_decomposition([1,0], wires=(0,1))
[BasisStatePreparation([1, 0], wires=[0, 1])]
[X(0)]

"""
return [BasisStatePreparation(n, wires)]

if not qml.math.is_abstract(features):
op_list = []
for wire, state in zip(wires, features):
if state == 1:
op_list.append(qml.X(wire))
KetpuntoG marked this conversation as resolved.
Show resolved Hide resolved
return op_list

op_list = []
for wire, state in zip(wires, features):
op_list.append(qml.PhaseShift(state * np.pi / 2, wire))
op_list.append(qml.RX(state * np.pi, wire))
op_list.append(qml.PhaseShift(state * np.pi / 2, wire))

return op_list

def state_vector(self, wire_order=None):
"""Returns a statevector of shape ``(2,) * num_wires``."""
prep_vals = self.parameters[0]
if any(i not in [0, 1] for i in prep_vals):
raise ValueError("BasisState parameter must consist of 0 or 1 integers.")

if (num_wires := len(self.wires)) != len(prep_vals):
raise ValueError("BasisState parameter and wires must be of equal length.")
soranjh marked this conversation as resolved.
Show resolved Hide resolved
prep_vals_int = math.cast(prep_vals, int)
KetpuntoG marked this conversation as resolved.
Show resolved Hide resolved

prep_vals = math.cast(prep_vals, int)
if wire_order is None:
indices = prep_vals
indices = prep_vals_int
num_wires = len(indices)
else:
if not Wires(wire_order).contains_wires(self.wires):
raise WireError("Custom wire_order must contain all BasisState wires")
num_wires = len(wire_order)
indices = [0] * num_wires
for base_wire_label, value in zip(self.wires, prep_vals):
for base_wire_label, value in zip(self.wires, prep_vals_int):
indices[wire_order.index(base_wire_label)] = value

ket = np.zeros((2,) * num_wires)
ket[tuple(indices)] = 1
if qml.math.get_interface(prep_vals_int) == "jax":
ket = math.array(math.zeros((2,) * num_wires), like="jax")
ket = ket.at[tuple(indices)].set(1)

else:
ket = math.zeros((2,) * num_wires)
obliviateandsurrender marked this conversation as resolved.
Show resolved Hide resolved
ket[tuple(indices)] = 1

return math.convert_like(ket, prep_vals)


Expand Down
105 changes: 7 additions & 98 deletions pennylane/templates/embeddings/basis.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,17 +15,15 @@
Contains the BasisEmbedding template.
"""
# pylint: disable-msg=too-many-branches,too-many-arguments,protected-access
import numpy as np

import pennylane as qml
from pennylane.operation import AnyWires, Operation
from pennylane.wires import Wires
from pennylane.ops.qubit.state_preparation import BasisState


class BasisEmbedding(Operation):
# pylint: disable=missing-class-docstring
KetpuntoG marked this conversation as resolved.
Show resolved Hide resolved
class BasisEmbedding(BasisState):
KetpuntoG marked this conversation as resolved.
Show resolved Hide resolved
r"""Encodes :math:`n` binary features into a basis state of :math:`n` qubits.

For example, for ``features=np.array([0, 1, 0])`` or ``features=2`` (binary 10), the
For example, for ``features=np.array([0, 1, 0])`` or ``features=2`` (binary 010), the
quantum system will be prepared in state :math:`|010 \rangle`.

.. warning::
Expand All @@ -35,8 +33,9 @@ class BasisEmbedding(Operation):
gradients with respect to the argument cannot be computed by PennyLane.

Args:
features (tensor_like): binary input of shape ``(len(wires), )``
wires (Any or Iterable[Any]): wires that the template acts on
features (tensor_like or int): binary input of shape ``(len(wires), )`` or integer
that represents the binary input.
wires (Any or Iterable[Any]): wires that the template acts on.

Example:

Expand Down Expand Up @@ -68,93 +67,3 @@ def circuit(feature_vector):
Thus, ``[1,1,1]`` is mapped to :math:`|111 \rangle`.

"""

num_wires = AnyWires
grad_method = None

def _flatten(self):
basis_state = self.hyperparameters["basis_state"]
basis_state = tuple(basis_state) if isinstance(basis_state, list) else basis_state
return tuple(), (self.wires, basis_state)

@classmethod
def _unflatten(cls, _, metadata) -> "BasisEmbedding":
return cls(features=metadata[1], wires=metadata[0])

def __init__(self, features, wires, id=None):
if isinstance(features, list):
features = qml.math.stack(features)

tracing = qml.math.is_abstract(features)

if qml.math.shape(features) == ():
if not tracing and features >= 2 ** len(wires):
raise ValueError(
f"Features must be of length {len(wires)}, got features={features} which is >= {2 ** len(wires)}"
)
bin = 2 ** np.arange(len(wires))[::-1]
features = qml.math.where((features & bin) > 0, 1, 0)

wires = Wires(wires)
shape = qml.math.shape(features)

if len(shape) != 1:
raise ValueError(f"Features must be one-dimensional; got shape {shape}.")

n_features = shape[0]
if n_features != len(wires):
raise ValueError(
f"Features must be of length {len(wires)}; got length {n_features} (features={features})."
)

if not tracing:
features = list(qml.math.toarray(features))
if not set(features).issubset({0, 1}):
raise ValueError(f"Basis state must only consist of 0s and 1s; got {features}")

self._hyperparameters = {"basis_state": features}

super().__init__(wires=wires, id=id)

@property
def num_params(self):
return 0

@staticmethod
def compute_decomposition(wires, basis_state): # pylint: disable=arguments-differ
r"""Representation of the operator as a product of other operators.

.. math:: O = O_1 O_2 \dots O_n.



.. seealso:: :meth:`~.BasisEmbedding.decomposition`.

Args:
features (tensor-like): binary input of shape ``(len(wires), )``
wires (Any or Iterable[Any]): wires that the operator acts on

Returns:
list[.Operator]: decomposition of the operator

**Example**

>>> features = torch.tensor([1, 0, 1])
>>> qml.BasisEmbedding.compute_decomposition(features, wires=["a", "b", "c"])
[X('a'),
X('c')]
"""
if not qml.math.is_abstract(basis_state):
ops_list = []
for wire, bit in zip(wires, basis_state):
if bit == 1:
ops_list.append(qml.X(wire))
return ops_list

ops_list = []
for wire, state in zip(wires, basis_state):
ops_list.append(qml.PhaseShift(state * np.pi / 2, wire))
ops_list.append(qml.RX(state * np.pi, wire))
ops_list.append(qml.PhaseShift(state * np.pi / 2, wire))

return ops_list
11 changes: 11 additions & 0 deletions pennylane/templates/state_preparations/basis.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
Contains the BasisStatePreparation template.
"""

import warnings

import numpy as np

import pennylane as qml
Expand All @@ -30,6 +32,8 @@ class BasisStatePreparation(Operation):
``basis_state`` influences the circuit architecture and is therefore incompatible with
gradient computations.

``BasisStatePreparation`` is deprecated and will be removed in version 0.39. Instead, please use ``BasisState``.

Args:
basis_state (array): Input array of shape ``(n,)``, where n is the number of wires
the state preparation acts on.
Expand Down Expand Up @@ -59,6 +63,13 @@ def circuit(basis_state):
ndim_params = (1,)

def __init__(self, basis_state, wires, id=None):

warnings.warn(
"BasisStatePreparation is deprecated and will be removed in version 0.39. "
"Instead, please use BasisState.",
qml.PennyLaneDeprecationWarning,
)

basis_state = qml.math.stack(basis_state)

# check if the `basis_state` param is batched
Expand Down
9 changes: 9 additions & 0 deletions tests/capture/test_templates.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,15 @@ def enable_disable_plxpr():
]


@pytest.fixture(scope="function", autouse=True)
def capture_warnings(recwarn):
yield
if len(recwarn) > 0:
for w in recwarn:
assert isinstance(w.message, qml.PennyLaneDeprecationWarning)
assert "BasisStatePreparation is deprecated" in str(w.message)


@pytest.mark.parametrize("template, args, kwargs", unmodified_templates_cases)
def test_unmodified_templates(template, args, kwargs):
"""Test that templates with unmodified primitive binds are captured as expected."""
Expand Down
Loading
Loading