diff --git a/strawberryfields/utils/random_numbers_matrices.py b/strawberryfields/utils/random_numbers_matrices.py index 5062f2780..6b37f07af 100644 --- a/strawberryfields/utils/random_numbers_matrices.py +++ b/strawberryfields/utils/random_numbers_matrices.py @@ -110,12 +110,10 @@ def random_interferometer(N, real=False): Returns: array: random :math:`N\times N` unitary distributed with the Haar measure """ + if N == 1: + if real: + return np.array([[2 * (np.random.binomial(1, 0.5) - 0.5)]]) + return np.array([[np.exp(1j * 2 * np.pi * np.random.rand())]]) if real: - z = np.random.randn(N, N) - else: - z = randnc(N, N) / np.sqrt(2.0) - q, r = sp.linalg.qr(z) - d = np.diagonal(r) - ph = d / np.abs(d) - U = np.multiply(q, ph, q) - return U + return sp.stats.ortho_group.rvs(N) + return sp.stats.unitary_group.rvs(N) diff --git a/tests/frontend/test_utils.py b/tests/frontend/test_utils.py index 05562f02c..06c0dd575 100644 --- a/tests/frontend/test_utils.py +++ b/tests/frontend/test_utils.py @@ -267,14 +267,10 @@ def test_odd_cat_state(self, a, cutoff, tol): # =================================================================================== +@pytest.mark.parametrize("modes", [1, 2, 3]) class TestRandomMatrices: """Unit tests for random matrices""" - @pytest.fixture - def modes(self): - """Number of modes to use when creating matrices""" - return 3 - @pytest.mark.parametrize("pure_state", [True, False]) @pytest.mark.parametrize("block_diag", [True, False]) def test_random_covariance_square(self, modes, hbar, pure_state, block_diag):