diff --git a/jax/_src/pallas/mosaic_gpu/lowering.py b/jax/_src/pallas/mosaic_gpu/lowering.py index 1d76dc8405d5..3afc62aebfcf 100644 --- a/jax/_src/pallas/mosaic_gpu/lowering.py +++ b/jax/_src/pallas/mosaic_gpu/lowering.py @@ -635,7 +635,8 @@ def _broadcast_in_dim_lowering_rule( ): if broadcast_dimensions: raise NotImplementedError - return _ensure_fa(x, ctx.avals_in[0]).broadcast(shape) + [x_aval] = ctx.avals_in + return _ensure_fa(x, x_aval.dtype).broadcast(shape) @register_lowering_rule(lax.convert_element_type_p) @@ -643,7 +644,8 @@ def _convert_element_type_lowering_rule( ctx: LoweringRuleContext, x, *, new_dtype, weak_type, sharding ): del weak_type, sharding - return _ensure_fa(x, *ctx.avals_in).astype(mlir.dtype_to_ir_type(new_dtype)) + [x_aval] = ctx.avals_in + return _ensure_fa(x, x_aval.dtype).astype(mlir.dtype_to_ir_type(new_dtype)) def _binary_op_lowering_rule(ctx: LoweringRuleContext, x, y, *, impl): @@ -661,7 +663,8 @@ def _binary_op_lowering_rule(ctx: LoweringRuleContext, x, y, *, impl): @register_lowering_rule(lax.integer_pow_p) def _integer_pow_lowering_rule(ctx: LoweringRuleContext, x, y): - x = _ensure_fa(x, *ctx.avals_in) + [x_aval] = ctx.avals_in + x = _ensure_fa(x, x_aval.dtype) if y == 2: return x * x return NotImplementedError @@ -669,7 +672,8 @@ def _integer_pow_lowering_rule(ctx: LoweringRuleContext, x, y): @register_lowering_rule(lax.rsqrt_p) def _rsqrt_lowering_rule(ctx: LoweringRuleContext, x): - return _ensure_fa(x, *ctx.avals_in).rsqrt(ctx.module_ctx.approx_math) + [x_aval] = ctx.avals_in + return _ensure_fa(x, x_aval.dtype).rsqrt(ctx.module_ctx.approx_math) @register_lowering_rule(lax.reduce_sum_p) @@ -721,22 +725,16 @@ def _bcast( y_aval: jax_core.ShapedArray, out_aval: jax_core.ShapedArray, ) -> ir.Value: - if isinstance(x, (np.ndarray, np.number, int, float)): + if not isinstance(x, mgpu.FragmentedArray): x_dtype = x_aval.dtype if x_aval.weak_type: x_dtype = y_aval.dtype - x = mgpu.FragmentedArray.splat( - _ir_constant(x, mlir.dtype_to_ir_type(x_dtype)), () - ) - if isinstance(y, (np.ndarray, np.number, int, float)): + x = _ensure_fa(x, x_dtype) + if not isinstance(y, mgpu.FragmentedArray): y_dtype = y_aval.dtype if y_aval.weak_type: y_dtype = x_aval.dtype - y = mgpu.FragmentedArray.splat( - _ir_constant(y, mlir.dtype_to_ir_type(y_dtype)), () - ) - assert isinstance(x, mgpu.FragmentedArray) - assert isinstance(y, mgpu.FragmentedArray) + y = _ensure_fa(y, y_dtype) if x_aval.shape != out_aval.shape: x = x.broadcast(out_aval.shape) if y_aval.shape != out_aval.shape: @@ -744,17 +742,17 @@ def _bcast( return x, y -def _ensure_fa(x: object, aval: jax_core.ShapedArray) -> mgpu.FragmentedArray: +def _ensure_fa(x: object, dtype: jnp.dtype) -> mgpu.FragmentedArray: if isinstance(x, mgpu.FragmentedArray): return x elif isinstance(x, (np.number, np.ndarray, int, float)): return mgpu.FragmentedArray.splat( - _ir_constant(x, mlir.dtype_to_ir_type(aval.dtype)), () + _ir_constant(x, mlir.dtype_to_ir_type(dtype)), () ) elif isinstance(x, ir.Value): if isinstance(x.type, (ir.IntegerType, ir.FloatType)): return mgpu.FragmentedArray.splat(x, ()) - raise NotImplementedError + raise NotImplementedError(f"Unsupported type: {type(x)}") def _ir_constant(v: object, t: ir.Type) -> ir.Value: