Skip to content

Commit

Permalink
re-applied bug fixes to working older version, updated backward pass,…
Browse files Browse the repository at this point in the history
… passing test

Signed-off-by: Alp Dener <adener@nvidia.com>
  • Loading branch information
denera committed Nov 15, 2024
1 parent 52af237 commit b989641
Show file tree
Hide file tree
Showing 2 changed files with 174 additions and 179 deletions.
93 changes: 38 additions & 55 deletions transformer_engine/jax/cpp_extensions/gemm.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""JAX/TE custom ops for cuBlasLt GEMM"""
import warnings
import operator
from functools import reduce
Expand Down Expand Up @@ -39,6 +38,10 @@
]


def sanitize_dims(dim, ndims):
return (ndims + dim) if dim < 0 else dim


def get_cublas_workspace_size_bytes() -> None:
"""Return 32 MiB if using hopper, 4 MiB for all other architectures."""
if tex.get_device_compute_capability() >= 90:
Expand Down Expand Up @@ -98,11 +101,8 @@ def abstract(
), "Missing RHS operand scale inverse in FP8 GEMM."

# Validate operand layouts
lhs_inner_dim, rhs_inner_dim = map(
lambda inner_dim, ndims: (ndims - inner_dim) if inner_dim < 0 else inner_dim,
contracting_dims,
(lhs_aval.ndim, rhs_aval.ndim),
)
lhs_inner_dim, rhs_inner_dim = map(sanitize_dims, contracting_dims,
(lhs_aval.ndim, rhs_aval.ndim))
assert (
lhs_aval.shape[lhs_inner_dim] == rhs_aval.shape[rhs_inner_dim]
), f"Incompatible operand sizes: {lhs_aval.shape} x {rhs_aval.shape}."
Expand Down Expand Up @@ -134,23 +134,31 @@ def abstract(
out_amax_updated_dtype = jnp.float32
out_scale_updated_dtype = jnp.float32

# Infer output shape
# Make sure leading dimensions of RHS is broadcast-compatible with LHS
lhs_outer_dim = lhs_aval.ndim - 1 if lhs_trans else lhs_aval.ndim - 2
rhs_outer_dim = rhs_aval.ndim - 2 if rhs_trans else rhs_aval.ndim - 1

lhs_bdims = [
dim for dim in range(lhs_aval.ndim) if dim not in [lhs_outer_dim, lhs_inner_dim]
]
lhs_batch_shape = [lhs_aval.shape[dim] for dim in lhs_bdims]
lhs_batch_size = reduce(operator.mul, lhs_batch_shape, 1)
rhs_outer_dim = rhs_aval.ndim - 2 if rhs_trans else rhs_aval.ndim - 1
rhs_bdims = [
dim for dim in range(rhs_aval.ndim) if dim not in [rhs_outer_dim, rhs_inner_dim]
]
rhs_batch_size = reduce(operator.mul, rhs_bdims, 1)
assert (
lhs_batch_size == rhs_batch_size
), "LHS and RHS operands must have the same batched sizes."
out_shape = (*lhs_batch_shape, lhs_aval.shape[lhs_outer_dim], rhs_aval.shape[rhs_outer_dim])
if rhs_aval.ndim > 2:
rhs_bdims = [
dim for dim in range(rhs_aval.ndim) if dim not in [rhs_outer_dim, rhs_inner_dim]
]
rhs_batch_shape = [rhs_aval.shape[dim] for dim in rhs_bdims]
rhs_batch_size = reduce(operator.mul, rhs_batch_shape, 1)
if rhs_batch_size > 1:
assert (
lhs_batch_size == rhs_batch_size
), (
f"Leading dimensins of RHS ({rhs_batch_shape=}) is not broadcast-compatible "
+ f"with the leading dimensions of LHS ({lhs_batch_shape=})."
)

# Infer output shape
out_shape = (*lhs_batch_shape, lhs_aval.shape[lhs_outer_dim], rhs_aval.shape[rhs_outer_dim])
# Validate bias/bias_grad shape against inferred output
bias_dtype = jnp.bfloat16 if jax_dtype_is_fp8(out_dtype) else out_dtype
if fuse_bias:
Expand Down Expand Up @@ -227,11 +235,8 @@ def lowering(
Fused attention fwd lowering rules
"""
lhs_aval, _, rhs_aval, _, bias_aval, *_ = ctx.avals_in
lhs_inner_dim, rhs_inner_dim = map(
lambda inner_dim, ndims: (ndims - inner_dim) if inner_dim < 0 else inner_dim,
contracting_dims,
(lhs_aval.ndim, rhs_aval.ndim),
)
lhs_inner_dim, rhs_inner_dim = map(sanitize_dims, contracting_dims,
(lhs_aval.ndim, rhs_aval.ndim))
lhs_trans = lhs_inner_dim != lhs_aval.ndim - 1
rhs_trans = rhs_inner_dim == rhs_aval.ndim - 1

Expand Down Expand Up @@ -376,19 +381,8 @@ def batcher(
check_valid_batch_dims(batch_dims)
lhs_bdims, *_, bias_bdims, gelu_input_bdims, out_amax_bdims, out_scale_bdims = batch_dims

# FP8 GEMM only supports non-transposed LHS and transposed RHS
lhs, _, rhs, *_ = batched_args
lhs_trans = contracting_dims[0] != lhs.ndim - 1
rhs_trans = contracting_dims[1] == rhs.ndim - 1
lhs = jnp.matrix_transpose(lhs) if lhs_trans and jax_dtype_is_fp8(lhs.dtype) else lhs
rhs = jnp.matrix_transpose(rhs) if not rhs_trans and jax_dtype_is_fp8(rhs.dtype) else rhs
contracting_dims = (1, 1)

return CollectiveGemmPrimitive.outer_primitive.bind(
lhs,
batched_args[1],
rhs,
*batched_args[3:],
*batched_args,
out_dtype=out_dtype,
contracting_dims=contracting_dims,
fuse_gelu=fuse_gelu,
Expand All @@ -415,11 +409,7 @@ def infer_sharding_from_operands(
lhs, _, rhs, *_ = arg_infos
lhs_spec, rhs_spec = map(get_padded_spec, [lhs, rhs])

lhs_inner_dim, rhs_inner_dim = map(
lambda inner_dim, ndims: (ndims - inner_dim) if inner_dim < 0 else inner_dim,
contracting_dims,
(lhs.ndim, rhs.ndim),
)
lhs_inner_dim, rhs_inner_dim = map(sanitize_dims, contracting_dims, (lhs.ndim, rhs.ndim))
if lhs_spec[lhs_inner_dim] != rhs_spec[rhs_inner_dim] and not grad:
warnings.warn(
"Forcing the inner dimension of LHS to match the sharding of inner "
Expand Down Expand Up @@ -471,11 +461,7 @@ def partition(
lhs, _, rhs, *_ = arg_infos
lhs_spec, rhs_spec = map(get_padded_spec, [lhs, rhs])

lhs_inner_dim, rhs_inner_dim = map(
lambda inner_dim, ndims: (ndims - inner_dim) if inner_dim < 0 else inner_dim,
contracting_dims,
(lhs.ndim, rhs.ndim),
)
lhs_inner_dim, rhs_inner_dim = map(sanitize_dims, contracting_dims, (lhs.ndim, rhs.ndim))

lhs_trans = lhs_inner_dim != lhs.ndim - 1
rhs_trans = rhs_inner_dim == rhs.ndim - 1
Expand Down Expand Up @@ -578,14 +564,13 @@ def sharded_impl(
def fp8_gemm_impl(
lhs: ArrayLike,
lhs_scale_inv: ArrayLike,
rhs: ArrayLike,
rhs_t: ArrayLike,
rhs_scale_inv: ArrayLike,
bias: Optional[ArrayLike] = None,
gelu_input: Optional[ArrayLike] = None,
out_amax: Optional[ArrayLike] = None,
out_scale: Optional[ArrayLike] = None,
out_dtype: jnp.dtype = jnp.bfloat16,
contracting_dims: Tuple[int, int] = (1, 1),
fuse_gelu: bool = False,
fuse_bias: bool = False,
accumulate: bool = False,
Expand All @@ -606,22 +591,20 @@ def fp8_gemm_impl(
if not fuse_gelu:
gelu_input = jnp.zeros(0, dtype=bias.dtype)
elif gelu_input is None:
lhs_outer_dim = lhs.ndim - 1 if contracting_dims[0] == 1 else lhs.ndim - 2
rhs_outer_dim = rhs.ndim - 2 if contracting_dims[1] == 0 else rhs.ndim - 1
out_shape = (*lhs.shape[:-2], lhs.shape[lhs_outer_dim], rhs.shape[rhs_outer_dim])
out_shape = (*lhs.shape[:-2], lhs.shape[-2], rhs_t.shape[-2])
gelu_input = jnp.zeros(out_shape, dtype=bias.dtype)

out, out_amax, out_scale, pre_gelu_out, _ = CollectiveGemmPrimitive.outer_primitive.bind(
rhs,
rhs_scale_inv,
lhs,
lhs_scale_inv,
rhs_t,
rhs_scale_inv,
bias,
gelu_input,
out_amax,
out_scale,
out_dtype=out_dtype,
contracting_dims=tuple(reversed(contracting_dims)),
contracting_dims=(-1, -1),
fuse_gelu=fuse_gelu,
fuse_bias=fuse_bias,
grad=False,
Expand All @@ -645,10 +628,9 @@ def gemm_impl(
use_split_accumulator: bool = False,
) -> Tuple[ArrayLike, ...]:
"""Non-FP8 mat-mul with `nvte_cublas_gemm()` custom op."""
dummy_fp8_meta = jnp.zeros(0, dtype=jnp.float32)

lhs_outer_dim = lhs.ndim - 1 if contracting_dims[0] == 1 else lhs.ndim - 2
rhs_outer_dim = rhs.ndim - 2 if contracting_dims[1] == 0 else rhs.ndim - 1
lhs_inner_dim, rhs_inner_dim = map(sanitize_dims, contracting_dims, (lhs.ndim, rhs.ndim))
lhs_outer_dim = lhs.ndim - 1 if lhs_inner_dim == lhs.ndim - 2 else lhs.ndim - 2
rhs_outer_dim = rhs.ndim - 2 if rhs_inner_dim == rhs.ndim - 1 else rhs.ndim - 1
out_shape = (*lhs.shape[:-2], lhs.shape[lhs_outer_dim], rhs.shape[rhs_outer_dim])

if not fuse_bias:
Expand All @@ -667,6 +649,7 @@ def gemm_impl(
elif gelu_input is None:
gelu_input = jnp.zeros(out_shape, dtype=lhs.dtypes)

dummy_fp8_meta = jnp.zeros(0, dtype=jnp.float32)
out, _, _, pre_gelu_out, bias_grad = CollectiveGemmPrimitive.outer_primitive.bind(
lhs,
dummy_fp8_meta,
Expand Down
Loading

0 comments on commit b989641

Please sign in to comment.