From 21951db84ae6cccf80b684dc93682d39c5c4fcb1 Mon Sep 17 00:00:00 2001 From: Shuhan Ding Date: Thu, 18 Jul 2024 10:57:43 -0700 Subject: [PATCH] adopt numpy rank promotion config in metal test --- tests/lax_metal_test.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/lax_metal_test.py b/tests/lax_metal_test.py index dab26d86c0a2..02fecb7b3f1a 100644 --- a/tests/lax_metal_test.py +++ b/tests/lax_metal_test.py @@ -578,6 +578,7 @@ def np_fun(x, y): dtype=float_dtypes,#number_dtypes, ) @jax.default_matmul_precision("float32") + @jax.numpy_rank_promotion('allow') # adopt PR#22316 def testVecdot(self, lhs_batch, rhs_batch, axis_size, axis, dtype): # Construct vecdot-compatible shapes. size = min(len(lhs_batch), len(rhs_batch))