Skip to content

Commit

Permalink
Fix tolerances for failing linalg tests.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 671881600
  • Loading branch information
dfm authored and jax authors committed Sep 6, 2024
1 parent 7266e33 commit 2ce0fc2
Showing 1 changed file with 4 additions and 10 deletions.
14 changes: 4 additions & 10 deletions tests/linalg_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2188,19 +2188,12 @@ def testHilbert(self, n):
self._CompileAndCheck(jsp_fun, args_maker)

@jtu.sample_product(
shape=[
(128, 12),
(128, 64),
(2048, 128),
],
dtype=[jnp.float32, jnp.float64],
shape=[(5, 1), (10, 4), (128, 12)],
dtype=float_types,
symmetrize_output=[True, False],
)
@jtu.skip_on_devices("tpu")
def testSymmetricProduct(self, shape, dtype, symmetrize_output):
if dtype is jnp.float64 and not config.enable_x64.value:
self.skipTest("Test disabled for x32 mode")

rng = jtu.rand_default(self.rng())
batch_size = 10
atol = 1e-6 if dtype == jnp.float64 else 1e-3
Expand All @@ -2209,7 +2202,8 @@ def testSymmetricProduct(self, shape, dtype, symmetrize_output):
c_shape = a_matrix.shape[:-1] + (a_matrix.shape[-2],)
c_matrix = jnp.zeros(c_shape, dtype)

old_product = jnp.einsum("...ij,...kj->...ik", a_matrix, a_matrix)
old_product = jnp.einsum("...ij,...kj->...ik", a_matrix, a_matrix,
precision=lax.Precision.HIGHEST)
new_product = lax_linalg.symmetric_product(
a_matrix, c_matrix, symmetrize_output=symmetrize_output)
new_product_with_batching = jax.vmap(
Expand Down

0 comments on commit 2ce0fc2

Please sign in to comment.