From 2c1127781d8fba6dfeacd8f8014b96ca631072c2 Mon Sep 17 00:00:00 2001 From: Sergei Lebedev Date: Thu, 29 Aug 2024 14:41:38 -0700 Subject: [PATCH] Improved extern selection in Pallas GPU Previously, * weakly typed avals matched the wrong externs; * this was addressed by #23193, which disallowed weakly typed avals entirely. Here we check if a weakly typed aval can be casted to the extern input dtype when selecting an extern. PiperOrigin-RevId: 669067725 --- jax/_src/pallas/triton/BUILD | 1 + jax/_src/pallas/triton/lowering.py | 321 +++++++++++++++-------------- 2 files changed, 166 insertions(+), 156 deletions(-) diff --git a/jax/_src/pallas/triton/BUILD b/jax/_src/pallas/triton/BUILD index c40fb19ec808..ff8db320a68b 100644 --- a/jax/_src/pallas/triton/BUILD +++ b/jax/_src/pallas/triton/BUILD @@ -54,6 +54,7 @@ pytype_strict_library( "//jax:api_util", "//jax:config", "//jax:core", + "//jax:dtypes", "//jax:mlir", "//jax:partial_eval", "//jax:source_info_util", diff --git a/jax/_src/pallas/triton/lowering.py b/jax/_src/pallas/triton/lowering.py index 4057b125fcdc..f2a4229223bd 100644 --- a/jax/_src/pallas/triton/lowering.py +++ b/jax/_src/pallas/triton/lowering.py @@ -571,7 +571,7 @@ def _not_lowering_rule(ctx: LoweringRuleContext, x): @dataclasses.dataclass(frozen=True) class _Extern: - arg_types: Sequence[str] + arg_types: Sequence[jax.typing.DTypeLike] symbol: str result_type: str @@ -579,7 +579,8 @@ def matches(self, avals: Sequence[jax_core.ShapedArray]) -> bool: if len(avals) != len(self.arg_types): return False return all( - aval.dtype.name == arg_type + aval.dtype == jnp.dtype(arg_type) + or (aval.weak_type and aval.dtype.kind == jnp.dtype(arg_type).kind) for aval, arg_type in zip(avals, self.arg_types) ) @@ -600,7 +601,7 @@ def lower(self, ctx: LoweringRuleContext, *args: Sequence[ir.Value]): @dataclasses.dataclass(frozen=True) class _Fallback: - arg_types: Sequence[str] + arg_types: Sequence[jax.typing.DTypeLike] lower: Callable[..., ir.Value] matches = _Extern.matches @@ -614,7 +615,7 @@ def inner(ctx: LoweringRuleContext, *args: ir.Value) -> ir.Value: table = tables[ctx.context.platform] h = next((e for e in table if e.matches(ctx.avals_in)), None) if h is None: - arg_aval_dtypes = tuple(aval.dtype.name for aval in ctx.avals_in) + arg_aval_dtypes = tuple(aval.dtype for aval in ctx.avals_in) raise NotImplementedError( f"unsupported types for {name}: {arg_aval_dtypes}" ) @@ -623,7 +624,7 @@ def inner(ctx: LoweringRuleContext, *args: ir.Value) -> ir.Value: bcast_args = [] for aval, arg, arg_type in zip(ctx.avals_in, args, h.arg_types): bcast_arg = _bcast_to(_ensure_ir_value(arg, aval), out_aval.shape) - if aval.weak_type and aval.dtype.name != arg_type: + if aval.weak_type and aval.dtype != jnp.dtype(arg_type): bcast_arg = _cast(bcast_arg, aval.dtype, jnp.dtype(arg_type)) bcast_args.append(bcast_arg) return h.lower(ctx, *bcast_args) @@ -634,16 +635,16 @@ def inner(ctx: LoweringRuleContext, *args: ir.Value) -> ir.Value: _abs_dispatch_table = _make_dispatch_table( "abs", cuda=[ - _Extern(["int32"], "__nv_abs", "int32"), - _Extern(["int64"], "__nv_llabs", "int64"), - _Extern(["float32"], "__nv_fabsf", "float32"), - _Extern(["float64"], "__nv_fabs", "float64"), + _Extern([jnp.int32], "__nv_abs", jnp.int32), + _Extern([jnp.int64], "__nv_llabs", jnp.int64), + _Extern([jnp.float32], "__nv_fabsf", jnp.float32), + _Extern([jnp.float64], "__nv_fabs", jnp.float64), ], rocm=[ - _Fallback(["int32"], lambda ctx, x: math_dialect.absi(x)), - _Fallback(["int64"], lambda ctx, x: math_dialect.absi(x)), - _Fallback(["float32"], lambda ctx, x: math_dialect.absf(x)), - _Fallback(["float64"], lambda ctx, x: math_dialect.absf(x)), + _Fallback([jnp.int32], lambda ctx, x: math_dialect.absi(x)), + _Fallback([jnp.int64], lambda ctx, x: math_dialect.absi(x)), + _Fallback([jnp.float32], lambda ctx, x: math_dialect.absf(x)), + _Fallback([jnp.float64], lambda ctx, x: math_dialect.absf(x)), ], ) @@ -667,337 +668,345 @@ def _abs_lowering_rule(ctx: LoweringRuleContext, x): lax.ceil_p: _make_dispatch_table( "ceil", cuda=[ - _Extern(["float32"], "__nv_ceilf", "float32"), - _Extern(["float64"], "__nv_ceil", "float64"), + _Extern([jnp.float32], "__nv_ceilf", jnp.float32), + _Extern([jnp.float64], "__nv_ceil", jnp.float64), ], rocm=[ - _Extern(["float32"], "__ocml_ceil_f32", "float32"), - _Extern(["float64"], "__ocml_ceil_f64", "float64"), + _Extern([jnp.float32], "__ocml_ceil_f32", jnp.float32), + _Extern([jnp.float64], "__ocml_ceil_f64", jnp.float64), ], ), lax.floor_p: _make_dispatch_table( "floor", cuda=[ - _Extern(["float32"], "__nv_floorf", "float32"), - _Extern(["float64"], "__nv_floor", "float64"), - _Fallback(["float16"], lambda ctx, x: math_dialect.floor(x)), - _Fallback(["bfloat16"], lambda ctx, x: math_dialect.floor(x)), + _Extern([jnp.float32], "__nv_floorf", jnp.float32), + _Extern([jnp.float64], "__nv_floor", jnp.float64), + _Fallback([jnp.float16], lambda ctx, x: math_dialect.floor(x)), + _Fallback([jnp.bfloat16], lambda ctx, x: math_dialect.floor(x)), ], rocm=[ - _Extern(["float32"], "__ocml_floor_f32", "float32"), - _Extern(["float64"], "__ocml_floor_f64", "float64"), - _Fallback(["float16"], lambda ctx, x: math_dialect.floor(x)), - _Fallback(["bfloat16"], lambda ctx, x: math_dialect.floor(x)), + _Extern([jnp.float32], "__ocml_floor_f32", jnp.float32), + _Extern([jnp.float64], "__ocml_floor_f64", jnp.float64), + _Fallback([jnp.float16], lambda ctx, x: math_dialect.floor(x)), + _Fallback([jnp.bfloat16], lambda ctx, x: math_dialect.floor(x)), ], ), lax.exp_p: _make_dispatch_table( "exp", cuda=[ - _Extern(["float32"], "__nv_expf", "float32"), - _Extern(["float64"], "__nv_exp", "float64"), - _Fallback(["float16"], lambda ctx, x: math_dialect.exp(x)), - _Fallback(["bfloat16"], lambda ctx, x: math_dialect.exp(x)), + _Extern([jnp.float32], "__nv_expf", jnp.float32), + _Extern([jnp.float64], "__nv_exp", jnp.float64), + _Fallback([jnp.float16], lambda ctx, x: math_dialect.exp(x)), + _Fallback([jnp.bfloat16], lambda ctx, x: math_dialect.exp(x)), ], rocm=[ - _Fallback(["float32"], lambda ctx, x: math_dialect.exp(x)), - _Fallback(["float64"], lambda ctx, x: math_dialect.exp(x)), - _Fallback(["float16"], lambda ctx, x: math_dialect.exp(x)), - _Fallback(["bfloat16"], lambda ctx, x: math_dialect.exp(x)), + _Fallback([jnp.float32], lambda ctx, x: math_dialect.exp(x)), + _Fallback([jnp.float64], lambda ctx, x: math_dialect.exp(x)), + _Fallback([jnp.float16], lambda ctx, x: math_dialect.exp(x)), + _Fallback([jnp.bfloat16], lambda ctx, x: math_dialect.exp(x)), ], ), lax.exp2_p: _make_dispatch_table( "exp2", cuda=[ - _Extern(["float32"], "__nv_exp2f", "float32"), - _Extern(["float64"], "__nv_exp2", "float64"), - _Fallback(["float16"], lambda ctx, x: math_dialect.exp2(x)), - _Fallback(["bfloat16"], lambda ctx, x: math_dialect.exp2(x)), + _Extern([jnp.float32], "__nv_exp2f", jnp.float32), + _Extern([jnp.float64], "__nv_exp2", jnp.float64), + _Fallback([jnp.float16], lambda ctx, x: math_dialect.exp2(x)), + _Fallback([jnp.bfloat16], lambda ctx, x: math_dialect.exp2(x)), ], rocm=[ - _Extern(["float32"], "__ocml_exp2_f32", "float32"), - _Extern(["float64"], "__ocml_exp2_f64", "float64"), - _Fallback(["float16"], lambda ctx, x: math_dialect.exp2(x)), - _Fallback(["bfloat16"], lambda ctx, x: math_dialect.exp2(x)), + _Extern([jnp.float32], "__ocml_exp2_f32", jnp.float32), + _Extern([jnp.float64], "__ocml_exp2_f64", jnp.float64), + _Fallback([jnp.float16], lambda ctx, x: math_dialect.exp2(x)), + _Fallback([jnp.bfloat16], lambda ctx, x: math_dialect.exp2(x)), ], ), lax.expm1_p: _make_dispatch_table( "expm1", cuda=[ - _Extern(["float32"], "__nv_expm1f", "float32"), - _Extern(["float64"], "__nv_expm1", "float64"), + _Extern([jnp.float32], "__nv_expm1f", jnp.float32), + _Extern([jnp.float64], "__nv_expm1", jnp.float64), ], rocm=[ - _Extern(["float32"], "__ocml_expm1_f32", "float32"), - _Extern(["float64"], "__ocml_expm1_f64", "float64"), + _Extern([jnp.float32], "__ocml_expm1_f32", jnp.float32), + _Extern([jnp.float64], "__ocml_expm1_f64", jnp.float64), ], ), lax.log_p: _make_dispatch_table( "log", cuda=[ - _Extern(["float32"], "__nv_logf", "float32"), - _Extern(["float64"], "__nv_log", "float64"), - _Fallback(["float16"], lambda ctx, x: math_dialect.log(x)), - _Fallback(["bfloat16"], lambda ctx, x: math_dialect.log(x)), + _Extern([jnp.float32], "__nv_logf", jnp.float32), + _Extern([jnp.float64], "__nv_log", jnp.float64), + _Fallback([jnp.float16], lambda ctx, x: math_dialect.log(x)), + _Fallback([jnp.bfloat16], lambda ctx, x: math_dialect.log(x)), ], rocm=[ - _Extern(["float32"], "__ocml_log_f32", "float32"), - _Extern(["float64"], "__ocml_log_f64", "float64"), - _Fallback(["float16"], lambda ctx, x: math_dialect.log(x)), - _Fallback(["bfloat16"], lambda ctx, x: math_dialect.log(x)), + _Extern([jnp.float32], "__ocml_log_f32", jnp.float32), + _Extern([jnp.float64], "__ocml_log_f64", jnp.float64), + _Fallback([jnp.float16], lambda ctx, x: math_dialect.log(x)), + _Fallback([jnp.bfloat16], lambda ctx, x: math_dialect.log(x)), ], ), lax.log1p_p: _make_dispatch_table( "log1p", cuda=[ - _Extern(["float32"], "__nv_log1pf", "float32"), - _Extern(["float64"], "__nv_log1p", "float64"), + _Extern([jnp.float32], "__nv_log1pf", jnp.float32), + _Extern([jnp.float64], "__nv_log1p", jnp.float64), ], rocm=[ - _Extern(["float32"], "__ocml_log1p_f32", "float32"), - _Extern(["float64"], "__ocml_log1p_f64", "float64"), + _Extern([jnp.float32], "__ocml_log1p_f32", jnp.float32), + _Extern([jnp.float64], "__ocml_log1p_f64", jnp.float64), ], ), lax.sqrt_p: _make_dispatch_table( "sqrt", cuda=[ - _Extern(["float32"], "__nv_sqrtf", "float32"), - _Extern(["float64"], "__nv_sqrt", "float64"), - _Fallback(["float16"], lambda ctx, x: math_dialect.sqrt(x)), - _Fallback(["bfloat16"], lambda ctx, x: math_dialect.sqrt(x)), + _Extern([jnp.float32], "__nv_sqrtf", jnp.float32), + _Extern([jnp.float64], "__nv_sqrt", jnp.float64), + _Fallback([jnp.float16], lambda ctx, x: math_dialect.sqrt(x)), + _Fallback([jnp.bfloat16], lambda ctx, x: math_dialect.sqrt(x)), ], rocm=[ - _Extern(["float32"], "__ocml_sqrt_f32", "float32"), - _Extern(["float64"], "__ocml_sqrt_f64", "float64"), - _Fallback(["float16"], lambda ctx, x: math_dialect.sqrt(x)), - _Fallback(["bfloat16"], lambda ctx, x: math_dialect.sqrt(x)), + _Extern([jnp.float32], "__ocml_sqrt_f32", jnp.float32), + _Extern([jnp.float64], "__ocml_sqrt_f64", jnp.float64), + _Fallback([jnp.float16], lambda ctx, x: math_dialect.sqrt(x)), + _Fallback([jnp.bfloat16], lambda ctx, x: math_dialect.sqrt(x)), ], ), lax.pow_p: _make_dispatch_table( "pow", cuda=[ - _Extern(["float32", "int32"], "__nv_powif", "float32"), - _Extern(["float64", "int32"], "__nv_powi", "float64"), - _Extern(["float32", "float32"], "__nv_powf", "float32"), - _Extern(["float64", "float64"], "__nv_pow", "float64"), + _Extern([jnp.float32, jnp.int32], "__nv_powif", jnp.float32), + _Extern([jnp.float64, jnp.int32], "__nv_powi", jnp.float64), + _Extern([jnp.float32, jnp.float32], "__nv_powf", jnp.float32), + _Extern([jnp.float64, jnp.float64], "__nv_pow", jnp.float64), ], rocm=[ - _Extern(["float32", "int32"], "__ocml_pown_f32", "float32"), - _Extern(["float64", "int32"], "__ocml_pown_f64", "float64"), - _Extern(["float32", "float32"], "__ocml_pow_f32", "float32"), - _Extern(["float64", "float64"], "__ocml_pow_f64", "float64"), + _Extern([jnp.float32, jnp.int32], "__ocml_pown_f32", jnp.float32), + _Extern([jnp.float64, jnp.int32], "__ocml_pown_f64", jnp.float64), + _Extern([jnp.float32, jnp.float32], "__ocml_pow_f32", jnp.float32), + _Extern([jnp.float64, jnp.float64], "__ocml_pow_f64", jnp.float64), ], ), lax.cbrt_p: _make_dispatch_table( "cbrt", cuda=[ - _Extern(["float32"], "__nv_cbrtf", "float32"), - _Extern(["float64"], "__nv_cbrt", "float64"), + _Extern([jnp.float32], "__nv_cbrtf", jnp.float32), + _Extern([jnp.float64], "__nv_cbrt", jnp.float64), ], rocm=[ - _Extern(["float32"], "__ocml_cbrt_f32", "float32"), - _Extern(["float64"], "__ocml_cbrt_f64", "float64"), + _Extern([jnp.float32], "__ocml_cbrt_f32", jnp.float32), + _Extern([jnp.float64], "__ocml_cbrt_f64", jnp.float64), ], ), lax.rsqrt_p: _make_dispatch_table( "rsqrt", cuda=[ - _Extern(["float32"], "__nv_rsqrtf", "float32"), - _Extern(["float64"], "__nv_rsqrt", "float64"), + _Extern([jnp.float32], "__nv_rsqrtf", jnp.float32), + _Extern([jnp.float64], "__nv_rsqrt", jnp.float64), ], rocm=[ - _Extern(["float32"], "__ocml_rsqrt_f32", "float32"), - _Extern(["float64"], "__ocml_rsqrt_f64", "float64"), + _Extern([jnp.float32], "__ocml_rsqrt_f32", jnp.float32), + _Extern([jnp.float64], "__ocml_rsqrt_f64", jnp.float64), ], ), lax.sin_p: _make_dispatch_table( "sin", cuda=[ - _Extern(["float32"], "__nv_sinf", "float32"), - _Extern(["float64"], "__nv_sin", "float64"), - _Fallback(["float16"], lambda ctx, x: math_dialect.sin(x)), - _Fallback(["bfloat16"], lambda ctx, x: math_dialect.sin(x)), + _Extern([jnp.float32], "__nv_sinf", jnp.float32), + _Extern([jnp.float64], "__nv_sin", jnp.float64), + _Fallback([jnp.float16], lambda ctx, x: math_dialect.sin(x)), + _Fallback([jnp.bfloat16], lambda ctx, x: math_dialect.sin(x)), ], rocm=[ - _Extern(["float32"], "__ocml_sin_f32", "float32"), - _Extern(["float64"], "__ocml_sin_f64", "float64"), - _Fallback(["float16"], lambda ctx, x: math_dialect.sin(x)), - _Fallback(["bfloat16"], lambda ctx, x: math_dialect.sin(x)), + _Extern([jnp.float32], "__ocml_sin_f32", jnp.float32), + _Extern([jnp.float64], "__ocml_sin_f64", jnp.float64), + _Fallback([jnp.float16], lambda ctx, x: math_dialect.sin(x)), + _Fallback([jnp.bfloat16], lambda ctx, x: math_dialect.sin(x)), ], ), lax.cos_p: _make_dispatch_table( "cos", cuda=[ - _Extern(["float32"], "__nv_cosf", "float32"), - _Extern(["float64"], "__nv_cos", "float64"), - _Fallback(["float16"], lambda ctx, x: math_dialect.cos(x)), - _Fallback(["bfloat16"], lambda ctx, x: math_dialect.cos(x)), + _Extern([jnp.float32], "__nv_cosf", jnp.float32), + _Extern([jnp.float64], "__nv_cos", jnp.float64), + _Fallback([jnp.float16], lambda ctx, x: math_dialect.cos(x)), + _Fallback([jnp.bfloat16], lambda ctx, x: math_dialect.cos(x)), ], rocm=[ - _Extern(["float32"], "__ocml_cos_f32", "float32"), - _Extern(["float64"], "__ocml_cos_f64", "float64"), - _Fallback(["float16"], lambda ctx, x: math_dialect.cos(x)), - _Fallback(["bfloat16"], lambda ctx, x: math_dialect.cos(x)), + _Extern([jnp.float32], "__ocml_cos_f32", jnp.float32), + _Extern([jnp.float64], "__ocml_cos_f64", jnp.float64), + _Fallback([jnp.float16], lambda ctx, x: math_dialect.cos(x)), + _Fallback([jnp.bfloat16], lambda ctx, x: math_dialect.cos(x)), ], ), lax.tan_p: _make_dispatch_table( "tan", cuda=[ - _Extern(["float32"], "__nv_tanf", "float32"), - _Extern(["float64"], "__nv_tan", "float64"), + _Extern([jnp.float32], "__nv_tanf", jnp.float32), + _Extern([jnp.float64], "__nv_tan", jnp.float64), ], rocm=[ - _Extern(["float32"], "__ocml_tan_f32", "float32"), - _Extern(["float64"], "__ocml_tan_f64", "float64"), + _Extern([jnp.float32], "__ocml_tan_f32", jnp.float32), + _Extern([jnp.float64], "__ocml_tan_f64", jnp.float64), ], ), lax.asin_p: _make_dispatch_table( "asin", cuda=[ - _Extern(["float32"], "__nv_asinf", "float32"), - _Extern(["float64"], "__nv_asin", "float64"), + _Extern([jnp.float32], "__nv_asinf", jnp.float32), + _Extern([jnp.float64], "__nv_asin", jnp.float64), ], rocm=[ - _Extern(["float32"], "__ocml_asin_f32", "float32"), - _Extern(["float64"], "__ocml_asin_f64", "float64"), + _Extern([jnp.float32], "__ocml_asin_f32", jnp.float32), + _Extern([jnp.float64], "__ocml_asin_f64", jnp.float64), ], ), lax.acos_p: _make_dispatch_table( "acos", cuda=[ - _Extern(["float32"], "__nv_acosf", "float32"), - _Extern(["float64"], "__nv_acos", "float64"), + _Extern([jnp.float32], "__nv_acosf", jnp.float32), + _Extern([jnp.float64], "__nv_acos", jnp.float64), ], rocm=[ - _Extern(["float32"], "__ocml_acos_f32", "float32"), - _Extern(["float64"], "__ocml_acos_f64", "float64"), + _Extern([jnp.float32], "__ocml_acos_f32", jnp.float32), + _Extern([jnp.float64], "__ocml_acos_f64", jnp.float64), ], ), lax.atan_p: _make_dispatch_table( "atan", cuda=[ - _Extern(["float32"], "__nv_atanf", "float32"), - _Extern(["float64"], "__nv_atan", "float64"), + _Extern([jnp.float32], "__nv_atanf", jnp.float32), + _Extern([jnp.float64], "__nv_atan", jnp.float64), ], rocm=[ - _Extern(["float32"], "__ocml_atan_f32", "float32"), - _Extern(["float64"], "__ocml_atan_f64", "float64"), + _Extern([jnp.float32], "__ocml_atan_f32", jnp.float32), + _Extern([jnp.float64], "__ocml_atan_f64", jnp.float64), ], ), lax.atan2_p: _make_dispatch_table( "atan2", cuda=[ - _Extern(["float32", "float32"], "__nv_atan2f", "float32"), - _Extern(["float64", "float64"], "__nv_atan2", "float64"), + _Extern([jnp.float32, jnp.float32], "__nv_atan2f", jnp.float32), + _Extern([jnp.float64, jnp.float64], "__nv_atan2", jnp.float64), ], rocm=[ - _Extern(["float32", "float32"], "__ocml_atan2_f32", "float32"), - _Extern(["float64", "float64"], "__ocml_atan2_f64", "float64"), + _Extern( + [jnp.float32, jnp.float32], "__ocml_atan2_f32", jnp.float32 + ), + _Extern( + [jnp.float64, jnp.float64], "__ocml_atan2_f64", jnp.float64 + ), ], ), lax.sinh_p: _make_dispatch_table( "sinh", cuda=[ - _Extern(["float32"], "__nv_sinhf", "float32"), - _Extern(["float64"], "__nv_sinh", "float64"), + _Extern([jnp.float32], "__nv_sinhf", jnp.float32), + _Extern([jnp.float64], "__nv_sinh", jnp.float64), ], rocm=[ - _Extern(["float32"], "__ocml_sinh_f32", "float32"), - _Extern(["float64"], "__ocml_sinh_f64", "float64"), + _Extern([jnp.float32], "__ocml_sinh_f32", jnp.float32), + _Extern([jnp.float64], "__ocml_sinh_f64", jnp.float64), ], ), lax.cosh_p: _make_dispatch_table( "cosh", cuda=[ - _Extern(["float32"], "__nv_coshf", "float32"), - _Extern(["float64"], "__nv_cosh", "float64"), + _Extern([jnp.float32], "__nv_coshf", jnp.float32), + _Extern([jnp.float64], "__nv_cosh", jnp.float64), ], rocm=[ - _Extern(["float32"], "__ocml_cosh_f32", "float32"), - _Extern(["float64"], "__ocml_cosh_f64", "float64"), + _Extern([jnp.float32], "__ocml_cosh_f32", jnp.float32), + _Extern([jnp.float64], "__ocml_cosh_f64", jnp.float64), ], ), lax.tanh_p: _make_dispatch_table( "tanh", cuda=[ - _Extern(["float32"], "__nv_tanhf", "float32"), - _Extern(["float64"], "__nv_tanh", "float64"), + _Extern([jnp.float32], "__nv_tanhf", jnp.float32), + _Extern([jnp.float64], "__nv_tanh", jnp.float64), ], rocm=[ - _Extern(["float32"], "__ocml_tanh_f32", "float32"), - _Extern(["float64"], "__ocml_tanh_f64", "float64"), + _Extern([jnp.float32], "__ocml_tanh_f32", jnp.float32), + _Extern([jnp.float64], "__ocml_tanh_f64", jnp.float64), ], ), lax.asinh_p: _make_dispatch_table( "asinh", cuda=[ - _Extern(["float32"], "__nv_asinhf", "float32"), - _Extern(["float64"], "__nv_asinh", "float64"), + _Extern([jnp.float32], "__nv_asinhf", jnp.float32), + _Extern([jnp.float64], "__nv_asinh", jnp.float64), ], rocm=[ - _Extern(["float32"], "__ocml_asinh_f32", "float32"), - _Extern(["float64"], "__ocml_asinh_f64", "float64"), + _Extern([jnp.float32], "__ocml_asinh_f32", jnp.float32), + _Extern([jnp.float64], "__ocml_asinh_f64", jnp.float64), ], ), lax.acosh_p: _make_dispatch_table( "acosh", cuda=[ - _Extern(["float32"], "__nv_acoshf", "float32"), - _Extern(["float64"], "__nv_acosh", "float64"), + _Extern([jnp.float32], "__nv_acoshf", jnp.float32), + _Extern([jnp.float64], "__nv_acosh", jnp.float64), ], rocm=[ - _Extern(["float32"], "__ocml_acosh_f32", "float32"), - _Extern(["float64"], "__ocml_acosh_f64", "float64"), + _Extern([jnp.float32], "__ocml_acosh_f32", jnp.float32), + _Extern([jnp.float64], "__ocml_acosh_f64", jnp.float64), ], ), lax.atanh_p: _make_dispatch_table( "atanh", cuda=[ - _Extern(["float32"], "__nv_atanhf", "float32"), - _Extern(["float64"], "__nv_atanh", "float64"), + _Extern([jnp.float32], "__nv_atanhf", jnp.float32), + _Extern([jnp.float64], "__nv_atanh", jnp.float64), ], rocm=[ - _Extern(["float32"], "__ocml_atanh_f32", "float32"), - _Extern(["float64"], "__ocml_atanh_f64", "float64"), + _Extern([jnp.float32], "__ocml_atanh_f32", jnp.float32), + _Extern([jnp.float64], "__ocml_atanh_f64", jnp.float64), ], ), lax.population_count_p: _make_dispatch_table( "population_count", cuda=[ - _Extern(["int32"], "__nv_popc", "int32"), - _Extern(["int64"], "__nv_popcll", "int32"), + _Extern([jnp.int32], "__nv_popc", jnp.int32), + _Extern([jnp.int64], "__nv_popcll", jnp.int32), ], rocm=[ - _Fallback(["int32"], lambda ctx, x: math_dialect.ctpop(x)), - _Fallback(["int64"], lambda ctx, x: math_dialect.ctpop(x)), + _Fallback([jnp.int32], lambda ctx, x: math_dialect.ctpop(x)), + _Fallback([jnp.int64], lambda ctx, x: math_dialect.ctpop(x)), ], ), lax.clz_p: _make_dispatch_table( "clz", cuda=[ - _Extern(["int32"], "__nv_clz", "int32"), - _Extern(["int64"], "__nv_clzll", "int32"), + _Extern([jnp.int32], "__nv_clz", jnp.int32), + _Extern([jnp.int64], "__nv_clzll", jnp.int32), ], rocm=[ - _Fallback(["int32"], lambda ctx, x: math_dialect.ctlz(x)), - _Fallback(["int64"], lambda ctx, x: math_dialect.ctlz(x)), + _Fallback([jnp.int32], lambda ctx, x: math_dialect.ctlz(x)), + _Fallback([jnp.int64], lambda ctx, x: math_dialect.ctlz(x)), ], ), lax.nextafter_p: _make_dispatch_table( "nextafter", cuda=[ - _Extern(["float32", "float32"], "__nv_nextafterf", "float32"), - _Extern(["float64", "float64"], "__nv_nextafter", "float64"), + _Extern([jnp.float32, jnp.float32], "__nv_nextafterf", jnp.float32 ), + _Extern([jnp.float64, jnp.float64], "__nv_nextafter", jnp.float64), ], rocm=[ - _Extern(["float32", "float32"], "__ocml_nextafter_f32", "float32"), - _Extern(["float64", "float64"], "__ocml_nextafter_f64", "float64"), + _Extern( + [jnp.float32, jnp.float32], "__ocml_nextafter_f32", jnp.float32 + ), + _Extern( + [jnp.float64, jnp.float64], "__ocml_nextafter_f64", jnp.float64 + ), ], ), lax.erf_inv_p: _make_dispatch_table( "erf_inv", cuda=[ _Fallback( - ["float32"], + [jnp.float32], lower_fun( pallas_utils.erf_inv_32_lowering_helper, multiple_results=False, @@ -1006,7 +1015,7 @@ def _abs_lowering_rule(ctx: LoweringRuleContext, x): ], rocm=[ _Fallback( - ["float32"], + [jnp.float32], lower_fun( pallas_utils.erf_inv_32_lowering_helper, multiple_results=False, @@ -2215,7 +2224,7 @@ def _argreduce_lowering( if i != axis: index = _expand_dims(index, i) index = _bcast_to(index, a_aval.shape) - ctx = ctx.replace(avals_in=[a_aval, a_aval.update(dtype=jnp.dtype("int32"))]) + ctx = ctx.replace(avals_in=[a_aval, a_aval.update(dtype=jnp.dtype(jnp.int32))]) _, indices = _reduction_lowering(body, ctx, (a, index), axes=axes) return indices