diff --git a/transformer_engine/jax/cpp_extensions/gemm.py b/transformer_engine/jax/cpp_extensions/gemm.py index 2df05d6df4..ee4c38d076 100644 --- a/transformer_engine/jax/cpp_extensions/gemm.py +++ b/transformer_engine/jax/cpp_extensions/gemm.py @@ -1,7 +1,6 @@ # Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. -"""JAX/TE custom ops for cuBlasLt GEMM""" import warnings import operator from functools import reduce @@ -39,6 +38,10 @@ ] +def sanitize_dims(dim, ndims): + return (ndims + dim) if dim < 0 else dim + + def get_cublas_workspace_size_bytes() -> None: """Return 32 MiB if using hopper, 4 MiB for all other architectures.""" if tex.get_device_compute_capability() >= 90: @@ -98,11 +101,8 @@ def abstract( ), "Missing RHS operand scale inverse in FP8 GEMM." # Validate operand layouts - lhs_inner_dim, rhs_inner_dim = map( - lambda inner_dim, ndims: (ndims - inner_dim) if inner_dim < 0 else inner_dim, - contracting_dims, - (lhs_aval.ndim, rhs_aval.ndim), - ) + lhs_inner_dim, rhs_inner_dim = map(sanitize_dims, contracting_dims, + (lhs_aval.ndim, rhs_aval.ndim)) assert ( lhs_aval.shape[lhs_inner_dim] == rhs_aval.shape[rhs_inner_dim] ), f"Incompatible operand sizes: {lhs_aval.shape} x {rhs_aval.shape}." @@ -134,23 +134,31 @@ def abstract( out_amax_updated_dtype = jnp.float32 out_scale_updated_dtype = jnp.float32 - # Infer output shape + # Make sure leading dimensions of RHS is broadcast-compatible with LHS lhs_outer_dim = lhs_aval.ndim - 1 if lhs_trans else lhs_aval.ndim - 2 + rhs_outer_dim = rhs_aval.ndim - 2 if rhs_trans else rhs_aval.ndim - 1 + lhs_bdims = [ dim for dim in range(lhs_aval.ndim) if dim not in [lhs_outer_dim, lhs_inner_dim] ] lhs_batch_shape = [lhs_aval.shape[dim] for dim in lhs_bdims] lhs_batch_size = reduce(operator.mul, lhs_batch_shape, 1) - rhs_outer_dim = rhs_aval.ndim - 2 if rhs_trans else rhs_aval.ndim - 1 - rhs_bdims = [ - dim for dim in range(rhs_aval.ndim) if dim not in [rhs_outer_dim, rhs_inner_dim] - ] - rhs_batch_size = reduce(operator.mul, rhs_bdims, 1) - assert ( - lhs_batch_size == rhs_batch_size - ), "LHS and RHS operands must have the same batched sizes." - out_shape = (*lhs_batch_shape, lhs_aval.shape[lhs_outer_dim], rhs_aval.shape[rhs_outer_dim]) + if rhs_aval.ndim > 2: + rhs_bdims = [ + dim for dim in range(rhs_aval.ndim) if dim not in [rhs_outer_dim, rhs_inner_dim] + ] + rhs_batch_shape = [rhs_aval.shape[dim] for dim in rhs_bdims] + rhs_batch_size = reduce(operator.mul, rhs_batch_shape, 1) + if rhs_batch_size > 1: + assert ( + lhs_batch_size == rhs_batch_size + ), ( + f"Leading dimensins of RHS ({rhs_batch_shape=}) is not broadcast-compatible " + + f"with the leading dimensions of LHS ({lhs_batch_shape=})." + ) + # Infer output shape + out_shape = (*lhs_batch_shape, lhs_aval.shape[lhs_outer_dim], rhs_aval.shape[rhs_outer_dim]) # Validate bias/bias_grad shape against inferred output bias_dtype = jnp.bfloat16 if jax_dtype_is_fp8(out_dtype) else out_dtype if fuse_bias: @@ -227,11 +235,8 @@ def lowering( Fused attention fwd lowering rules """ lhs_aval, _, rhs_aval, _, bias_aval, *_ = ctx.avals_in - lhs_inner_dim, rhs_inner_dim = map( - lambda inner_dim, ndims: (ndims - inner_dim) if inner_dim < 0 else inner_dim, - contracting_dims, - (lhs_aval.ndim, rhs_aval.ndim), - ) + lhs_inner_dim, rhs_inner_dim = map(sanitize_dims, contracting_dims, + (lhs_aval.ndim, rhs_aval.ndim)) lhs_trans = lhs_inner_dim != lhs_aval.ndim - 1 rhs_trans = rhs_inner_dim == rhs_aval.ndim - 1 @@ -376,19 +381,8 @@ def batcher( check_valid_batch_dims(batch_dims) lhs_bdims, *_, bias_bdims, gelu_input_bdims, out_amax_bdims, out_scale_bdims = batch_dims - # FP8 GEMM only supports non-transposed LHS and transposed RHS - lhs, _, rhs, *_ = batched_args - lhs_trans = contracting_dims[0] != lhs.ndim - 1 - rhs_trans = contracting_dims[1] == rhs.ndim - 1 - lhs = jnp.matrix_transpose(lhs) if lhs_trans and jax_dtype_is_fp8(lhs.dtype) else lhs - rhs = jnp.matrix_transpose(rhs) if not rhs_trans and jax_dtype_is_fp8(rhs.dtype) else rhs - contracting_dims = (1, 1) - return CollectiveGemmPrimitive.outer_primitive.bind( - lhs, - batched_args[1], - rhs, - *batched_args[3:], + *batched_args, out_dtype=out_dtype, contracting_dims=contracting_dims, fuse_gelu=fuse_gelu, @@ -415,11 +409,7 @@ def infer_sharding_from_operands( lhs, _, rhs, *_ = arg_infos lhs_spec, rhs_spec = map(get_padded_spec, [lhs, rhs]) - lhs_inner_dim, rhs_inner_dim = map( - lambda inner_dim, ndims: (ndims - inner_dim) if inner_dim < 0 else inner_dim, - contracting_dims, - (lhs.ndim, rhs.ndim), - ) + lhs_inner_dim, rhs_inner_dim = map(sanitize_dims, contracting_dims, (lhs.ndim, rhs.ndim)) if lhs_spec[lhs_inner_dim] != rhs_spec[rhs_inner_dim] and not grad: warnings.warn( "Forcing the inner dimension of LHS to match the sharding of inner " @@ -471,11 +461,7 @@ def partition( lhs, _, rhs, *_ = arg_infos lhs_spec, rhs_spec = map(get_padded_spec, [lhs, rhs]) - lhs_inner_dim, rhs_inner_dim = map( - lambda inner_dim, ndims: (ndims - inner_dim) if inner_dim < 0 else inner_dim, - contracting_dims, - (lhs.ndim, rhs.ndim), - ) + lhs_inner_dim, rhs_inner_dim = map(sanitize_dims, contracting_dims, (lhs.ndim, rhs.ndim)) lhs_trans = lhs_inner_dim != lhs.ndim - 1 rhs_trans = rhs_inner_dim == rhs.ndim - 1 @@ -578,14 +564,13 @@ def sharded_impl( def fp8_gemm_impl( lhs: ArrayLike, lhs_scale_inv: ArrayLike, - rhs: ArrayLike, + rhs_t: ArrayLike, rhs_scale_inv: ArrayLike, bias: Optional[ArrayLike] = None, gelu_input: Optional[ArrayLike] = None, out_amax: Optional[ArrayLike] = None, out_scale: Optional[ArrayLike] = None, out_dtype: jnp.dtype = jnp.bfloat16, - contracting_dims: Tuple[int, int] = (1, 1), fuse_gelu: bool = False, fuse_bias: bool = False, accumulate: bool = False, @@ -606,22 +591,20 @@ def fp8_gemm_impl( if not fuse_gelu: gelu_input = jnp.zeros(0, dtype=bias.dtype) elif gelu_input is None: - lhs_outer_dim = lhs.ndim - 1 if contracting_dims[0] == 1 else lhs.ndim - 2 - rhs_outer_dim = rhs.ndim - 2 if contracting_dims[1] == 0 else rhs.ndim - 1 - out_shape = (*lhs.shape[:-2], lhs.shape[lhs_outer_dim], rhs.shape[rhs_outer_dim]) + out_shape = (*lhs.shape[:-2], lhs.shape[-2], rhs_t.shape[-2]) gelu_input = jnp.zeros(out_shape, dtype=bias.dtype) out, out_amax, out_scale, pre_gelu_out, _ = CollectiveGemmPrimitive.outer_primitive.bind( - rhs, - rhs_scale_inv, lhs, lhs_scale_inv, + rhs_t, + rhs_scale_inv, bias, gelu_input, out_amax, out_scale, out_dtype=out_dtype, - contracting_dims=tuple(reversed(contracting_dims)), + contracting_dims=(-1, -1), fuse_gelu=fuse_gelu, fuse_bias=fuse_bias, grad=False, @@ -645,10 +628,9 @@ def gemm_impl( use_split_accumulator: bool = False, ) -> Tuple[ArrayLike, ...]: """Non-FP8 mat-mul with `nvte_cublas_gemm()` custom op.""" - dummy_fp8_meta = jnp.zeros(0, dtype=jnp.float32) - - lhs_outer_dim = lhs.ndim - 1 if contracting_dims[0] == 1 else lhs.ndim - 2 - rhs_outer_dim = rhs.ndim - 2 if contracting_dims[1] == 0 else rhs.ndim - 1 + lhs_inner_dim, rhs_inner_dim = map(sanitize_dims, contracting_dims, (lhs.ndim, rhs.ndim)) + lhs_outer_dim = lhs.ndim - 1 if lhs_inner_dim == lhs.ndim - 2 else lhs.ndim - 2 + rhs_outer_dim = rhs.ndim - 2 if rhs_inner_dim == rhs.ndim - 1 else rhs.ndim - 1 out_shape = (*lhs.shape[:-2], lhs.shape[lhs_outer_dim], rhs.shape[rhs_outer_dim]) if not fuse_bias: @@ -667,6 +649,7 @@ def gemm_impl( elif gelu_input is None: gelu_input = jnp.zeros(out_shape, dtype=lhs.dtypes) + dummy_fp8_meta = jnp.zeros(0, dtype=jnp.float32) out, _, _, pre_gelu_out, bias_grad = CollectiveGemmPrimitive.outer_primitive.bind( lhs, dummy_fp8_meta, diff --git a/transformer_engine/jax/gemm.py b/transformer_engine/jax/gemm.py index 79499725b7..e9e046d182 100644 --- a/transformer_engine/jax/gemm.py +++ b/transformer_engine/jax/gemm.py @@ -1,7 +1,8 @@ # Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. -from functools import partial +import operator +from functools import partial, reduce from typing import Optional, Tuple, Union import jax @@ -19,6 +20,7 @@ dbias_cast_transpose, dact_lu_dbias_cast_transpose, ) +from .cpp_extensions.gemm import sanitize_dims __all__ = [ @@ -98,27 +100,48 @@ def _gemm_bwd_rule( grad, ): x, kernel, pre_gelu_out, fuse_bias = ctx + x_inner_dim, kernel_inner_dim = map(sanitize_dims, contracting_dims, (x.ndim, kernel.ndim)) - x_t_contracting = 0 if contracting_dims[0] == 1 else 1 - wgrad, dgelu, bgrad = gemm_impl( - x, + + kernel_t_contracting = ( + kernel.ndim - 2 if kernel_inner_dim == kernel.ndim - 1 else kernel.ndim - 1 + ) + # DGRAD: ([B], M, N) x (K, N)^T = ([B], M, K) + dgrad, dgelu, _ = gemm_impl( grad, + kernel, gelu_input=pre_gelu_out, - contracting_dims=(x_t_contracting, 0), + contracting_dims=(-1, kernel_t_contracting), fuse_gelu=fuse_gelu, - fuse_bias=fuse_bias, + fuse_bias=False, grad=True, accumulate=accumulate, use_split_accumulator=use_split_accumulator, ) - kernel_t_contracting = 1 if contracting_dims[1] == 0 else 0 - dgrad, *_ = gemm_impl( - dgelu if fuse_gelu else grad, - kernel, + # Collapse batch x sequence dimensions for WGRAD + x_outer_dim = x.ndim - 2 if x_inner_dim == x.ndim - 1 else x.ndim - 1 + wgrad_rhs = dgelu if fuse_gelu else grad + if x.ndim > 2: + batch_size = reduce(operator.mul, x.shape[:-2], 1) + x = jax.lax.reshape( + jax.lax.transpose(x, (*list(range(x.ndim - 2)), x_outer_dim, x_inner_dim)), + (batch_size * x.shape[x_outer_dim], x.shape[x_inner_dim]), + ) + wgrad_rhs = jnp.reshape( + wgrad_rhs, shape=(batch_size * wgrad_rhs.shape[-2], wgrad_rhs.shape[-1]) + ) + x_t_contracting = 0 + else: + x_t_contracting = x_outer_dim + + # WGRAD: ([B], M, K)^T x ([B], M, N) = ([B], K, N) + wgrad, _, bgrad = gemm_impl( + x, + wgrad_rhs, gelu_input=pre_gelu_out, - contracting_dims=(1, kernel_t_contracting), - fuse_gelu=fuse_gelu, + contracting_dims=(x_t_contracting, wgrad_rhs.ndim - 2), + fuse_gelu=False, fuse_bias=fuse_bias, grad=True, accumulate=accumulate, @@ -140,7 +163,6 @@ def fp8_gemm( fp8_meta: FP8MetaPackage, bias: Optional[ArrayLike] = None, out_dtype: jnp.dtype = jnp.bfloat16, - contracting_dims: Tuple[int, int] = (1, 1), fuse_gelu: bool = False, accumulate: bool = False, use_split_accumulator: bool = False, @@ -152,7 +174,6 @@ def fp8_gemm( fp8_meta.amax_list, fp8_meta.scale_list, out_dtype, - contracting_dims, fuse_gelu, accumulate, use_split_accumulator, @@ -162,12 +183,11 @@ def fp8_gemm( @partial(jax.custom_vjp, nondiff_argnums=(5, 6, 7, 8, 9)) def _fp8_gemm( x: ArrayLike, - kernel: ArrayLike, + kernel_t: ArrayLike, bias: ArrayLike, amax_list: ArrayLike, scale_list: ArrayLike, out_dtype: jnp.dtype, - contracting_dims: Tuple[int, int], fuse_gelu: bool, accumulate: bool, use_split_accumulator: bool, @@ -175,12 +195,11 @@ def _fp8_gemm( """Non-FP8 `nvte_cublas_gemm()` with optional GELU and bias-add fusions.""" out, _ = _fp8_gemm_fwd_rule( x, - kernel, + kernel_t, bias, amax_list, scale_list, out_dtype, - contracting_dims, fuse_gelu, accumulate, use_split_accumulator, @@ -190,12 +209,11 @@ def _fp8_gemm( def _fp8_gemm_fwd_rule( x: ArrayLike, - kernel: ArrayLike, + kernel_t: ArrayLike, bias: ArrayLike, amax_list: ArrayLike, scale_list: ArrayLike, out_dtype: jnp.dtype, - contracting_dims: Tuple[int, int], fuse_gelu: bool, accumulate: bool, use_split_accumulator: bool, @@ -221,54 +239,36 @@ def _fp8_gemm_fwd_rule( x_scale = scale_list[FP8MetaPackage.INPUT_IDX] x_scale_inv = scale_inv_list[FP8MetaPackage.INPUT_IDX] if x.dtype not in [jnp.float8_e4m3fn, jnp.float8_e5m2]: - if contracting_dims[0] == 0: - _, casted_x, updated_x_amax = cast_transpose( - x, - x_amax, - x_scale, - x_scale_inv, - fwd_dtype, - static_axis_boundary=-1, - transpose_axis_boundary=-1, - ) - else: - casted_x, updated_x_amax = cast_fp8(x, x_amax, x_scale, x_scale_inv, fwd_dtype) + casted_x, casted_x_t, updated_x_amax = cast_transpose( + x, + x_amax, + x_scale, + x_scale_inv, + fwd_dtype, + static_axis_boundary=-1, + transpose_axis_boundary=-1, + ) else: - if contracting_dims[0] == 0: - casted_x_t = x - casted_x = casted_x_t.transpose() - else: - casted_x = x + casted_x = x + casted_x_t = jnp.matrix_transpose(x) updated_x_amax = x_amax kernel_amax = amax_list[FP8MetaPackage.WEIGHT_IDX][0:1] kernel_scale = scale_list[FP8MetaPackage.WEIGHT_IDX] kernel_scale_inv = scale_inv_list[FP8MetaPackage.WEIGHT_IDX] - if kernel.dtype not in [jnp.float8_e4m3fn, jnp.float8_e5m2]: - if contracting_dims[1] == 0: # need to transpose the kernel for FP8 GEMM - _, casted_kernel_t, updated_kernel_amax = cast_transpose( - kernel, - kernel_amax, - kernel_scale, - kernel_scale_inv, - fwd_dtype, - static_axis_boundary=-1, - transpose_axis_boundary=-1, - ) - else: - casted_kernel_t, updated_kernel_amax = cast_fp8( - kernel, - kernel_amax, - kernel_scale, - kernel_scale_inv, - fwd_dtype, - ) + if kernel_t.dtype not in [jnp.float8_e4m3fn, jnp.float8_e5m2]: + casted_kernel_t, casted_kernel, updated_kernel_amax = cast_transpose( + kernel_t, + kernel_amax, + kernel_scale, + kernel_scale_inv, + fwd_dtype, + static_axis_boundary=-1, + transpose_axis_boundary=-1, + ) else: - if contracting_dims[1] == 0: - casted_kernel = kernel - casted_kernel_t = casted_kernel.transpose() - else: - casted_kernel_t = kernel + casted_kernel = jnp.matrix_transpose(kernel_t) + casted_kernel_t = kernel_t updated_kernel_amax = kernel_amax out_amax = ( @@ -300,24 +300,24 @@ def _fp8_gemm_fwd_rule( updated_out_scale = None ctx = ( - casted_x, - casted_kernel_t, + casted_x_t, + casted_kernel, amax_list, scale_list, scale_inv_list, updated_x_amax, updated_kernel_amax, + updated_out_amax, pre_gelu_out if fuse_gelu else None, fuse_bias, maybe_fp32_to_fm32, ) - return (out, updated_out_amax, updated_out_scale), ctx + return (out, updated_out_scale), ctx def _fp8_gemm_bwd_rule( out_dtype, - contracting_dims, fuse_gelu, accumulate, use_split_accumulator, @@ -325,83 +325,84 @@ def _fp8_gemm_bwd_rule( grad, ): ( - casted_x, - casted_kernel_t, + casted_x_t, + casted_kernel, amax_list, scale_list, scale_inv_list, updated_x_amax, updated_kernel_amax, + updated_out_amax, pre_gelu_out, fuse_bias, maybe_fp32_to_fm32, ) = ctx - fwd_dtype = FP8Helper.FWD_DTYPE bwd_dtype = FP8Helper.BWD_DTYPE grad_amax = amax_list[FP8MetaPackage.GRAD_IDX][0:1] grad_scale = scale_list[FP8MetaPackage.GRAD_IDX] grad_scale_inv = scale_inv_list[FP8MetaPackage.GRAD_ID] - if fuse_bias and not fuse_gelu: - # Since there is no GELU fusion, we need to fuse dbias into this cast_transpose. - _, casted_grad_t, bgrad, updated_grad_amax = dbias_cast_transpose( - grad, - grad_amax, - grad_scale, - grad_scale_inv, - bwd_dtype, - static_axis_boundary=-1, - transpose_axis_boundary=-1, - ) + if fuse_gelu: + if fuse_bias: + # Fuse dbias into this dGELU. + casted_grad, casted_grad_t, bgrad, updated_grad_amax = dact_lu_dbias_cast_transpose( + grad, + pre_gelu_out, + grad_amax, + grad_scale, + grad_scale_inv, + bwd_dtype, + static_axis_boundary=-1, + transpose_axis_boundary=-1, + activation_type=("gelu",), + ) + else: + # No bias to fuse so we just do dGELU. + casted_grad, casted_grad_t, updated_grad_amax = dact_lu(grad, pre_gelu_out, ("gelu",)) + bgrad = None else: - # If both bias and GELU is fused into the forward pass, we will fuse dbias later with - # dGELU. No need to do it here. - _, casted_grad_t, updated_grad_amax = cast_transpose( - grad, - grad_amax, - grad_scale, - grad_scale_inv, - bwd_dtype, - static_axis_boundary=-1, - transpose_axis_boundary=-1, - ) - bgrad = None + if fuse_bias: + # Since there is no GELU fusion, we need to fuse dbias into this cast_transpose. + casted_grad, casted_grad_t, bgrad, updated_grad_amax = dbias_cast_transpose( + grad, + grad_amax, + grad_scale, + grad_scale_inv, + bwd_dtype, + static_axis_boundary=-1, + transpose_axis_boundary=-1, + ) + else: + # If both bias and GELU is fused into the forward pass, we will fuse dbias later with + # dGELU. No need to do it here. + casted_grad, casted_grad_t, updated_grad_amax = cast_transpose( + grad, + grad_amax, + grad_scale, + grad_scale_inv, + bwd_dtype, + static_axis_boundary=-1, + transpose_axis_boundary=-1, + ) + bgrad = None - x_scale_inv = scale_inv_list[FP8MetaPackage.INPUT_IDX] - wgrad, *_ = fp8_gemm_impl( - casted_x, - x_scale_inv, - casted_grad_t, + kernel_scale_inv = scale_inv_list[FP8MetaPackage.WEIGHT_IDX] + dgrad, *_ = fp8_gemm_impl( + casted_grad, grad_scale_inv, + casted_kernel, + kernel_scale_inv, accumulate=accumulate, use_split_accumulator=use_split_accumulator, ) - if fuse_gelu and fuse_bias: - # Fuse dbias into this dGELU. - casted_dgelu, casted_dgelu_t, bgrad, updated_dgelu_amax = dact_lu_dbias_cast_transpose( - grad, - pre_gelu_out, - grad_amax, - grad_scale, - grad_scale_inv, - bwd_dtype, - static_axis_boundary=-1, - transpose_axis_boundary=-1, - activation_type=("gelu",), - ) - elif fuse_gelu: - # No bias to fuse so we just do dGELU. - casted_dgelu, casted_dgelu_t, updated_dgelu_amax = dact_lu(grad, pre_gelu_out, ("gelu",)) - bgrad = None - - kernel_scale_inv = scale_inv_list[FP8MetaPackage.WEIGHT_IDX] - dgrad, *_ = gemm_impl( - casted_dgelu if fuse_gelu else grad, + x_scale_inv = scale_inv_list[FP8MetaPackage.INPUT_IDX] + wgrad, *_ = fp8_gemm_impl( + casted_x_t, + x_scale_inv, + casted_grad_t, grad_scale_inv, - casted_kernel_t, - kernel_scale_inv, accumulate=accumulate, use_split_accumulator=use_split_accumulator, ) @@ -412,6 +413,13 @@ def _fp8_gemm_bwd_rule( amax_list[FP8MetaPackage.WEIGHT_IDX] = ( amax_list[FP8MetaPackage.WEIGHT_IDX].at[0].set(updated_kernel_amax[0]) ) + amax_list[FP8MetaPackage.GRAD_IDX] = ( + amax_list[FP8MetaPackage.GRAD_IDX].at[0].set(updated_grad_amax[0]) + ) + if out_dtype in [jnp.float8_e4m3fn, jnp.float8_e5m2]: + amax_list[FP8MetaPackage.OUTPUT_IDX] = ( + amax_list[FP8MetaPackage.OUTPUT_IDX].at[0].set(updated_out_amax[0]) + ) amax_list = maybe_fp32_to_fm32(*amax_list) scale_list = maybe_fp32_to_fm32(*scale_list) @@ -433,20 +441,24 @@ def type_safe_gemm( accumulate: bool = False, use_split_accumulator: bool = False, ) -> ArrayLike: - if x.dtype in [jnp.float8_e4m3fn, jnp.float8_e5m2] or kernel.dtype in [ - jnp.float8_e4m3fn, - jnp.float8_e5m2, - ]: + if (x.dtype in [jnp.float8_e4m3fn, jnp.float8_e5m2] + or kernel.dtype in [jnp.float8_e4m3fn, jnp.float8_e5m2]): assert fp8_meta is not None, "GEMM operands have FP8 dtypes but FP8MetaPackage is None." if fp8_meta is not None: + x_inner_dim, kernel_inner_dim = map(sanitize_dims, contracting_dims, (x.ndim, kernel.ndim)) + assert ( + x_inner_dim == x.ndim - 1 and kernel_inner_dim == kernel.ndim - 2 + ), ( + "FP8 GEMM requires non-transposed X (LHS) and transposed kernel (RHS), " + + "i.e. contracting_dims=(-1, -1)." + ) return fp8_gemm( x, kernel, bias, fp8_meta, out_dtype, - contracting_dims, fuse_gelu, accumulate, use_split_accumulator,