Skip to content

Commit

Permalink
Test if the new test cases fails
Browse files Browse the repository at this point in the history
  • Loading branch information
ayaka14732 committed Aug 28, 2024
1 parent b56ed8e commit 252fc07
Showing 1 changed file with 11 additions and 0 deletions.
11 changes: 11 additions & 0 deletions tests/pallas/ops_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -739,6 +739,17 @@ def kernel(x_ref, o_ref):
x = jnp.array([0.42, 2.4]).astype(dtype)
np.testing.assert_allclose(kernel(x), fn(x), rtol=1e-6)

def test_abs_weak_type(self):
# see https://github.com/google/jax/issues/23191
@functools.partial(
self.pallas_call, out_shape=jax.ShapeDtypeStruct((4, 4), jnp.float32),
)
def kernel(x_ref, o_ref):
o_ref[...] = jnp.abs(x_ref[...])

x = jnp.broadcast_to(-3.2, (4, 4)) # sets `weak_type` to `True`
np.testing.assert_allclose(kernel(x), jnp.abs(x), rtol=1e-6)

@parameterized.parameters(
("float32", "int32"),
("float64", "int32"),
Expand Down

0 comments on commit 252fc07

Please sign in to comment.