Skip to content

Commit

Permalink
Test if the new test cases fails
Browse files Browse the repository at this point in the history
  • Loading branch information
ayaka14732 committed Aug 28, 2024
1 parent b56ed8e commit b3615a9
Showing 1 changed file with 12 additions and 0 deletions.
12 changes: 12 additions & 0 deletions tests/pallas/ops_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -739,6 +739,18 @@ def kernel(x_ref, o_ref):
x = jnp.array([0.42, 2.4]).astype(dtype)
np.testing.assert_allclose(kernel(x), fn(x), rtol=1e-6)

def test_abs_weak_type(self):
# see https://github.com/google/jax/issues/23191
@functools.partial(
self.pallas_call, out_shape=jax.ShapeDtypeStruct((4, 4), jnp.float32),
)
def kernel(x_ref, o_ref):
o_ref[...] = jnp.abs(x_ref[...])

x = jnp.array([-3.2], dtype=jnp.float32)
x.aval.weak_type = True # manually set `weak_type` to `True`
np.testing.assert_allclose(kernel(x), jnp.abs(x), rtol=1e-6)

@parameterized.parameters(
("float32", "int32"),
("float64", "int32"),
Expand Down

0 comments on commit b3615a9

Please sign in to comment.