Skip to content

Commit

Permalink
Merge branch 'master' into plxpr-capture-measurements
Browse files Browse the repository at this point in the history
  • Loading branch information
albi3ro committed May 14, 2024
2 parents 3c8ed40 + 57e8d93 commit 391d77e
Show file tree
Hide file tree
Showing 25 changed files with 324 additions and 118 deletions.
22 changes: 19 additions & 3 deletions doc/releases/changelog-dev.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,14 @@

<h3>Improvements 🛠</h3>

* The sorting order of parameter-shift terms is now guaranteed to resolve ties in the absolute value with the sign of the shifts.
[(#5582)](https://github.com/PennyLaneAI/pennylane/pull/5582)

<h4>Mid-circuit measurements and dynamic circuits</h4>

* The `dynamic_one_shot` transform uses a single auxiliary tape with a shot vector and `default.qubit` implements the loop over shots with `jax.vmap`.
[(#5617)](https://github.com/PennyLaneAI/pennylane/pull/5617)

* The `dynamic_one_shot` transform can be compiled with `jax.jit`.
[(#5557)](https://github.com/PennyLaneAI/pennylane/pull/5557)

Expand Down Expand Up @@ -84,7 +90,7 @@

* ``qml.from_qasm_file`` has been removed. The user can open files and load their content using `qml.from_qasm`.
[(#5659)](https://github.com/PennyLaneAI/pennylane/pull/5659)

* ``qml.load`` has been removed in favour of more specific functions, such as ``qml.from_qiskit``, etc.
[(#5654)](https://github.com/PennyLaneAI/pennylane/pull/5654)

Expand All @@ -94,9 +100,19 @@

<h3>Bug fixes 🐛</h3>

* `param_shift`, `finite_diff`, `compile`, `merge_rotations`, and `transpile` now all work
with circuits with non-commuting measurements.
* Use vanilla NumPy arrays in `test_projector_expectation` to avoid differentiating `qml.Projector` with respect to the state attribute.
[(#5683)](https://github.com/PennyLaneAI/pennylane/pull/5683)

* `qml.Projector` is now compatible with jax-jit.
[(#5595)](https://github.com/PennyLaneAI/pennylane/pull/5595)

* Finite shot circuits with a `qml.probs` measurement, both with a `wires` or `op` argument, can now be compiled with `jax.jit`.
[(#5619)](https://github.com/PennyLaneAI/pennylane/pull/5619)

* `param_shift`, `finite_diff`, `compile`, `insert`, `merge_rotations`, and `transpile` now
all work with circuits with non-commuting measurements.
[(#5424)](https://github.com/PennyLaneAI/pennylane/pull/5424)
[(#5681)](https://github.com/PennyLaneAI/pennylane/pull/5681)

* A correction is added to `bravyi_kitaev` to call the correct function for a FermiSentence input.
[(#5671)](https://github.com/PennyLaneAI/pennylane/pull/5671)
Expand Down
15 changes: 13 additions & 2 deletions pennylane/_qubit_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,8 +275,19 @@ def execute(self, circuit, **kwargs):
self.check_validity(circuit.operations, circuit.observables)

has_mcm = any(isinstance(op, MidMeasureMP) for op in circuit.operations)
if has_mcm:
kwargs["mid_measurements"] = {}
if has_mcm and "mid_measurements" not in kwargs:
results = []
aux_circ = qml.tape.QuantumScript(
circuit.operations,
circuit.measurements,
shots=[1],
trainable_params=circuit.trainable_params,
)
for _ in circuit.shots:
kwargs["mid_measurements"] = {}
self.reset()
results.append(self.execute(aux_circ, **kwargs))
return tuple(results)
# apply all circuit operations
self.apply(
circuit.operations,
Expand Down
4 changes: 2 additions & 2 deletions pennylane/capture/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def qfunc(a):
(where ``cls`` indicates the class) if:
* The operator does not accept wires, like :class:`~.SymbolicOp` or :class:`~.CompositeOp`.
* The operator allows metadata to be provided positionally, like :class:`~.PauliRot`.
* The operator needs to enforce a data/ metadata distinction, like :class:`~.PauliRot`.
In such cases, the operator developer can override ``cls._primitive_bind_call``, which
will be called when constructing a new class instance instead of ``type.__call__``. For example,
Expand All @@ -78,7 +78,7 @@ def qfunc(a):
class JustMetadataOp(qml.operation.Operator):
def __init__(self, metadata="X"):
def __init__(self, metadata):
super().__init__(wires=[])
self._metadata = metadata
Expand Down
1 change: 0 additions & 1 deletion pennylane/capture/primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@

from functools import lru_cache
from typing import Callable, Optional

import pennylane as qml

has_jax = True
Expand Down
28 changes: 26 additions & 2 deletions pennylane/devices/qubit/simulate.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,9 +279,33 @@ def simulate(

has_mcm = any(isinstance(op, MidMeasureMP) for op in circuit.operations)
if circuit.shots and has_mcm:
return simulate_one_shot_native_mcm(
circuit, debugger=debugger, rng=rng, prng_key=prng_key, interface=interface
results = []
aux_circ = qml.tape.QuantumScript(
circuit.operations,
circuit.measurements,
shots=[1],
trainable_params=circuit.trainable_params,
)
keys = jax_random_split(prng_key, num=circuit.shots.total_shots)
if qml.math.get_deep_interface(circuit.data) == "jax":
# pylint: disable=import-outside-toplevel
import jax

def simulate_partial(k):
return simulate_one_shot_native_mcm(
aux_circ, debugger=debugger, rng=rng, prng_key=k, interface=interface
)

results = jax.vmap(simulate_partial, in_axes=(0,))(keys)
results = tuple(zip(*results))
else:
for i in range(circuit.shots.total_shots):
results.append(
simulate_one_shot_native_mcm(
aux_circ, debugger=debugger, rng=rng, prng_key=keys[i], interface=interface
)
)
return tuple(results)

ops_key, meas_key = jax_random_split(prng_key)
state, is_state_batched = get_final_state(
Expand Down
10 changes: 6 additions & 4 deletions pennylane/devices/tests/test_compare_default_qubit.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests that a device gives the same output as the default device."""
import numpy as np

# pylint: disable=no-self-use,no-member
import pytest
from flaky import flaky
Expand Down Expand Up @@ -90,10 +92,10 @@ def workload():
[0, 1, 0, 0],
[0, 0, 1, 0],
[0, 0, 0, 1],
pnp.array([1, 1, 0, 0]) / pnp.sqrt(2),
pnp.array([0, 1, 0, 1]) / pnp.sqrt(2),
pnp.array([1, 1, 1, 0]) / pnp.sqrt(3),
pnp.array([1, 1, 1, 1]) / 2,
np.array([1, 1, 0, 0]) / np.sqrt(2),
np.array([0, 1, 0, 1]) / np.sqrt(2),
np.array([1, 1, 1, 0]) / np.sqrt(3),
np.array([1, 1, 1, 1]) / 2,
],
)
def test_projector_expectation(self, device, state, tol, benchmark):
Expand Down
6 changes: 4 additions & 2 deletions pennylane/gradients/general_shift_rules.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ def process_shifts(rule, tol=1e-10, batch_duplicates=True):
- Finally, the terms are sorted according to the absolute value of ``shift``,
This ensures that a zero-shift term, if it exists, is returned first.
For equal absolute values of two shifts, the positive shift is sorted to come first.
"""
# set all small coefficients, multipliers if present, and shifts to zero.
rule[np.abs(rule) < tol] = 0
Expand All @@ -78,8 +79,9 @@ def process_shifts(rule, tol=1e-10, batch_duplicates=True):
coeffs = [np.sum(rule[slc, 0]) for slc in matches.T]
rule = np.hstack([np.stack(coeffs)[:, np.newaxis], unique_mods])

# sort columns according to abs(shift)
return rule[np.argsort(np.abs(rule[:, -1]), kind="stable")]
# sort columns according to abs(shift), ties are resolved with the sign,
# positive shifts being returned before negative shifts.
return rule[np.lexsort((-np.sign(rule[:, -1]), np.abs(rule[:, -1])))]


@functools.lru_cache(maxsize=None)
Expand Down
1 change: 1 addition & 0 deletions pennylane/math/single_dispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,6 +243,7 @@ def _take_autograd(tensor, indices, axis=None):
ar.autoray._SUBMODULE_ALIASES["tensorflow", "isclose"] = "tensorflow.experimental.numpy"
ar.autoray._SUBMODULE_ALIASES["tensorflow", "atleast_1d"] = "tensorflow.experimental.numpy"
ar.autoray._SUBMODULE_ALIASES["tensorflow", "all"] = "tensorflow.experimental.numpy"
ar.autoray._SUBMODULE_ALIASES["tensorflow", "vstack"] = "tensorflow.experimental.numpy"

tf_fft_functions = [
"fft",
Expand Down
40 changes: 25 additions & 15 deletions pennylane/measurements/probs.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,23 +263,33 @@ def _count_samples(indices, batch_size, dim):
"""Count the occurrences of sampled indices and convert them to relative
counts in order to estimate their occurrence probability."""
num_bins, bin_size = indices.shape[-2:]
if batch_size is None:
prob = qml.math.zeros((dim, num_bins), dtype="float64")
# count the basis state occurrences, and construct the probability vector for each bin
for b, idx in enumerate(indices):
basis_states, counts = qml.math.unique(idx, return_counts=True)
prob[basis_states, b] = counts / bin_size
interface = qml.math.get_deep_interface(indices)

return prob
if qml.math.is_abstract(indices):

prob = qml.math.zeros((batch_size, dim, num_bins), dtype="float64")
indices = indices.reshape((batch_size, num_bins, bin_size))
def _count_samples_core(indices, dim, interface):
return qml.math.array(
[[qml.math.sum(idx == p) for idx in indices] for p in range(dim)],
like=interface,
)

else:

def _count_samples_core(indices, dim, *_):
probabilities = qml.math.zeros((dim, num_bins), dtype="float64")
for b, idx in enumerate(indices):
basis_states, counts = qml.math.unique(idx, return_counts=True)
probabilities[basis_states, b] = counts
return probabilities

if batch_size is None:
return _count_samples_core(indices, dim, interface) / bin_size

# count the basis state occurrences, and construct the probability vector
# for each bin and broadcasting index
for i, _indices in enumerate(indices): # First iterate over broadcasting dimension
for b, idx in enumerate(_indices): # Then iterate over bins dimension
basis_states, counts = qml.math.unique(idx, return_counts=True)
prob[i, basis_states, b] = counts / bin_size

return prob
indices = indices.reshape((batch_size, num_bins, bin_size))
probabilities = qml.math.array(
[_count_samples_core(_indices, dim, interface) for _indices in indices],
like=interface,
)
return probabilities / bin_size
5 changes: 3 additions & 2 deletions pennylane/ops/op_math/composite.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,8 +180,9 @@ def eigvals(self):
list(self.wires),
)
)

return self._math_op(math.asarray(eigvals, like=math.get_deep_interface(eigvals)), axis=0)
framework = math.get_deep_interface(eigvals)
eigvals = [math.asarray(ei, like=framework) for ei in eigvals]
return self._math_op(math.vstack(eigvals), axis=0)

@abc.abstractmethod
def matrix(self, wire_order=None):
Expand Down
64 changes: 53 additions & 11 deletions pennylane/ops/qubit/observables.py
Original file line number Diff line number Diff line change
Expand Up @@ -435,14 +435,23 @@ class BasisStateProjector(Projector, Operation):
r"""Observable corresponding to the state projector :math:`P=\ket{\phi}\bra{\phi}`, where
:math:`\phi` denotes a basis state."""

grad_method = None

# The call signature should be the same as Projector.__new__ for the positional
# arguments, but with free key word arguments.
def __init__(self, state, wires, id=None):
wires = Wires(wires)
state = list(qml.math.toarray(state).astype(int))

if not set(state).issubset({0, 1}):
raise ValueError(f"Basis state must only consist of 0s and 1s; got {state}")
if qml.math.get_interface(state) == "jax":
dtype = qml.math.dtype(state)
if not (np.issubdtype(dtype, np.integer) or np.issubdtype(dtype, bool)):
raise ValueError("Basis state must consist of integers or booleans.")
else:
# state is index into data, rather than data, so cast it to built-ins when
# no need for tracing
state = tuple(qml.math.toarray(state).astype(int))
if not set(state).issubset({0, 1}):
raise ValueError(f"Basis state must only consist of 0s and 1s; got {state}")

super().__init__(state, wires=wires, id=id)

Expand Down Expand Up @@ -471,7 +480,7 @@ def label(self, decimals=None, base_label=None, cache=None):

if base_label is not None:
return base_label
basis_string = "".join(str(int(i)) for i in self.parameters[0])
basis_string = "".join(str(int(i)) for i in self.data[0])
return f"|{basis_string}⟩⟨{basis_string}|"

@staticmethod
Expand All @@ -497,7 +506,15 @@ def compute_matrix(basis_state): # pylint: disable=arguments-differ
[0. 0. 0. 0.]
[0. 0. 0. 0.]]
"""
m = np.zeros((2 ** len(basis_state), 2 ** len(basis_state)))
shape = (2 ** len(basis_state), 2 ** len(basis_state))
if qml.math.get_interface(basis_state) == "jax":
idx = 0
for i, m in enumerate(basis_state):
idx = idx + (m << (len(basis_state) - i - 1))
mat = qml.math.zeros(shape, like=basis_state)
return mat.at[idx, idx].set(1.0)

m = np.zeros(shape)
idx = int("".join(str(i) for i in basis_state), 2)
m[idx, idx] = 1
return m
Expand Down Expand Up @@ -528,6 +545,12 @@ def compute_eigvals(basis_state): # pylint: disable=arguments-differ
>>> BasisStateProjector.compute_eigvals([0, 1])
[0. 1. 0. 0.]
"""
if qml.math.get_interface(basis_state) == "jax":
idx = 0
for i, m in enumerate(basis_state):
idx = idx + (m << (len(basis_state) - i - 1))
eigvals = qml.math.zeros(2 ** len(basis_state), like=basis_state)
return eigvals.at[idx].set(1.0)
w = np.zeros(2 ** len(basis_state))
idx = int("".join(str(i) for i in basis_state), 2)
w[idx] = 1
Expand Down Expand Up @@ -566,12 +589,12 @@ class StateVectorProjector(Projector):
r"""Observable corresponding to the state projector :math:`P=\ket{\phi}\bra{\phi}`, where
:math:`\phi` denotes a state."""

grad_method = None

# The call signature should be the same as Projector.__new__ for the positional
# arguments, but with free key word arguments.
def __init__(self, state, wires, id=None):
wires = Wires(wires)
state = list(qml.math.toarray(state))

super().__init__(state, wires=wires, id=id)

def __new__(cls, *_, **__): # pylint: disable=arguments-differ
Expand Down Expand Up @@ -680,9 +703,11 @@ def compute_eigvals(state_vector): # pylint: disable=arguments-differ,arguments
>>> StateVectorProjector.compute_eigvals([0, 0, 1, 0])
array([1, 0, 0, 0])
"""
w = qml.math.zeros_like(state_vector)
precision = qml.math.get_dtype_name(state_vector)[-2:]
dtype = f"float{precision}" if precision in {"32", "64"} else "float64"
w = np.zeros(qml.math.shape(state_vector), dtype=dtype)
w[0] = 1
return w
return qml.math.convert_like(w, state_vector)

@staticmethod
def compute_diagonalizing_gates(
Expand Down Expand Up @@ -715,10 +740,27 @@ def compute_diagonalizing_gates(
# Adapting the approach discussed in the link below to work with arbitrary complex-valued state vectors.
# Alternatively, we could take the adjoint of the Mottonen decomposition for the state vector.
# https://quantumcomputing.stackexchange.com/questions/10239/how-can-i-fill-a-unitary-knowing-only-its-first-column
phase = qml.math.exp(-1j * qml.math.angle(state_vector[0]))

if qml.math.get_interface(state_vector) == "tensorflow":
dtype_name = qml.math.get_dtype_name(state_vector)
if dtype_name == "int32":
state_vector = qml.math.cast(state_vector, np.complex64)
elif dtype_name == "int64":
state_vector = qml.math.cast(state_vector, np.complex128)

angle = qml.math.angle(state_vector[0])
if qml.math.get_interface(angle) == "tensorflow":
if qml.math.get_dtype_name(angle) == "float32":
angle = qml.math.cast(angle, np.complex64)
else:
angle = qml.math.cast(angle, np.complex128)

phase = qml.math.exp(-1.0j * angle)
psi = phase * state_vector
denominator = qml.math.sqrt(2 + 2 * psi[0])
psi = qml.math.set_index(psi, 0, psi[0] + 1) # psi[0] += 1, but JAX-JIT compatible
summed_array = np.zeros(qml.math.shape(psi), dtype=qml.math.get_dtype_name(psi))
summed_array[0] = 1.0
psi = psi + summed_array
psi /= denominator
u = 2 * qml.math.outer(psi, qml.math.conj(psi)) - qml.math.eye(len(psi))
return [QubitUnitary(u, wires=wires)]
Loading

0 comments on commit 391d77e

Please sign in to comment.