Skip to content

Commit

Permalink
[Pallas TPU] Fix itemsize check for int4 in bitcast lowering.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 668637257
  • Loading branch information
bythew3i authored and jax authors committed Aug 28, 2024
1 parent 2785a08 commit 2c11a91
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 8 deletions.
4 changes: 3 additions & 1 deletion jax/_src/pallas/mosaic/lowering.py
Original file line number Diff line number Diff line change
Expand Up @@ -2546,7 +2546,9 @@ def _bitcast_convert_type_lowering_rule(
ctx: LoweringRuleContext, x, *, new_dtype):
(in_aval, ) = ctx.avals_in
(out_aval,) = ctx.avals_out
if in_aval.dtype.itemsize != new_dtype.itemsize:
old_bitwidth = pallas_utils.dtype_bitwidth(in_aval.dtype)
new_bitwidth = pallas_utils.dtype_bitwidth(new_dtype)
if old_bitwidth != new_bitwidth:
raise NotImplementedError("Changing bitwidths not supported.")
return tpu.BitcastOp(aval_to_ir_type(out_aval), x).result
lowering_rules[lax.bitcast_convert_type_p] = _bitcast_convert_type_lowering_rule
Expand Down
21 changes: 14 additions & 7 deletions jax/_src/pallas/mosaic/primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,13 @@
from jax._src import state
from jax._src import tree_util
from jax._src import util
from jax._src.state import indexing
from jax._src.state import primitives as sp
from jax._src.interpreters import mlir
from jax._src.pallas import core as pl_core
from jax._src.pallas import utils as pallas_utils
from jax._src.pallas.mosaic import core as tpu_core
from jax._src.state import discharge as state_discharge
from jax._src.state import indexing
from jax._src.state import primitives as sp
from jax._src.typing import DTypeLike
import jax.numpy as jnp

Expand Down Expand Up @@ -65,7 +66,9 @@ def bitcast(x, ty: DTypeLike):
ty = dtypes.canonicalize_dtype(ty)
if len(x.shape) < 2:
raise ValueError("Not implemented: bitcast 1D")
if x.shape[-2] * x.dtype.itemsize % ty.itemsize:
src_bitwidth = pallas_utils.dtype_bitwidth(x.dtype)
dst_bitwidth = pallas_utils.dtype_bitwidth(ty)
if x.shape[-2] * src_bitwidth % dst_bitwidth:
raise ValueError(
"Not implemented: the 2nd minor dim can not be perfectly packed or"
" unpacked"
Expand All @@ -76,19 +79,23 @@ def bitcast(x, ty: DTypeLike):
@bitcast_p.def_abstract_eval
def _bitcast_abstract_eval(x, *, ty):
shape = list(x.shape)
shape[-2] = shape[-2] * x.dtype.itemsize // ty.itemsize
src_bitwidth = pallas_utils.dtype_bitwidth(x.dtype)
dst_bitwidth = pallas_utils.dtype_bitwidth(ty)
shape[-2] = shape[-2] * src_bitwidth // dst_bitwidth
return jax_core.ShapedArray(shape, ty)


def _bitcast_lowering_rule(ctx: mlir.LoweringRuleContext, x, *, ty):
def _bitcast(x):
if x.dtype.itemsize < ty.itemsize:
src_bitwidth = pallas_utils.dtype_bitwidth(x.dtype)
dst_bitwidth = pallas_utils.dtype_bitwidth(ty)
if src_bitwidth < dst_bitwidth:
*leading, m, n = x.shape
packing = ty.itemsize // x.dtype.itemsize
packing = dst_bitwidth // src_bitwidth
x = x.reshape(*leading, m // packing, packing, n)
x = jnp.swapaxes(x, -1, -2)
return jax.lax.bitcast_convert_type(x, ty)
if x.dtype.itemsize > ty.itemsize:
if src_bitwidth > dst_bitwidth:
y = jax.lax.bitcast_convert_type(x, ty)
*leading, m, n, packing = y.shape
return jnp.swapaxes(y, -1, -2).reshape(*leading, m * packing, n)
Expand Down
4 changes: 4 additions & 0 deletions jax/_src/pallas/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,10 @@ def next_power_of_2(x: int) -> int:
raise ValueError("`next_power_of_2` requires a non-negative integer.")
return 1 if x == 0 else 2 ** (x - 1).bit_length()

def dtype_bitwidth(dtype: np.dtype | jnp.dtype) -> int:
if isinstance(dtype, jnp.integer):
return jnp.iinfo(dtype).bits
return np.dtype(dtype).itemsize * 8

def pattern_match_scan_to_fori_loop(
jaxpr: jax_core.Jaxpr, num_consts: int, num_carry: int
Expand Down

0 comments on commit 2c11a91

Please sign in to comment.