Skip to content

Commit

Permalink
Migrated pl.BlockSpec uses in jax_triton to the new argument order
Browse files Browse the repository at this point in the history
See jax-ml/jax#22209.

PiperOrigin-RevId: 648671020
  • Loading branch information
superbobry authored and The jax_triton Authors committed Jul 2, 2024
1 parent 893b538 commit 7965cfb
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions examples/block_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,10 @@ def matmul(x, y, *, block_shape):
out_shape=jax.ShapeDtypeStruct((x.shape[0], y.shape[1]),
x.dtype),
in_specs=[
pl.BlockSpec(lambda i, j: (i, 0), (l, x.shape[1])),
pl.BlockSpec(lambda i, j: (0, j), (y.shape[0], r))
pl.BlockSpec((l, x.shape[1]), lambda i, j: (i, 0)),
pl.BlockSpec((y.shape[0], r), lambda i, j: (0, j))
],
out_specs=pl.BlockSpec(lambda i, j: (i, j), (l, r)),
out_specs=pl.BlockSpec((l, r), lambda i, j: (i, j)),
grid=(x.shape[0] // l, y.shape[1] // r),
debug=True)(x, y)

Expand Down

0 comments on commit 7965cfb

Please sign in to comment.