Skip to content

Commit

Permalink
Create projectors in any basis (#4192)
Browse files Browse the repository at this point in the history
* First version of StateVectorProjector

* Outer product fix

* Support projector in default qubit

* change inheritance

* consistent diagonalizing gates

* Update docstring

* Proper bra-ket in label

* Tests for StateVectorProjector

* fix tests

* Projector prototype

* Explicit signature and pow method

* discard hacky prototype

* New hack (thanks Tymmy <3)

* Exception tests

* Docstring for __new__

* Label method adds matrix to cache

* tests for label method

* Update docstrings

* Update changelog

* enhanced docstring for label method

* `Projector.__new__` docstring rephrased

* Fix docstring

Co-authored-by: Matthew Silverman <matthews@xanadu.ai>

* Remove boolean kwarg

* Remove `basis_representation` from changelog

* Fix wire length issue

* Update Projector example in docstring

Co-authored-by: Matthew Silverman <matthews@xanadu.ai>

* Update changelog description

Co-authored-by: Tom Bromley <49409390+trbromley@users.noreply.github.com>

* Update projector description

Co-authored-by: Tom Bromley <49409390+trbromley@users.noreply.github.com>

* change argument names to be `state`

* Remove shape indication in docstring

Co-authored-by: Tom Bromley <49409390+trbromley@users.noreply.github.com>

* fix expval with state vector projector

* Code example and remove hidden class docstring

* minor docstring corrections

* add expval test (it was failing)

* projector bind new parameters dispatcher

* Update error string

Co-authored-by: Matthew Silverman <matthews@xanadu.ai>

* fix projector and qubit device

* extensive projector testing

* remove outdated test

* fix tests typo

* Additional indications on input shape

Co-authored-by: Mudit Pandey <mudit.pandey@xanadu.ai>

* revision feedback

* update copy

---------

Co-authored-by: = <=>
Co-authored-by: Matthew Silverman <matthews@xanadu.ai>
Co-authored-by: Tom Bromley <49409390+trbromley@users.noreply.github.com>
Co-authored-by: Mudit Pandey <mudit.pandey@xanadu.ai>
  • Loading branch information
4 people authored Jun 16, 2023
1 parent f42fb92 commit e102fe2
Show file tree
Hide file tree
Showing 12 changed files with 722 additions and 192 deletions.
20 changes: 20 additions & 0 deletions doc/releases/changelog-dev.md
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,26 @@

<h3>Improvements 🛠</h3>

* `Projector` now accepts a state vector representation, which enables the creation of projectors
in any basis.
[(#4192)](https://github.com/PennyLaneAI/pennylane/pull/4192)

```python
dev = qml.device("default.qubit", wires=2)
@qml.qnode(dev)
def circuit(state):
return qml.expval(qml.Projector(state, wires=[0, 1]))
zero_state = [0, 0]
plusplus_state = np.array([1, 1, 1, 1]) / 2
```
```pycon
>>> circuit(zero_state)
1.
>>>
>>> circuit(plusplus_state)
0.25
```

* The pulse differentiation methods, `pulse_generator` and `stoch_pulse_grad` now raise an error when they
are applied to a `QNode` directly. Instead, use differentiation via a JAX entry point (`jax.grad`, `jax.jacobian`, ...).
[(4241)](https://github.com/PennyLaneAI/pennylane/pull/4241)
Expand Down
14 changes: 9 additions & 5 deletions pennylane/_qubit_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,7 +298,7 @@ def execute(self, circuit, **kwargs):
* :meth:`~.probability`
Additional keyword arguments may be passed to the this method
Additional keyword arguments may be passed to this method
that can be utilised by :meth:`apply`. An example would be passing
the ``QNode`` hash that can be used later for parametric compilation.
Expand Down Expand Up @@ -1649,8 +1649,10 @@ def marginal_prob(self, prob, wires=None):
return self._reshape(prob, flat_shape)

def expval(self, observable, shot_range=None, bin_size=None):
if observable.name == "Projector":
# branch specifically to handle the projector observable
if observable.name == "Projector" and len(observable.parameters[0]) == len(
observable.wires
):
# branch specifically to handle the basis state projector observable
idx = int("".join(str(i) for i in observable.parameters[0]), 2)
probs = self.probability(
wires=observable.wires, shot_range=shot_range, bin_size=bin_size
Expand Down Expand Up @@ -1679,8 +1681,10 @@ def expval(self, observable, shot_range=None, bin_size=None):
return np.squeeze(np.mean(samples, axis=axis))

def var(self, observable, shot_range=None, bin_size=None):
if observable.name == "Projector":
# branch specifically to handle the projector observable
if observable.name == "Projector" and len(observable.parameters[0]) == len(
observable.wires
):
# branch specifically to handle the basis state projector observable
idx = int("".join(str(i) for i in observable.parameters[0]), 2)
probs = self.probability(
wires=observable.wires, shot_range=shot_range, bin_size=bin_size
Expand Down
22 changes: 19 additions & 3 deletions pennylane/devices/tests/test_compare_default_qubit.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,23 @@ def circuit(theta, phi):
assert np.allclose(qnode(theta, phi), qnode_def(theta, phi), atol=tol(dev.shots))
assert np.allclose(grad(theta, phi), grad_def(theta, phi), atol=tol(dev.shots))

@pytest.mark.parametrize("state", [[0, 0], [0, 1], [1, 0], [1, 1]])
@pytest.mark.parametrize(
"state",
[
[0, 0],
[0, 1],
[1, 0],
[1, 1],
[1, 0, 0, 0],
[0, 1, 0, 0],
[0, 0, 1, 0],
[0, 0, 0, 1],
np.array([1, 1, 0, 0]) / np.sqrt(2),
np.array([0, 1, 0, 1]) / np.sqrt(2),
np.array([1, 1, 1, 0]) / np.sqrt(3),
np.array([1, 1, 1, 1]) / 2,
],
)
def test_projector_expectation(self, device, state, tol):
"""Test that arbitrary multi-mode Projector expectation values are correct"""
n_wires = 2
Expand All @@ -88,11 +104,11 @@ def test_projector_expectation(self, device, state, tol):
theta = 0.432
phi = 0.123

def circuit(theta, phi, basis_state):
def circuit(theta, phi, state):
qml.RX(theta, wires=[0])
qml.RX(phi, wires=[1])
qml.CNOT(wires=[0, 1])
return qml.expval(qml.Projector(basis_state, wires=[0, 1]))
return qml.expval(qml.Projector(state, wires=[0, 1]))

qnode_def = qml.QNode(circuit, dev_def)
qnode = qml.QNode(circuit, dev)
Expand Down
Loading

0 comments on commit e102fe2

Please sign in to comment.