Skip to content

Commit

Permalink
removing amplitudeembedding
Browse files Browse the repository at this point in the history
  • Loading branch information
KetpuntoG committed Jul 24, 2024
1 parent ce3b4b6 commit 238af73
Show file tree
Hide file tree
Showing 3 changed files with 106 additions and 114 deletions.
89 changes: 76 additions & 13 deletions pennylane/ops/qubit/state_preparation.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@

state_prep_ops = {"BasisState", "StatePrep", "QubitDensityMatrix"}

TOLERANCE = 1e-10


class BasisState(StatePrepBase):
r"""BasisState(n, wires)
Expand Down Expand Up @@ -157,6 +159,7 @@ class StatePrep(StatePrepBase):
>>> @qml.qnode(dev)
... def example_circuit():
... qml.StatePrep(np.array([1, 0, 0, 0]), wires=range(2))
... qml.StatePrep(np.array([1, 0, 0, 0]), wires=range(2))
... return qml.state()
>>> print(example_circuit())
[1.+0.j 0.+0.j 0.+0.j 0.+0.j]
Expand All @@ -169,20 +172,13 @@ class StatePrep(StatePrepBase):
ndim_params = (1,)
"""int: Number of dimensions per trainable parameter of the operator."""

def __init__(self, state, wires, id=None):
super().__init__(state, wires=wires, id=id)
state = self.parameters[0]
def __init__(self, state, wires, pad_with=None, normalize=False, id=None):

Check notice on line 175 in pennylane/ops/qubit/state_preparation.py

View check run for this annotation

codefactor.io / CodeFactor

pennylane/ops/qubit/state_preparation.py#L175

Too many arguments (6/5) (too-many-arguments)

if len(state.shape) == 1:
state = math.reshape(state, (1, state.shape[0]))
if state.shape[1] != 2 ** len(self.wires):
raise ValueError("State vector must have shape (2**wires,) or (batch_size, 2**wires).")
self.pad_with = pad_with
self.normalize = normalize
state = self._preprocess(state, wires, pad_with, normalize)

param = math.cast(state, np.complex128)
if not math.is_abstract(param):
norm = math.linalg.norm(param, axis=-1, ord=2)
if not math.allclose(norm, 1.0, atol=1e-10):
raise ValueError("Sum of amplitudes-squared does not equal one.")
super().__init__(state, wires=wires, id=id)

@staticmethod
def compute_decomposition(state, wires):
Expand All @@ -209,7 +205,8 @@ def compute_decomposition(state, wires):
return [MottonenStatePreparation(state, wires)]

def state_vector(self, wire_order=None):
num_op_wires = len(self.wires)

num_op_wires = len(Wires(self.wires))
op_vector_shape = (-1,) + (2,) * num_op_wires if self.batch_size else (2,) * num_op_wires
op_vector = math.reshape(self.parameters[0], op_vector_shape)

Expand All @@ -232,6 +229,72 @@ def state_vector(self, wire_order=None):
transpose_axes = [0] + [a + 1 for a in transpose_axes]
return math.transpose(op_vector, transpose_axes)

@staticmethod
def _preprocess(state, wires, pad_with, normalize):

Check notice on line 233 in pennylane/ops/qubit/state_preparation.py

View check run for this annotation

codefactor.io / CodeFactor

pennylane/ops/qubit/state_preparation.py#L233

Too many branches (13/12) (too-many-branches)
"""Validate and pre-process inputs as follows:
* If state is batched, the processing that follows is applied to each state set in the batch.
* Check that the state tensor is one-dimensional.
* If pad_with is None, check that the last dimension of the state tensor
has length :math:`2^n` where :math:`n` is the number of qubits. Else check that the
last dimension of the state tensor is not larger than :math:`2^n` and pad state
with value if necessary.
* If normalize is false, check that last dimension of state is normalised to one. Else, normalise the
state tensor.
"""
if isinstance(state, (list, tuple)):
state = math.array(state)
shape = math.shape(state)

# check shape
if len(shape) not in (1, 2):
raise ValueError(
f"State must be a one-dimensional tensor, or two-dimensional with batching; got shape {shape}."
)

n_states = shape[-1]
dim = 2 ** len(Wires(wires))
if pad_with is None and n_states != dim:
raise ValueError(
f"State must be of length {dim}; got length {n_states}. "
f"Use the 'pad_with' argument for automated padding."
)

if pad_with is not None:
if n_states > dim:
raise ValueError(
f"Input state must be of length {dim} or "
f"smaller to be padded; got length {n_states}."
)

# pad
if n_states < dim:
padding = [pad_with] * (dim - n_states)
if len(shape) > 1:
padding = [padding] * shape[0]
padding = math.convert_like(padding, state)
state = math.hstack([state, padding])

# normalize
if "int" in str(state.dtype):
state = math.cast_like(state, 0.0)
norm = math.linalg.norm(state, axis=-1)

if math.is_abstract(norm):
if normalize or pad_with:
state = state / math.reshape(norm, (*shape[:-1], 1))

elif not math.allclose(norm, 1.0, atol=TOLERANCE):
if normalize or pad_with:
state = state / math.reshape(norm, (*shape[:-1], 1))
else:
raise ValueError(
f"The state must be a vector of norm 1.0; got norm {norm}. "
"Use 'normalize=True' to automatically normalize."
)

return state


# pylint: disable=missing-class-docstring
class QubitStateVector(StatePrep):
Expand Down
100 changes: 1 addition & 99 deletions pennylane/templates/embeddings/amplitude.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,102 +118,4 @@ def circuit(f=None):
"""

def __init__(self, features, wires, pad_with=None, normalize=False, id=None):
# pylint:disable=bad-super-call
wires = Wires(wires)
self.pad_with = pad_with
self.normalize = normalize
features = self._preprocess(features, wires, pad_with, normalize)
super(StatePrep, self).__init__(features, wires=wires, id=id)

@staticmethod
def compute_decomposition(
features, wires
): # pylint: disable=arguments-differ,arguments-renamed
r"""Representation of the operator as a product of other operators.
.. math:: O = O_1 O_2 \dots O_n.
.. seealso:: :meth:`~.AmplitudeEmbedding.decomposition`.
Args:
features (tensor_like): input tensor of dimension ``(2^len(wires),)``
wires (Any or Iterable[Any]): wires that the operator acts on
Returns:
list[.Operator]: decomposition of the operator
**Example**
>>> features = torch.tensor([1., 0., 0., 0.])
>>> qml.AmplitudeEmbedding.compute_decomposition(features, wires=["a", "b"])
[StatePrep(tensor([1., 0., 0., 0.]), wires=['a', 'b'])]
"""
return [StatePrep(features, wires=wires)]

@staticmethod
def _preprocess(features, wires, pad_with, normalize):
"""Validate and pre-process inputs as follows:
* If features is batched, the processing that follows is applied to each feature set in the batch.
* Check that the features tensor is one-dimensional.
* If pad_with is None, check that the last dimension of the features tensor
has length :math:`2^n` where :math:`n` is the number of qubits. Else check that the
last dimension of the features tensor is not larger than :math:`2^n` and pad features
with value if necessary.
* If normalize is false, check that last dimension of features is normalised to one. Else, normalise the
features tensor.
"""
if isinstance(features, (list, tuple)):
features = qml.math.array(features)
shape = qml.math.shape(features)

# check shape
if len(shape) not in (1, 2):
raise ValueError(
f"Features must be a one-dimensional tensor, or two-dimensional with batching; got shape {shape}."
)

n_features = shape[-1]
dim = 2 ** len(wires)
if pad_with is None and n_features != dim:
raise ValueError(
f"Features must be of length {dim}; got length {n_features}. "
f"Use the 'pad_with' argument for automated padding."
)

if pad_with is not None:
if n_features > dim:
raise ValueError(
f"Features must be of length {dim} or "
f"smaller to be padded; got length {n_features}."
)

# pad
if n_features < dim:
padding = [pad_with] * (dim - n_features)
if len(shape) > 1:
padding = [padding] * shape[0]
padding = qml.math.convert_like(padding, features)
features = qml.math.hstack([features, padding])

# normalize
if "int" in str(features.dtype):
features = qml.math.cast_like(features, 0.0)
norm = qml.math.linalg.norm(features, axis=-1)

if qml.math.is_abstract(norm):
if normalize or pad_with:
features = features / qml.math.reshape(norm, (*shape[:-1], 1))

elif not qml.math.allclose(norm, 1.0, atol=TOLERANCE):
if normalize or pad_with:
features = features / qml.math.reshape(norm, (*shape[:-1], 1))
else:
raise ValueError(
f"Features must be a vector of norm 1.0; got norm {norm}. "
"Use 'normalize=True' to automatically normalize."
)

return features
super(StatePrep, self).__init__(features, wires=wires, pad_with=pad_with, normalize=normalize, id=id)
31 changes: 29 additions & 2 deletions tests/ops/qubit/test_state_prep.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,32 @@ def test_StatePrep_decomposition(self):
assert isinstance(ops1[0], qml.MottonenStatePreparation)
assert isinstance(ops2[0], qml.MottonenStatePreparation)

def test_StatePrep_padding(self):
"""Test that StatePrep pads the input state correctly."""

state = np.array([1, 0])
wires = (0, 1)

@qml.qnode(qml.device("default.qubit", wires=2))
def circuit():
qml.StatePrep(state, pad_with=0, wires=wires)
return qml.state()

assert np.allclose(circuit(), np.array([1, 0, 0, 0]))

def test_StatePrep_normalize(self):
"""Test that StatePrep normalizes the input state correctly."""

state = np.array([1, 1, 1, 1])
wires = (0, 1)

@qml.qnode(qml.device("default.qubit", wires=2))
def circuit():
qml.StatePrep(state, normalize=True, wires=wires)
return qml.state()

assert np.allclose(circuit(), np.array([1, 1, 1, 1]) / 2)

def test_StatePrep_broadcasting(self):
"""Test broadcasting for StatePrep."""

Expand Down Expand Up @@ -182,12 +208,13 @@ def test_StatePrep_state_vector_bad_wire_order(self):
@pytest.mark.parametrize("vec", [[0] * 4, [1] * 4])
def test_StatePrep_state_norm_not_one_fails(self, vec):
"""Tests that the state-vector provided must have norm equal to 1."""
with pytest.raises(ValueError, match="Sum of amplitudes-squared does not equal one."):

with pytest.raises(ValueError, match="The state must be a vector of norm 1"):
_ = qml.StatePrep(vec, wires=[0, 1])

def test_StatePrep_wrong_param_size_fails(self):
"""Tests that the parameter must be of shape (2**num_wires,)."""
with pytest.raises(ValueError, match="State vector must have shape"):
with pytest.raises(ValueError, match="State must be of length"):
_ = qml.StatePrep([0, 1], wires=[0, 1])

@pytest.mark.torch
Expand Down

0 comments on commit 238af73

Please sign in to comment.