diff --git a/doc/releases/changelog-dev.md b/doc/releases/changelog-dev.md
index bd4e5f0ffc3..9d88744ff53 100644
--- a/doc/releases/changelog-dev.md
+++ b/doc/releases/changelog-dev.md
@@ -144,6 +144,9 @@
Bug fixes 🐛
+* The decomposition of `StronglyEntanglingLayers` is now compatible with broadcasting.
+ [(#5716)](https://github.com/PennyLaneAI/pennylane/pull/5716)
+
* `qml.cond` can now be applied to `ControlledOp` operations when deferring measurements.
[(#5725)](https://github.com/PennyLaneAI/pennylane/pull/5725)
diff --git a/pennylane/templates/layers/strongly_entangling.py b/pennylane/templates/layers/strongly_entangling.py
index fd07f05ac23..554f100e8f5 100644
--- a/pennylane/templates/layers/strongly_entangling.py
+++ b/pennylane/templates/layers/strongly_entangling.py
@@ -200,7 +200,7 @@ def compute_decomposition(
CNOT(wires=['a', 'a']),
CNOT(wires=['b', 'b'])]
"""
- n_layers = qml.math.shape(weights)[0]
+ n_layers = qml.math.shape(weights)[-3]
wires = qml.wires.Wires(wires)
op_list = []
diff --git a/tests/templates/test_layers/test_strongly_entangling.py b/tests/templates/test_layers/test_strongly_entangling.py
index d02bbe67e5c..ed9839b840b 100644
--- a/tests/templates/test_layers/test_strongly_entangling.py
+++ b/tests/templates/test_layers/test_strongly_entangling.py
@@ -36,6 +36,7 @@ def test_standard_validity():
qml.ops.functions.assert_valid(op)
+@pytest.mark.parametrize("batch_dim", [None, 1, 3])
class TestDecomposition:
"""Tests that the template defines the correct decomposition."""
@@ -57,23 +58,39 @@ class TestDecomposition:
]
@pytest.mark.parametrize("n_wires, weight_shape, expected_names, expected_wires", QUEUES)
- def test_expansion(self, n_wires, weight_shape, expected_names, expected_wires):
+ def test_expansion(self, n_wires, weight_shape, expected_names, expected_wires, batch_dim):
"""Checks the queue for the default settings."""
+ # pylint: disable=too-many-arguments
+ if batch_dim is not None:
+ weight_shape = (batch_dim,) + weight_shape
weights = np.random.random(size=weight_shape)
op = qml.StronglyEntanglingLayers(weights, wires=range(n_wires))
tape = op.expand()
+ if batch_dim is None:
+ param_sets = iter(weights.reshape((-1, 3)))
+ else:
+ param_sets = iter(weights.reshape((batch_dim, -1, 3)).transpose([1, 2, 0]))
+
for i, gate in enumerate(tape.operations):
assert gate.name == expected_names[i]
+ if gate.name == "Rot":
+ assert gate.batch_size == batch_dim
+ assert qml.math.allclose(gate.data, next(param_sets))
+ else:
+ assert gate.batch_size is None
assert gate.wires.labels == tuple(expected_wires[i])
@pytest.mark.parametrize("n_layers, n_wires", [(2, 2), (1, 3), (2, 4)])
- def test_uses_correct_imprimitive(self, n_layers, n_wires):
+ def test_uses_correct_imprimitive(self, n_layers, n_wires, batch_dim):
"""Test that correct number of entanglers are used in the circuit."""
- weights = np.random.randn(n_layers, n_wires, 3)
+ shape = (n_layers, n_wires, 3)
+ if batch_dim is not None:
+ shape = (batch_dim,) + shape
+ weights = np.random.randn(*shape)
op = qml.StronglyEntanglingLayers(weights=weights, wires=range(n_wires), imprimitive=qml.CZ)
ops = op.expand().operations
@@ -81,9 +98,10 @@ def test_uses_correct_imprimitive(self, n_layers, n_wires):
gate_names = [gate.name for gate in ops]
assert gate_names.count("CZ") == n_wires * n_layers
- def test_custom_wire_labels(self, tol):
+ def test_custom_wire_labels(self, tol, batch_dim):
"""Test that template can deal with non-numeric, nonconsecutive wire labels."""
- weights = np.random.random(size=(1, 3, 3))
+ shape = (1, 3, 3) if batch_dim is None else (batch_dim, 1, 3, 3)
+ weights = np.random.random(size=shape)
dev = qml.device("default.qubit", wires=3)
dev2 = qml.device("default.qubit", wires=["z", "a", "k"])
@@ -107,10 +125,13 @@ def circuit2():
@pytest.mark.parametrize(
"n_layers, n_wires, ranges", [(2, 2, [1, 1]), (1, 3, [2]), (4, 4, [2, 3, 1, 3])]
)
- def test_custom_range_sequence(self, n_layers, n_wires, ranges):
+ def test_custom_range_sequence(self, n_layers, n_wires, ranges, batch_dim):
"""Test that correct sequence of custom ranges are used in the circuit."""
- weights = np.random.randn(n_layers, n_wires, 3)
+ shape = (n_layers, n_wires, 3)
+ if batch_dim is not None:
+ shape = (batch_dim,) + shape
+ weights = np.random.randn(*shape)
op = qml.StronglyEntanglingLayers(weights=weights, wires=range(n_wires), ranges=ranges)
ops = op.expand().operations