diff --git a/jax/_src/lax/lax.py b/jax/_src/lax/lax.py index f51f0436b7a9..039d0865d096 100644 --- a/jax/_src/lax/lax.py +++ b/jax/_src/lax/lax.py @@ -3244,13 +3244,25 @@ def _dot_general_transpose_rhs(g, x, y, *, dimension_numbers, precision, y_bar = _convert_element_type(y_bar, y.aval.dtype, y.aval.weak_type) return y_bar -def _dot_general_batch_rule(batched_args, batch_dims, *, dimension_numbers, - precision, - preferred_element_type: DTypeLike | None, - algorithm: _DotAlgorithmLike = None, - transpose_algorithm: DotTransposeAlgorithm | None = None): - lhs, rhs = batched_args - lbd, rbd = batch_dims + +def _dot_batch_rule( + unpack_args, + unpack_dims, + invoke_prim, + batched_args, + batch_dims, + *, + dimension_numbers, + precision, + preferred_element_type: DTypeLike | None, + algorithm: _DotAlgorithmLike = None, + transpose_algorithm: DotTransposeAlgorithm | None = None, + **_, +): + + lhs, rhs = unpack_args(batched_args) + lbd, rbd = unpack_dims(batch_dims) + (lhs_contract, rhs_contract), (lhs_batch, rhs_batch) = dimension_numbers left_stack_dim = lbd.stacked_axis if type(lbd) is RaggedAxis else lbd right_stack_dim = rbd.stacked_axis if type(rbd) is RaggedAxis else rbd @@ -3272,16 +3284,21 @@ def _dot_general_batch_rule(batched_args, batch_dims, *, dimension_numbers, rhs_shape = batching.bdim_as_shape(rbd, rhs.shape) else: rhs_shape = np.shape(rhs) - batched_out = dot_general(lhs, rhs, new_dimension_numbers, - precision=precision, - preferred_element_type=preferred_element_type, - algorithm=algorithm, - transpose_algorithm=transpose_algorithm) + batched_out = invoke_prim( + lhs, + rhs, + new_dimension_numbers, + precision=precision, + preferred_element_type=preferred_element_type, + algorithm=algorithm, + transpose_algorithm=transpose_algorithm, + ) result_batch_dim = batching.shape_as_bdim( result_stack_dim, _dot_general_shape_computation(lhs_shape, rhs_shape, new_dimension_numbers)) return batched_out, result_batch_dim + def _dot_general_batch_dim_nums(ndims, batch_dims, dimension_numbers): # There are three kinds of dimensions in a dot_general: # - contraction dimensions appear in lhs and rhs but not the result @@ -3356,8 +3373,35 @@ def _dot_general_pp_rule(eqn, context, settings) -> pp.Doc: dot_general_p = standard_primitive(_dot_general_shape_rule, _dot_general_dtype_rule, 'dot_general') + + +def _dot_general_batch_unpack_args(batch_args): + lhs, rhs = batch_args + return (lhs, rhs) + + +def _dot_general_batch_unpack_dims(batch_dims): + lbd, rbd = batch_dims + return (lbd, rbd) + +# DotDimensionNumbers used in the dot_general call for ragged_dot(). +_RAGGED_DOT_DOT_DIMENSION_NUMBERS: DotDimensionNumbers = ( + ([2, 0], [1, 0]), + ([], []), +) +_RAGGED_DOT_BATCH_DOT_DIMENSION_NUMBERS: DotDimensionNumbers = ( + ([3, 1], [2, 1]), + ([0], [0]), +) + ad.defbilinear(dot_general_p, _dot_general_transpose_lhs, _dot_general_transpose_rhs) +_dot_general_batch_rule = functools.partial( + _dot_batch_rule, + _dot_general_batch_unpack_args, + _dot_general_batch_unpack_dims, + dot_general, +) batching.primitive_batchers[dot_general_p] = _dot_general_batch_rule pe.padding_rules[dot_general_p] = _dot_general_padding_rule core.pp_eqn_rules[dot_general_p] = _dot_general_pp_rule @@ -3461,6 +3505,34 @@ def _is_fp8_mixed_precision_matmul(_lhs_dtypes, _rhs_dtypes): def _ragged_dot_shape_rule(lhs: Array, rhs: Array, group_sizes: Array, **_) -> Shape: + if len(lhs.shape) == 3: + # Batched case + b, m, k = lhs.shape + b2, group_count, rk, n = rhs.shape + b3 = group_sizes.shape[0] + if b != b2: + raise TypeError( + f'ragged_dot requires that lhs.shape[0] == rhs.shape[0]: got {b} and' + f' {b2}.' + ) + if b3 != b: + raise TypeError( + 'ragged_dot requires that group_sizes.shape[0] == lhs.shape[0]: got' + f' {b3} and {b}.' + ) + if k != rk: + raise TypeError( + f'ragged_dot requires that lhs.shape[1] == rhs.shape[1]: got {k} and' + f' {rk}.' + ) + num_groups = group_sizes.shape[1] + if group_count != num_groups: + raise TypeError( + 'ragged_dot requires that rhs.shape[1] == group_sizes.shape[1]: got' + f' {group_count} and {num_groups}.' + ) + return (b, m, n) + m, k = lhs.shape group_count, rk, n = rhs.shape if k != rk: @@ -3470,9 +3542,6 @@ def _ragged_dot_shape_rule(lhs: Array, rhs: Array, group_sizes: Array, **_) -> S raise TypeError(f"ragged_dot requires that rhs.shape[0] == group_sizes.shape[0]: got {group_count} and {num_groups}.") return (m, n) -# DotDimensionNumbers used in the dot_general call for ragged_dot(). -_RAGGED_DOT_DOT_DIMENSION_NUMBERS: DotDimensionNumbers = (([2, 0], [1, 0]), ([], [])) - def _ragged_dot_dtype_rule(lhs: Array, rhs: Array, group_sizes: Array, precision, preferred_element_type: DTypeLike | None, **_) -> np.dtype: if not dtypes.issubdtype(group_sizes.dtype, np.integer): @@ -3584,11 +3653,68 @@ def _ragged_dot_transpose_rule( return grad_x, grad_y, None +def _ragged_dot_batch_unpack_args(batched_args): + lhs, rhs, _ = batched_args + return (lhs, rhs) + + +def _ragged_dot_batch_unpack_dims(batch_dims): + if not all(dim == 0 for dim in batch_dims): + raise NotImplementedError('ragged_dot vmap over any dim but 0 - NYI') + lbd, rbd, _ = batch_dims + return (lbd, rbd) + + +def _ragged_dot_invoke_prim( + group_sizes, + lhs, + rhs, + new_dimension_numbers, + precision, + preferred_element_type, + algorithm, + transpose_algorithm, +): + assert algorithm is None + assert transpose_algorithm is None + + return ragged_dot( + lhs, + rhs, + group_sizes, + precision=precision, + preferred_element_type=preferred_element_type, + ) + + +def _ragged_dot_batch_rule( + batched_args, + batch_dims, + *, + precision, + preferred_element_type: DTypeLike | None, + **_, +): + invoke = functools.partial(_ragged_dot_invoke_prim, batched_args[2]) + + return _dot_batch_rule( + _ragged_dot_batch_unpack_args, + _ragged_dot_batch_unpack_dims, + invoke, + batched_args, + batch_dims, + dimension_numbers=_RAGGED_DOT_DOT_DIMENSION_NUMBERS, + precision=precision, + preferred_element_type=preferred_element_type, + ) + + ragged_dot_p = standard_primitive(_ragged_dot_shape_rule, _ragged_dot_dtype_rule, 'ragged_dot') ragged_dot_p.def_impl(partial(dispatch.apply_primitive, ragged_dot_p)) ad.primitive_jvps[ragged_dot_p] = _ragged_dot_jvp_rule ad.primitive_transposes[ragged_dot_p] = _ragged_dot_transpose_rule +batching.primitive_batchers[ragged_dot_p] = _ragged_dot_batch_rule def _ragged_dot_impl( lhs: Array, @@ -3600,11 +3726,20 @@ def _ragged_dot_impl( ) -> Array: if group_offset is not None: raise NotImplementedError("Unimplemented group_offset support.") - lhs = _ragged_to_dense(lhs, rhs, group_sizes=group_sizes) + + if len(lhs.shape) == 3: + ragged_dot_dims = _RAGGED_DOT_BATCH_DOT_DIMENSION_NUMBERS + ragged_to_dense = api.vmap(_ragged_to_dense, in_axes=(0, 0, 0)) + else: + ragged_dot_dims = _RAGGED_DOT_DOT_DIMENSION_NUMBERS + ragged_to_dense = _ragged_to_dense + + lhs = ragged_to_dense(lhs, rhs, group_sizes) + return dot_general( lhs, rhs, - dimension_numbers=_RAGGED_DOT_DOT_DIMENSION_NUMBERS, + dimension_numbers=ragged_dot_dims, precision=precision, preferred_element_type=preferred_element_type, )