Skip to content

Commit

Permalink
Migrate multi_head_jagged_flash_attention SLL ops to OSS (#3485)
Browse files Browse the repository at this point in the history
Summary:
X-link: facebookresearch/FBGEMM#569

Pull Request resolved: #3485

- Migrate `multi_head_jagged_flash_attention` SLL ops to OSS

Reviewed By: brad-mengchi

Differential Revision: D66972360

fbshipit-source-id: 42d9548ad3a3e9390ff1f3d464804065bf9481a3
  • Loading branch information
q10 authored and facebook-github-bot committed Dec 13, 2024
1 parent d70a19e commit 6e9b083
Show file tree
Hide file tree
Showing 4 changed files with 963 additions and 0 deletions.
1 change: 1 addition & 0 deletions .github/scripts/fbgemm_gpu_test.bash
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ __configure_fbgemm_gpu_test_cpu () {
./sll/jagged_flash_attention_basic_test.py
./sll/jagged_jagged_bmm_jagged_out_test.py
./sll/jagged_dense_flash_attention_test.py
./sll/multi_head_jagged_flash_attention_test.py
)
}

Expand Down
22 changes: 22 additions & 0 deletions fbgemm_gpu/fbgemm_gpu/sll/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
jagged_jagged_bmm,
jagged_jagged_bmm_jagged_out,
jagged_softmax,
multi_head_jagged_flash_attention,
triton_jagged_self_substraction_jagged_out,
)

Expand Down Expand Up @@ -263,6 +264,19 @@ def register_sll_op(op_name: str, functors: Dict[str, Callable]) -> None:
"""
)

if "fbgemm::sll_multi_head_jagged_flash_attention" not in torch.library._defs:
lib.define(
"""sll_multi_head_jagged_flash_attention(
Tensor q_weights,
Tensor k_weights,
Tensor v_weights,
Tensor offsets,
int max_seq_len,
bool allow_tf32=True
) -> Tensor
"""
)

# NOTE: here we register the op for AutogradCUDA/CPU and CUDA/CPU with the same function
# however, this is not ideal because in the inference case, we don't need the autograd forward
# to save the context because we don't need to do backward.
Expand Down Expand Up @@ -396,3 +410,11 @@ def register_sll_op(op_name: str, functors: Dict[str, Callable]) -> None:
"AutogradCPU": cpu_jagged_dense_flash_attention,
},
)

register_sll_op(
"sll_multi_head_jagged_flash_attention",
{
"CUDA": multi_head_jagged_flash_attention,
"AutogradCUDA": multi_head_jagged_flash_attention,
},
)
Loading

0 comments on commit 6e9b083

Please sign in to comment.