Skip to content

Commit

Permalink
Add very simple batching support for ragged_dot.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 675407251
  • Loading branch information
Google-ML-Automation committed Sep 26, 2024
1 parent 9f4e8d0 commit 13f1597
Showing 1 changed file with 152 additions and 17 deletions.
169 changes: 152 additions & 17 deletions jax/_src/lax/lax.py
Original file line number Diff line number Diff line change
Expand Up @@ -3244,13 +3244,25 @@ def _dot_general_transpose_rhs(g, x, y, *, dimension_numbers, precision,
y_bar = _convert_element_type(y_bar, y.aval.dtype, y.aval.weak_type)
return y_bar

def _dot_general_batch_rule(batched_args, batch_dims, *, dimension_numbers,
precision,
preferred_element_type: DTypeLike | None,
algorithm: _DotAlgorithmLike = None,
transpose_algorithm: DotTransposeAlgorithm | None = None):
lhs, rhs = batched_args
lbd, rbd = batch_dims

def _dot_batch_rule(
unpack_args,
unpack_dims,
invoke_prim,
batched_args,
batch_dims,
*,
dimension_numbers,
precision,
preferred_element_type: DTypeLike | None,
algorithm: _DotAlgorithmLike = None,
transpose_algorithm: DotTransposeAlgorithm | None = None,
**_,
):

lhs, rhs = unpack_args(batched_args)
lbd, rbd = unpack_dims(batch_dims)

(lhs_contract, rhs_contract), (lhs_batch, rhs_batch) = dimension_numbers
left_stack_dim = lbd.stacked_axis if type(lbd) is RaggedAxis else lbd
right_stack_dim = rbd.stacked_axis if type(rbd) is RaggedAxis else rbd
Expand All @@ -3272,16 +3284,21 @@ def _dot_general_batch_rule(batched_args, batch_dims, *, dimension_numbers,
rhs_shape = batching.bdim_as_shape(rbd, rhs.shape)
else:
rhs_shape = np.shape(rhs)
batched_out = dot_general(lhs, rhs, new_dimension_numbers,
precision=precision,
preferred_element_type=preferred_element_type,
algorithm=algorithm,
transpose_algorithm=transpose_algorithm)
batched_out = invoke_prim(
lhs,
rhs,
new_dimension_numbers,
precision=precision,
preferred_element_type=preferred_element_type,
algorithm=algorithm,
transpose_algorithm=transpose_algorithm,
)
result_batch_dim = batching.shape_as_bdim(
result_stack_dim,
_dot_general_shape_computation(lhs_shape, rhs_shape, new_dimension_numbers))
return batched_out, result_batch_dim


def _dot_general_batch_dim_nums(ndims, batch_dims, dimension_numbers):
# There are three kinds of dimensions in a dot_general:
# - contraction dimensions appear in lhs and rhs but not the result
Expand Down Expand Up @@ -3356,8 +3373,35 @@ def _dot_general_pp_rule(eqn, context, settings) -> pp.Doc:

dot_general_p = standard_primitive(_dot_general_shape_rule,
_dot_general_dtype_rule, 'dot_general')


def _dot_general_batch_unpack_args(batch_args):
lhs, rhs = batch_args
return (lhs, rhs)


def _dot_general_batch_unpack_dims(batch_dims):
lbd, rbd = batch_dims
return (lbd, rbd)

# DotDimensionNumbers used in the dot_general call for ragged_dot().
_RAGGED_DOT_DOT_DIMENSION_NUMBERS: DotDimensionNumbers = (
([2, 0], [1, 0]),
([], []),
)
_RAGGED_DOT_BATCH_DOT_DIMENSION_NUMBERS: DotDimensionNumbers = (
([3, 1], [2, 1]),
([0], [0]),
)

ad.defbilinear(dot_general_p,
_dot_general_transpose_lhs, _dot_general_transpose_rhs)
_dot_general_batch_rule = functools.partial(
_dot_batch_rule,
_dot_general_batch_unpack_args,
_dot_general_batch_unpack_dims,
dot_general,
)
batching.primitive_batchers[dot_general_p] = _dot_general_batch_rule
pe.padding_rules[dot_general_p] = _dot_general_padding_rule
core.pp_eqn_rules[dot_general_p] = _dot_general_pp_rule
Expand Down Expand Up @@ -3461,6 +3505,34 @@ def _is_fp8_mixed_precision_matmul(_lhs_dtypes, _rhs_dtypes):


def _ragged_dot_shape_rule(lhs: Array, rhs: Array, group_sizes: Array, **_) -> Shape:
if len(lhs.shape) == 3:
# Batched case
b, m, k = lhs.shape
b2, group_count, rk, n = rhs.shape
b3 = group_sizes.shape[0]
if b != b2:
raise TypeError(
f'ragged_dot requires that lhs.shape[0] == rhs.shape[0]: got {b} and'
f' {b2}.'
)
if b3 != b:
raise TypeError(
'ragged_dot requires that group_sizes.shape[0] == lhs.shape[0]: got'
f' {b3} and {b}.'
)
if k != rk:
raise TypeError(
f'ragged_dot requires that lhs.shape[1] == rhs.shape[1]: got {k} and'
f' {rk}.'
)
num_groups = group_sizes.shape[1]
if group_count != num_groups:
raise TypeError(
'ragged_dot requires that rhs.shape[1] == group_sizes.shape[1]: got'
f' {group_count} and {num_groups}.'
)
return (b, m, n)

m, k = lhs.shape
group_count, rk, n = rhs.shape
if k != rk:
Expand All @@ -3470,9 +3542,6 @@ def _ragged_dot_shape_rule(lhs: Array, rhs: Array, group_sizes: Array, **_) -> S
raise TypeError(f"ragged_dot requires that rhs.shape[0] == group_sizes.shape[0]: got {group_count} and {num_groups}.")
return (m, n)

# DotDimensionNumbers used in the dot_general call for ragged_dot().
_RAGGED_DOT_DOT_DIMENSION_NUMBERS: DotDimensionNumbers = (([2, 0], [1, 0]), ([], []))

def _ragged_dot_dtype_rule(lhs: Array, rhs: Array, group_sizes: Array,
precision, preferred_element_type: DTypeLike | None, **_) -> np.dtype:
if not dtypes.issubdtype(group_sizes.dtype, np.integer):
Expand Down Expand Up @@ -3584,11 +3653,68 @@ def _ragged_dot_transpose_rule(
return grad_x, grad_y, None


def _ragged_dot_batch_unpack_args(batched_args):
lhs, rhs, _ = batched_args
return (lhs, rhs)


def _ragged_dot_batch_unpack_dims(batch_dims):
if not all(dim == 0 for dim in batch_dims):
raise NotImplementedError('ragged_dot vmap over any dim but 0 - NYI')
lbd, rbd, _ = batch_dims
return (lbd, rbd)


def _ragged_dot_invoke_prim(
group_sizes,
lhs,
rhs,
new_dimension_numbers,
precision,
preferred_element_type,
algorithm,
transpose_algorithm,
):
assert algorithm is None
assert transpose_algorithm is None

return ragged_dot(
lhs,
rhs,
group_sizes,
precision=precision,
preferred_element_type=preferred_element_type,
)


def _ragged_dot_batch_rule(
batched_args,
batch_dims,
*,
precision,
preferred_element_type: DTypeLike | None,
**_,
):
invoke = functools.partial(_ragged_dot_invoke_prim, batched_args[2])

return _dot_batch_rule(
_ragged_dot_batch_unpack_args,
_ragged_dot_batch_unpack_dims,
invoke,
batched_args,
batch_dims,
dimension_numbers=_RAGGED_DOT_DOT_DIMENSION_NUMBERS,
precision=precision,
preferred_element_type=preferred_element_type,
)


ragged_dot_p = standard_primitive(_ragged_dot_shape_rule,
_ragged_dot_dtype_rule, 'ragged_dot')
ragged_dot_p.def_impl(partial(dispatch.apply_primitive, ragged_dot_p))
ad.primitive_jvps[ragged_dot_p] = _ragged_dot_jvp_rule
ad.primitive_transposes[ragged_dot_p] = _ragged_dot_transpose_rule
batching.primitive_batchers[ragged_dot_p] = _ragged_dot_batch_rule

def _ragged_dot_impl(
lhs: Array,
Expand All @@ -3600,11 +3726,20 @@ def _ragged_dot_impl(
) -> Array:
if group_offset is not None:
raise NotImplementedError("Unimplemented group_offset support.")
lhs = _ragged_to_dense(lhs, rhs, group_sizes=group_sizes)

if len(lhs.shape) == 3:
ragged_dot_dims = _RAGGED_DOT_BATCH_DOT_DIMENSION_NUMBERS
ragged_to_dense = api.vmap(_ragged_to_dense, in_axes=(0, 0, 0))
else:
ragged_dot_dims = _RAGGED_DOT_DOT_DIMENSION_NUMBERS
ragged_to_dense = _ragged_to_dense

lhs = ragged_to_dense(lhs, rhs, group_sizes)

return dot_general(
lhs,
rhs,
dimension_numbers=_RAGGED_DOT_DOT_DIMENSION_NUMBERS,
dimension_numbers=ragged_dot_dims,
precision=precision,
preferred_element_type=preferred_element_type,
)
Expand Down

0 comments on commit 13f1597

Please sign in to comment.