diff --git a/tests/linalg_test.py b/tests/linalg_test.py index 0eb4f800309b..9fcc940c6c2c 100644 --- a/tests/linalg_test.py +++ b/tests/linalg_test.py @@ -2188,19 +2188,12 @@ def testHilbert(self, n): self._CompileAndCheck(jsp_fun, args_maker) @jtu.sample_product( - shape=[ - (128, 12), - (128, 64), - (2048, 128), - ], - dtype=[jnp.float32, jnp.float64], + shape=[(5, 1), (10, 4), (128, 12)], + dtype=float_types, symmetrize_output=[True, False], ) @jtu.skip_on_devices("tpu") def testSymmetricProduct(self, shape, dtype, symmetrize_output): - if dtype is jnp.float64 and not config.enable_x64.value: - self.skipTest("Test disabled for x32 mode") - rng = jtu.rand_default(self.rng()) batch_size = 10 atol = 1e-6 if dtype == jnp.float64 else 1e-3 @@ -2209,7 +2202,8 @@ def testSymmetricProduct(self, shape, dtype, symmetrize_output): c_shape = a_matrix.shape[:-1] + (a_matrix.shape[-2],) c_matrix = jnp.zeros(c_shape, dtype) - old_product = jnp.einsum("...ij,...kj->...ik", a_matrix, a_matrix) + old_product = jnp.einsum("...ij,...kj->...ik", a_matrix, a_matrix, + precision=lax.Precision.HIGHEST) new_product = lax_linalg.symmetric_product( a_matrix, c_matrix, symmetrize_output=symmetrize_output) new_product_with_batching = jax.vmap(