Skip to content

Commit

Permalink
Reduced duplication between _bcast and _ensure_fa in Pallas Mosai…
Browse files Browse the repository at this point in the history
…c GPU lowering

PiperOrigin-RevId: 676180945
  • Loading branch information
superbobry authored and Google-ML-Automation committed Sep 18, 2024
1 parent 0c7c71e commit ba06bd5
Showing 1 changed file with 15 additions and 17 deletions.
32 changes: 15 additions & 17 deletions jax/_src/pallas/mosaic_gpu/lowering.py
Original file line number Diff line number Diff line change
Expand Up @@ -635,15 +635,17 @@ def _broadcast_in_dim_lowering_rule(
):
if broadcast_dimensions:
raise NotImplementedError
return _ensure_fa(x, ctx.avals_in[0]).broadcast(shape)
[x_aval] = ctx.avals_in
return _ensure_fa(x, x_aval.dtype).broadcast(shape)


@register_lowering_rule(lax.convert_element_type_p)
def _convert_element_type_lowering_rule(
ctx: LoweringRuleContext, x, *, new_dtype, weak_type, sharding
):
del weak_type, sharding
return _ensure_fa(x, *ctx.avals_in).astype(mlir.dtype_to_ir_type(new_dtype))
[x_aval] = ctx.avals_in
return _ensure_fa(x, x_aval.dtype).astype(mlir.dtype_to_ir_type(new_dtype))


def _binary_op_lowering_rule(ctx: LoweringRuleContext, x, y, *, impl):
Expand All @@ -661,15 +663,17 @@ def _binary_op_lowering_rule(ctx: LoweringRuleContext, x, y, *, impl):

@register_lowering_rule(lax.integer_pow_p)
def _integer_pow_lowering_rule(ctx: LoweringRuleContext, x, y):
x = _ensure_fa(x, *ctx.avals_in)
[x_aval] = ctx.avals_in
x = _ensure_fa(x, x_aval.dtype)
if y == 2:
return x * x
return NotImplementedError


@register_lowering_rule(lax.rsqrt_p)
def _rsqrt_lowering_rule(ctx: LoweringRuleContext, x):
return _ensure_fa(x, *ctx.avals_in).rsqrt(ctx.module_ctx.approx_math)
[x_aval] = ctx.avals_in
return _ensure_fa(x, x_aval.dtype).rsqrt(ctx.module_ctx.approx_math)


@register_lowering_rule(lax.reduce_sum_p)
Expand Down Expand Up @@ -721,40 +725,34 @@ def _bcast(
y_aval: jax_core.ShapedArray,
out_aval: jax_core.ShapedArray,
) -> ir.Value:
if isinstance(x, (np.ndarray, np.number, int, float)):
if not isinstance(x, mgpu.FragmentedArray):
x_dtype = x_aval.dtype
if x_aval.weak_type:
x_dtype = y_aval.dtype
x = mgpu.FragmentedArray.splat(
_ir_constant(x, mlir.dtype_to_ir_type(x_dtype)), ()
)
if isinstance(y, (np.ndarray, np.number, int, float)):
x = _ensure_fa(x, x_dtype)
if not isinstance(y, mgpu.FragmentedArray):
y_dtype = y_aval.dtype
if y_aval.weak_type:
y_dtype = x_aval.dtype
y = mgpu.FragmentedArray.splat(
_ir_constant(y, mlir.dtype_to_ir_type(y_dtype)), ()
)
assert isinstance(x, mgpu.FragmentedArray)
assert isinstance(y, mgpu.FragmentedArray)
y = _ensure_fa(y, y_dtype)
if x_aval.shape != out_aval.shape:
x = x.broadcast(out_aval.shape)
if y_aval.shape != out_aval.shape:
y = y.broadcast(out_aval.shape)
return x, y


def _ensure_fa(x: object, aval: jax_core.ShapedArray) -> mgpu.FragmentedArray:
def _ensure_fa(x: object, dtype: jnp.dtype) -> mgpu.FragmentedArray:
if isinstance(x, mgpu.FragmentedArray):
return x
elif isinstance(x, (np.number, np.ndarray, int, float)):
return mgpu.FragmentedArray.splat(
_ir_constant(x, mlir.dtype_to_ir_type(aval.dtype)), ()
_ir_constant(x, mlir.dtype_to_ir_type(dtype)), ()
)
elif isinstance(x, ir.Value):
if isinstance(x.type, (ir.IntegerType, ir.FloatType)):
return mgpu.FragmentedArray.splat(x, ())
raise NotImplementedError
raise NotImplementedError(f"Unsupported type: {type(x)}")


def _ir_constant(v: object, t: ir.Type) -> ir.Value:
Expand Down

0 comments on commit ba06bd5

Please sign in to comment.