Skip to content

Commit

Permalink
Merge pull request #22517 from shuhand0:dev/shuhan/fixConfig
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 653734332
  • Loading branch information
jax authors committed Jul 18, 2024
2 parents 6dcc497 + 21951db commit 6cc4298
Showing 1 changed file with 1 addition and 0 deletions.
1 change: 1 addition & 0 deletions tests/lax_metal_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -578,6 +578,7 @@ def np_fun(x, y):
dtype=float_dtypes,#number_dtypes,
)
@jax.default_matmul_precision("float32")
@jax.numpy_rank_promotion('allow') # adopt PR#22316
def testVecdot(self, lhs_batch, rhs_batch, axis_size, axis, dtype):
# Construct vecdot-compatible shapes.
size = min(len(lhs_batch), len(rhs_batch))
Expand Down

0 comments on commit 6cc4298

Please sign in to comment.