From 08d286f8a80c6d2d9d5c291c289dd521adb4f036 Mon Sep 17 00:00:00 2001 From: Guillermo Alonso-Linaje <65235481+KetpuntoG@users.noreply.github.com> Date: Fri, 19 Apr 2024 13:30:34 -0400 Subject: [PATCH] Fixing kron compatibility `numpy @ torch` (#5540) Fixes https://github.com/PennyLaneAI/pennylane/issues/5542 **Relevant Shortcut Stories:** [sc-61616] --------- Co-authored-by: Jay Soni --- doc/releases/changelog-dev.md | 3 +++ pennylane/math/multi_dispatch.py | 8 ++++++++ tests/math/test_multi_dispatch.py | 11 +++++++++++ 3 files changed, 22 insertions(+) diff --git a/doc/releases/changelog-dev.md b/doc/releases/changelog-dev.md index 780c64a2a25..b993ec77087 100644 --- a/doc/releases/changelog-dev.md +++ b/doc/releases/changelog-dev.md @@ -461,6 +461,9 @@ * Fixes a bug in `hamiltonian_expand` that produces incorrect output dimensions when shot vectors are combined with parameter broadcasting. [(#5494)](https://github.com/PennyLaneAI/pennylane/pull/5494) +* Fixes a bug in `qml.math.kron` that makes torch incompatible with numpy. + [(#5540)](https://github.com/PennyLaneAI/pennylane/pull/5540) +

Contributors ✍️

This release contains contributions from (in alphabetical order): diff --git a/pennylane/math/multi_dispatch.py b/pennylane/math/multi_dispatch.py index 2c16ebe98e5..acf89bb4152 100644 --- a/pennylane/math/multi_dispatch.py +++ b/pennylane/math/multi_dispatch.py @@ -160,6 +160,14 @@ def kron(*args, like=None, **kwargs): """The kronecker/tensor product of args.""" if like == "scipy": return onp.kron(*args, **kwargs) # Dispatch scipy kron to numpy backed specifically. + + if like == "torch": + mats = [ + ar.numpy.asarray(arg, like="torch") if isinstance(arg, onp.ndarray) else arg + for arg in args + ] + return ar.numpy.kron(*mats) + return ar.numpy.kron(*args, like=like, **kwargs) diff --git a/tests/math/test_multi_dispatch.py b/tests/math/test_multi_dispatch.py index ac253f7c206..3c56c21f8a5 100644 --- a/tests/math/test_multi_dispatch.py +++ b/tests/math/test_multi_dispatch.py @@ -197,6 +197,17 @@ def test_dot_autograd(): assert fn.allclose(qml_grad(fn.dot)(x, y), x) +def test_kron(): + """Test the kronecker product function.""" + x = torch.tensor([[1, 2], [3, 4]]) + y = np.array([[0, 5], [6, 7]]) + + res = fn.kron(x, y) + expected = torch.tensor([[0, 5, 0, 10], [6, 7, 12, 14], [0, 15, 0, 20], [18, 21, 24, 28]]) + + assert fn.allclose(res, expected) + + class TestMatmul: @pytest.mark.torch def test_matmul_torch(self):