diff --git a/doc/_static/templates/qrom/qrom_thumbnail.png b/doc/_static/templates/qrom/qrom_thumbnail.png
new file mode 100644
index 00000000000..977465fc9c9
Binary files /dev/null and b/doc/_static/templates/qrom/qrom_thumbnail.png differ
diff --git a/doc/introduction/templates.rst b/doc/introduction/templates.rst
index 1cda874fcbc..2969f5168b2 100644
--- a/doc/introduction/templates.rst
+++ b/doc/introduction/templates.rst
@@ -303,6 +303,11 @@ Other useful templates which do not belong to the previous categories can be fou
:description: :doc:`Qubitization <../code/api/pennylane.Qubitization>`
:figure: _static/templates/qubitization/thumbnail_qubitization.png
+.. gallery-item::
+ :description: :doc:`QROM <../code/api/pennylane.QROM>`
+ :figure: _static/templates/qrom/qrom_thumbnail.png
+
+
.. raw:: html
diff --git a/doc/releases/changelog-dev.md b/doc/releases/changelog-dev.md
index 6596be92a6d..e57878e54e9 100644
--- a/doc/releases/changelog-dev.md
+++ b/doc/releases/changelog-dev.md
@@ -4,6 +4,33 @@
New features since last release
+* QROM template is added. This template allows you to enter classic data in the form of bitstrings.
+ [(#5688)](https://github.com/PennyLaneAI/pennylane/pull/5688)
+
+ ```python
+ # a list of bitstrings is defined
+ bitstrings = ["010", "111", "110", "000"]
+
+ dev = qml.device("default.qubit", shots = 1)
+
+ @qml.qnode(dev)
+ def circuit():
+
+ # the third index is encoded in the control wires [0, 1]
+ qml.BasisEmbedding(2, wires = [0,1])
+
+ qml.QROM(bitstrings = bitstrings,
+ control_wires = [0,1],
+ target_wires = [2,3,4],
+ work_wires = [5,6,7])
+
+ return qml.sample(wires = [2,3,4])
+ ```
+ ```pycon
+ >>> print(circuit())
+ [1 1 0]
+ ```
+
* `qml.QNode` and `qml.qnode` now accept two new keyword arguments: `postselect_mode` and `mcm_method`.
These keyword arguments can be used to configure how the device should behave when running circuits with
mid-circuit measurements.
@@ -20,6 +47,7 @@
* The `default.tensor` device is introduced to perform tensor network simulation of a quantum circuit.
[(#5699)](https://github.com/PennyLaneAI/pennylane/pull/5699)
+
Improvements ðŸ›
* The wires for the `default.tensor` device are selected at runtime if they are not provided by user.
@@ -308,6 +336,7 @@
This release contains contributions from (in alphabetical order):
+Guillermo Alonso-Linaje,
Lillian M. A. Frederiksen,
Gabriel Bottrill,
Astral Cai,
diff --git a/pennylane/templates/subroutines/__init__.py b/pennylane/templates/subroutines/__init__.py
index 0e5a8b2caac..b0acee37f5f 100644
--- a/pennylane/templates/subroutines/__init__.py
+++ b/pennylane/templates/subroutines/__init__.py
@@ -43,3 +43,4 @@
from .reflection import Reflection
from .amplitude_amplification import AmplitudeAmplification
from .qubitization import Qubitization
+from .qrom import QROM
diff --git a/pennylane/templates/subroutines/qrom.py b/pennylane/templates/subroutines/qrom.py
new file mode 100644
index 00000000000..38c0f99ddc0
--- /dev/null
+++ b/pennylane/templates/subroutines/qrom.py
@@ -0,0 +1,295 @@
+# Copyright 2018-2024 Xanadu Quantum Technologies Inc.
+
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+
+# http://www.apache.org/licenses/LICENSE-2.0
+
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+This submodule contains the template for QROM.
+"""
+
+import math
+
+import numpy as np
+
+import pennylane as qml
+from pennylane.operation import Operation
+
+
+def _multi_swap(wires1, wires2):
+ """Apply a series of SWAP gates between two sets of wires."""
+ for wire1, wire2 in zip(wires1, wires2):
+ qml.SWAP(wires=[wire1, wire2])
+
+
+class QROM(Operation):
+ r"""Applies the QROM operator.
+
+ This operator encodes bitstrings associated with indexes:
+
+ .. math::
+ \text{QROM}|i\rangle|0\rangle = |i\rangle |b_i\rangle,
+
+ where :math:`b_i` is the bitstring associated with index :math:`i`.
+
+ Args:
+ bitstrings (list[str]): the bitstrings to be encoded
+ control_wires (Sequence[int]): the wires where the indexes are specified
+ target_wires (Sequence[int]): the wires where the bitstring is loaded
+ work_wires (Sequence[int]): the auxiliary wires used for the computation
+ clean (bool): if True, the work wires are not altered by operator, default is ``True``
+
+ **Example**
+
+ In this example, the QROM operator is applied to encode the third bitstring, associated with index 2, in the target wires.
+
+ .. code-block::
+
+ # a list of bitstrings is defined
+ bitstrings = ["010", "111", "110", "000"]
+
+ dev = qml.device("default.qubit", shots = 1)
+
+ @qml.qnode(dev)
+ def circuit():
+
+ # the third index is encoded in the control wires [0, 1]
+ qml.BasisEmbedding(2, wires = [0,1])
+
+ qml.QROM(bitstrings = bitstrings,
+ control_wires = [0,1],
+ target_wires = [2,3,4],
+ work_wires = [5,6,7])
+
+ return qml.sample(wires = [2,3,4])
+
+ .. code-block:: pycon
+
+ >>> print(circuit())
+ [1 1 0]
+
+
+ .. details::
+ :title: Usage Details
+
+ This template takes as input three different sets of wires. The first one is ``control_wires`` which is used
+ to encode the desired index. Therefore, if we have :math:`m` bitstrings, we need
+ at least :math:`\lceil \log_2(m)\rceil` control wires.
+
+ The second set of wires is ``target_wires`` which stores the bitstrings.
+ For instance, if the bitstring is "0110", we will need four target wires. Internally, the bitstrings are
+ encoded using the :class:`~.BasisEmbedding` template.
+
+
+ The ``work_wires`` are the auxiliary qubits used by the template to reduce the number of gates required.
+ Let :math:`k` be the number of work wires. If :math:`k = 0`, the template is equivalent to executing :class:`~.Select`.
+ Following the idea in [`arXiv:1812.00954 `__], auxiliary qubits can be used to
+ load more than one bitstring in parallel . Let :math:`\lambda` be
+ the number of bitstrings we want to store in parallel, assumed to be a power of :math:`2`.
+ Then, :math:`k = l \cdot (\lambda-1)` work wires are needed,
+ where :math:`l` is the length of the bitstrings.
+
+ The QROM template has two variants. The first one (``clean = False``) is based on [`arXiv:1812.00954 `__] that alterates the state in the ``work_wires``.
+ The second one (``clean = True``), based on [`arXiv:1902.02134 `__], solves that issue by
+ returning ``work_wires`` to their initial state. This technique can be applied when the ``work_wires`` are not
+ initialized to zero.
+
+ """
+
+ def __init__(
+ self, bitstrings, control_wires, target_wires, work_wires, clean=True, id=None
+ ): # pylint: disable=too-many-arguments
+
+ control_wires = qml.wires.Wires(control_wires)
+ target_wires = qml.wires.Wires(target_wires)
+
+ work_wires = qml.wires.Wires(work_wires) if work_wires else qml.wires.Wires([])
+
+ self.hyperparameters["bitstrings"] = tuple(bitstrings)
+ self.hyperparameters["control_wires"] = control_wires
+ self.hyperparameters["target_wires"] = target_wires
+ self.hyperparameters["work_wires"] = work_wires
+ self.hyperparameters["clean"] = clean
+
+ if work_wires:
+ if any(wire in work_wires for wire in control_wires):
+ raise ValueError("Control wires should be different from work wires.")
+
+ if any(wire in work_wires for wire in target_wires):
+ raise ValueError("Target wires should be different from work wires.")
+
+ if any(wire in control_wires for wire in target_wires):
+ raise ValueError("Target wires should be different from control wires.")
+
+ if 2 ** len(control_wires) < len(bitstrings):
+ raise ValueError(
+ f"Not enough control wires ({len(control_wires)}) for the desired number of "
+ + f"bitstrings ({len(bitstrings)}). At least {int(math.ceil(math.log2(len(bitstrings))))} control "
+ + "wires are required."
+ )
+
+ if len(bitstrings[0]) != len(target_wires):
+ raise ValueError("Bitstring length must match the number of target wires.")
+
+ all_wires = target_wires + control_wires + work_wires
+ super().__init__(wires=all_wires, id=id)
+
+ def _flatten(self):
+ metadata = tuple((key, value) for key, value in self.hyperparameters.items())
+ return tuple(), metadata
+
+ @classmethod
+ def _unflatten(cls, data, metadata):
+ hyperparams_dict = dict(metadata)
+ return cls(**hyperparams_dict)
+
+ def __repr__(self):
+ return f"QROM(control_wires={self.control_wires}, target_wires={self.target_wires}, work_wires={self.work_wires}, clean={self.clean})"
+
+ def map_wires(self, wire_map: dict):
+ new_dict = {
+ key: [wire_map.get(w, w) for w in self.hyperparameters[key]]
+ for key in ["target_wires", "control_wires", "work_wires"]
+ }
+
+ return QROM(
+ self.bitstrings,
+ new_dict["control_wires"],
+ new_dict["target_wires"],
+ new_dict["work_wires"],
+ self.clean,
+ )
+
+ def __copy__(self):
+ """Copy this op"""
+ cls = self.__class__
+ copied_op = cls.__new__(cls)
+
+ for attr, value in vars(self).items():
+ setattr(copied_op, attr, value)
+
+ return copied_op
+
+ def decomposition(self): # pylint: disable=arguments-differ
+
+ return self.compute_decomposition(
+ self.bitstrings,
+ control_wires=self.control_wires,
+ target_wires=self.target_wires,
+ work_wires=self.work_wires,
+ clean=self.clean,
+ )
+
+ @staticmethod
+ def compute_decomposition(
+ bitstrings, control_wires, target_wires, work_wires, clean
+ ): # pylint: disable=arguments-differ
+ with qml.QueuingManager.stop_recording():
+
+ swap_wires = target_wires + work_wires
+
+ # number of operators we store per column (power of 2)
+ depth = len(swap_wires) // len(target_wires)
+ depth = int(2 ** np.floor(np.log2(depth)))
+
+ ops = [qml.BasisEmbedding(int(bits, 2), wires=target_wires) for bits in bitstrings]
+ ops_identity = ops + [qml.I(target_wires)] * int(2 ** len(control_wires) - len(ops))
+
+ n_columns = len(ops) // depth if len(ops) % depth == 0 else len(ops) // depth + 1
+ new_ops = []
+ for i in range(n_columns):
+ column_ops = []
+ for j in range(depth):
+ dic_map = {
+ ops_identity[i * depth + j].wires[l]: swap_wires[j * len(target_wires) + l]
+ for l in range(len(target_wires))
+ }
+ column_ops.append(qml.map_wires(ops_identity[i * depth + j], dic_map))
+ new_ops.append(qml.prod(*column_ops))
+
+ # Select block
+ n_control_select_wires = int(math.ceil(math.log2(2 ** len(control_wires) / depth)))
+ control_select_wires = control_wires[:n_control_select_wires]
+
+ select_ops = []
+ if control_select_wires:
+ select_ops += [qml.Select(new_ops, control=control_select_wires)]
+ else:
+ select_ops = new_ops
+
+ # Swap block
+ control_swap_wires = control_wires[n_control_select_wires:]
+ swap_ops = []
+ for ind in range(len(control_swap_wires)):
+ for j in range(2**ind):
+ new_op = qml.prod(_multi_swap)(
+ swap_wires[(j) * len(target_wires) : (j + 1) * len(target_wires)],
+ swap_wires[
+ (j + 2**ind)
+ * len(target_wires) : (j + 2 ** (ind + 1))
+ * len(target_wires)
+ ],
+ )
+ swap_ops.insert(0, qml.ctrl(new_op, control=control_swap_wires[-ind - 1]))
+
+ if not clean:
+ # Based on this paper (Fig 1.c): https://arxiv.org/abs/1812.00954
+ decomp_ops = select_ops + swap_ops
+
+ else:
+ # Based on this paper (Fig 4): https://arxiv.org/abs/1902.02134
+ adjoint_swap_ops = swap_ops[::-1]
+ hadamard_ops = [qml.Hadamard(wires=w) for w in target_wires]
+
+ decomp_ops = 2 * (hadamard_ops + adjoint_swap_ops + select_ops + swap_ops)
+
+ if qml.QueuingManager.recording():
+ for op in decomp_ops:
+ qml.apply(op)
+
+ return decomp_ops
+
+ @classmethod
+ def _primitive_bind_call(cls, *args, **kwargs):
+ return cls._primitive.bind(*args, **kwargs)
+
+ @property
+ def bitstrings(self):
+ """bitstrings to be added."""
+ return self.hyperparameters["bitstrings"]
+
+ @property
+ def control_wires(self):
+ """The control wires."""
+ return self.hyperparameters["control_wires"]
+
+ @property
+ def target_wires(self):
+ """The wires where the bitstring is loaded."""
+ return self.hyperparameters["target_wires"]
+
+ @property
+ def work_wires(self):
+ """The wires where the index is specified."""
+ return self.hyperparameters["work_wires"]
+
+ @property
+ def wires(self):
+ """All wires involved in the operation."""
+ return (
+ self.hyperparameters["control_wires"]
+ + self.hyperparameters["target_wires"]
+ + self.hyperparameters["work_wires"]
+ )
+
+ @property
+ def clean(self):
+ """Boolean to select the version of QROM."""
+ return self.hyperparameters["clean"]
diff --git a/tests/capture/test_templates.py b/tests/capture/test_templates.py
index 06a8f3139ed..b7fb8c60e00 100644
--- a/tests/capture/test_templates.py
+++ b/tests/capture/test_templates.py
@@ -258,6 +258,7 @@ def fn(*args):
qml.MERA,
qml.MPS,
qml.TTN,
+ qml.QROM,
]
@@ -652,6 +653,41 @@ def qfunc():
assert len(q) == 1
assert qml.equal(q.queue[0], qml.Qubitization(**kwargs))
+ @pytest.mark.usefixtures("new_opmath_only")
+ def test_qrom(self):
+ """Test the primitive bind call of QROM."""
+
+ kwargs = {
+ "bitstrings": ["0", "1"],
+ "control_wires": [0],
+ "target_wires": [1],
+ "work_wires": None,
+ }
+
+ def qfunc():
+ qml.QROM(**kwargs)
+
+ # Validate inputs
+ qfunc()
+
+ # Actually test primitive bind
+ jaxpr = jax.make_jaxpr(qfunc)()
+
+ assert len(jaxpr.eqns) == 1
+
+ eqn = jaxpr.eqns[0]
+ assert eqn.primitive == qml.QROM._primitive
+ assert eqn.invars == jaxpr.jaxpr.invars
+ assert eqn.params == kwargs
+ assert len(eqn.outvars) == 1
+ assert isinstance(eqn.outvars[0], jax.core.DropVar)
+
+ with qml.queuing.AnnotatedQueue() as q:
+ jax.core.eval_jaxpr(jaxpr.jaxpr, jaxpr.consts)
+
+ assert len(q) == 1
+ assert qml.equal(q.queue[0], qml.QROM(**kwargs))
+
@pytest.mark.parametrize(
"template, kwargs",
[
diff --git a/tests/templates/test_subroutines/test_qrom.py b/tests/templates/test_subroutines/test_qrom.py
new file mode 100644
index 00000000000..21d60f28aaa
--- /dev/null
+++ b/tests/templates/test_subroutines/test_qrom.py
@@ -0,0 +1,253 @@
+# Copyright 2018-2024 Xanadu Quantum Technologies Inc.
+
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+
+# http://www.apache.org/licenses/LICENSE-2.0
+
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+Tests for the QROM template.
+"""
+
+import pytest
+
+import pennylane as qml
+from pennylane import numpy as np
+
+
+def test_assert_valid_qrom():
+ """Run standard validity tests."""
+ bitstrings = ["000", "001", "111", "011", "000", "101", "110", "111"]
+
+ op = qml.QROM(bitstrings, control_wires=[0, 1, 2], target_wires=[3, 4, 5], work_wires=[6, 7, 8])
+ qml.ops.functions.assert_valid(op)
+
+
+class TestQROM:
+ """Test the qml.QROM template."""
+
+ @pytest.mark.parametrize(
+ ("bitstrings", "target_wires", "control_wires", "work_wires", "clean"),
+ [
+ (
+ ["11", "01", "00", "10"],
+ [0, 1],
+ [2, 3],
+ [4, 5],
+ True,
+ ),
+ (
+ ["11", "01", "00", "10"],
+ [0, 1],
+ [2, 3],
+ [4, 5, 6, 7, 8, 9],
+ True,
+ ),
+ (
+ ["01", "01", "00", "00"],
+ ["a", "b"],
+ [2, 3],
+ [4, 5, 6],
+ False,
+ ),
+ (
+ ["111", "001", "000", "100"],
+ [0, 1, "b"],
+ [2, 3],
+ None,
+ False,
+ ),
+ (
+ ["1111", "0101", "0100", "1010"],
+ [0, 1, "b", "d"],
+ [2, 3],
+ ["a", 5, 6, 7],
+ True,
+ ),
+ ],
+ )
+ def test_operation_result(
+ self, bitstrings, target_wires, control_wires, work_wires, clean
+ ): # pylint: disable=too-many-arguments
+ """Test the correctness of the QROM template output."""
+ dev = qml.device("default.qubit", shots=1)
+
+ @qml.qnode(dev)
+ def circuit(j):
+ qml.BasisEmbedding(j, wires=control_wires)
+
+ qml.QROM(bitstrings, control_wires, target_wires, work_wires, clean)
+ return qml.sample(wires=target_wires)
+
+ for j in range(2 ** len(control_wires)):
+ assert np.allclose(circuit(j), [int(bit) for bit in bitstrings[j]])
+
+ @pytest.mark.parametrize(
+ ("bitstrings", "target_wires", "control_wires", "work_wires"),
+ [
+ (
+ ["11", "01", "00", "10"],
+ [0, 1],
+ [2, 3],
+ [4, 5],
+ ),
+ (
+ ["01", "01", "00", "00"],
+ ["a", "b"],
+ [2, 3],
+ [4, 5, 6],
+ ),
+ (
+ ["111", "001", "000", "100"],
+ [0, 1, "b"],
+ [2, 3],
+ ["a", 5, 6],
+ ),
+ (
+ ["1111", "0101", "0100", "1010"],
+ [0, 1, "b", "d"],
+ [2, 3],
+ ["a", 5, 6, 7],
+ ),
+ ],
+ )
+ def test_work_wires_output(self, bitstrings, target_wires, control_wires, work_wires):
+ """Tests that the ``clean = True`` version don't modify the initial state in work_wires."""
+ dev = qml.device("default.qubit", shots=1)
+
+ @qml.qnode(dev)
+ def circuit():
+
+ # Initialize the work wires to a non-zero state
+ for ind, wire in enumerate(work_wires):
+ qml.RX(ind, wires=wire)
+
+ for wire in control_wires:
+ qml.Hadamard(wires=wire)
+
+ qml.QROM(bitstrings, control_wires, target_wires, work_wires)
+
+ for ind, wire in enumerate(work_wires):
+ qml.RX(-ind, wires=wire)
+
+ return qml.probs(wires=work_wires)
+
+ assert np.isclose(circuit()[0], 1.0)
+
+ def test_decomposition(self):
+ """Test that compute_decomposition and decomposition work as expected."""
+ qrom_decomposition = qml.QROM(
+ ["1", "0", "0", "1"], control_wires=[0, 1], target_wires=[2], work_wires=[3], clean=True
+ ).decomposition()
+
+ expected_gates = [
+ qml.Hadamard(wires=[2]),
+ qml.CSWAP(wires=[1, 2, 3]),
+ qml.Select(
+ ops=(
+ qml.BasisEmbedding(1, wires=[2]) @ qml.BasisEmbedding(0, wires=[3]),
+ qml.BasisEmbedding(0, wires=[2]) @ qml.BasisEmbedding(1, wires=[3]),
+ ),
+ control=[0],
+ ),
+ qml.CSWAP(wires=[1, 2, 3]),
+ qml.Hadamard(wires=[2]),
+ qml.CSWAP(wires=[1, 2, 3]),
+ qml.Select(
+ ops=(
+ qml.BasisEmbedding(1, wires=[2]) @ qml.BasisEmbedding(0, wires=[3]),
+ qml.BasisEmbedding(0, wires=[2]) @ qml.BasisEmbedding(1, wires=[3]),
+ ),
+ control=0,
+ ),
+ qml.CSWAP(wires=[1, 2, 3]),
+ ]
+
+ assert all(qml.equal(op1, op2) for op1, op2 in zip(qrom_decomposition, expected_gates))
+
+ @pytest.mark.jax
+ def test_jit_compatible(self):
+ """Test that the template is compatible with the JIT compiler."""
+
+ import jax
+
+ jax.config.update("jax_enable_x64", True)
+
+ dev = qml.device("default.qubit", wires=4)
+
+ @jax.jit
+ @qml.qnode(dev)
+ def circuit():
+ qml.QROM(["1", "0", "0", "1"], control_wires=[0, 1], target_wires=[2], work_wires=[3])
+ return qml.probs(wires=3)
+
+ assert jax.numpy.allclose(circuit(), jax.numpy.array([1.0, 0.0]))
+
+
+@pytest.mark.parametrize(
+ ("control_wires", "target_wires", "work_wires", "msg_match"),
+ [
+ (
+ [0, 1, 2],
+ [0, 3],
+ [4, 5],
+ "Target wires should be different from control wires.",
+ ),
+ (
+ [0, 1, 2],
+ [4],
+ [2, 5],
+ "Control wires should be different from work wires.",
+ ),
+ (
+ [0, 1, 2],
+ [4],
+ [4],
+ "Target wires should be different from work wires.",
+ ),
+ ],
+)
+def test_wires_error(control_wires, target_wires, work_wires, msg_match):
+ """Test an error is raised when a control wire is in one of the ops"""
+ with pytest.raises(ValueError, match=msg_match):
+ qml.QROM(["1"] * 8, control_wires, target_wires, work_wires)
+
+
+def test_repr():
+ """Test that the __repr__ method works as expected."""
+
+ op = qml.QROM(
+ ["1", "0", "0", "1"], control_wires=[0, 1], target_wires=[2], work_wires=[3], clean=True
+ )
+ res = op.__repr__()
+ expected = "QROM(control_wires=, target_wires=, work_wires=, clean=True)"
+ assert res == expected
+
+
+@pytest.mark.parametrize(
+ ("bitstrings", "control_wires", "target_wires", "msg_match"),
+ [
+ (
+ ["1", "0", "0", "1"],
+ [0],
+ [2],
+ r"Not enough control wires \(1\) for the desired number of bitstrings \(4\). At least 2 control wires are required.",
+ ),
+ (
+ ["1", "0", "0", "1"],
+ [0, 1],
+ [2, 3],
+ r"Bitstring length must match the number of target wires.",
+ ),
+ ],
+)
+def test_wrong_wires_error(bitstrings, control_wires, target_wires, msg_match):
+ """Test that error is raised if more ops are requested than can fit in control wires"""
+ with pytest.raises(ValueError, match=msg_match):
+ qml.QROM(bitstrings, control_wires, target_wires, work_wires=None)