-
-
Notifications
You must be signed in to change notification settings - Fork 47
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Your Name
committed
Sep 4, 2024
1 parent
95868ae
commit 69138f9
Showing
5 changed files
with
153 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,54 @@ | ||
import torch | ||
from torch import nn | ||
cimport cython | ||
|
||
cdef class MultiQueryAttention: | ||
cdef int embed_dim | ||
cdef int num_heads | ||
cdef int head_dim | ||
cdef object query_proj # Treat nn.Linear as a Python object | ||
cdef object key_proj # Treat nn.Linear as a Python object | ||
cdef object value_proj # Treat nn.Linear as a Python object | ||
cdef object out_proj # Treat nn.Linear as a Python object | ||
|
||
def __cinit__(self, int embed_dim, int num_heads): | ||
self.embed_dim = embed_dim | ||
self.num_heads = num_heads | ||
self.head_dim = embed_dim // num_heads | ||
|
||
assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads" | ||
|
||
# Initialize nn.Linear layers as regular Python objects | ||
self.query_proj = nn.Linear(embed_dim, embed_dim) | ||
self.key_proj = nn.Linear(embed_dim, self.head_dim) | ||
self.value_proj = nn.Linear(embed_dim, self.head_dim) | ||
self.out_proj = nn.Linear(embed_dim, embed_dim) | ||
|
||
@cython.boundscheck(False) | ||
@cython.wraparound(False) | ||
def forward(self, query, key, value): | ||
cdef int batch_size, seq_len, _ | ||
|
||
# Assuming the input tensors are torch.Tensor objects | ||
batch_size, seq_len, _ = query.size() | ||
|
||
# Linear projections | ||
queries = self.query_proj(query) | ||
keys = self.key_proj(key) | ||
values = self.value_proj(value) | ||
|
||
# Reshape for multi-head attention | ||
queries = queries.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2) | ||
keys = keys.unsqueeze(1).expand(-1, self.num_heads, -1, -1) | ||
values = values.unsqueeze(1).expand(-1, self.num_heads, -1, -1) | ||
|
||
# Scaled dot-product attention | ||
scores = torch.matmul(queries, keys.transpose(-2, -1)) / (self.head_dim ** 0.5) | ||
attn_weights = torch.nn.functional.softmax(scores, dim=-1) | ||
attn_output = torch.matmul(attn_weights, values) | ||
|
||
# Concatenate and project the output | ||
attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, seq_len, self.embed_dim) | ||
output = self.out_proj(attn_output) | ||
|
||
return output |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,54 @@ | ||
import timeit | ||
import torch | ||
from zeta import MultiQueryAttention as PyTorchMQA | ||
from mqa import MultiQueryAttention as CythonMQA | ||
|
||
# Input parameters | ||
batch_size = 32 | ||
seq_len = 128 | ||
embed_dim = 512 | ||
num_heads = 8 | ||
|
||
# Create sample input tensors | ||
query = torch.randn(batch_size, seq_len, embed_dim) | ||
key = torch.randn(batch_size, seq_len, embed_dim) | ||
value = torch.randn(batch_size, seq_len, embed_dim) | ||
|
||
# Initialize the PyTorch Multi-Query Attention layer (from zeta package) | ||
pytorch_mqa = PyTorchMQA(dim=embed_dim, heads=num_heads) | ||
|
||
# Initialize the Cython Multi-Query Attention layer (from mqa module) | ||
cython_mqa = CythonMQA(embed_dim, num_heads) | ||
|
||
|
||
# Define functions for benchmarking | ||
def run_pytorch_mqa(): | ||
output, _, _ = pytorch_mqa(query) | ||
return output | ||
|
||
|
||
def run_cython_mqa(): | ||
output = cython_mqa.forward(query, key, value) | ||
return output | ||
|
||
|
||
# Warm-up runs (important to avoid cold start issues) | ||
for _ in range(20): | ||
run_pytorch_mqa() | ||
run_cython_mqa() | ||
|
||
# Benchmark PyTorch implementation | ||
pytorch_time = timeit.timeit( | ||
"run_pytorch_mqa()", globals=globals(), number=1000 | ||
) | ||
|
||
# Benchmark Cython implementation | ||
cython_time = timeit.timeit("run_cython_mqa()", globals=globals(), number=1000) | ||
|
||
# Print the results | ||
print(f"PyTorch MQA execution time: {pytorch_time:.6f} seconds") | ||
print(f"Cython MQA execution time: {cython_time:.6f} seconds") | ||
if cython_time < pytorch_time: | ||
print(f"Cython is faster by: {pytorch_time / cython_time:.2f}x") | ||
else: | ||
print(f"PyTorch is faster by: {cython_time / pytorch_time:.2f}x") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
import torch | ||
import torch_extension # Import the compiled Cython module | ||
|
||
# Create a PyTorch tensor | ||
input_tensor = torch.tensor([0.0, 1.0, 2.0, 3.0]) | ||
|
||
# Use the Cython function to apply the sin function | ||
output_tensor = torch_extension.apply_sin(input_tensor) | ||
|
||
print(output_tensor) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
from setuptools import setup, Extension | ||
from torch.utils.cpp_extension import BuildExtension | ||
from Cython.Build import cythonize | ||
|
||
setup( | ||
name="mqa", | ||
ext_modules=cythonize( | ||
Extension( | ||
"mqa", | ||
sources=["mqa.pyx"], | ||
language="c++", | ||
) | ||
), | ||
cmdclass={"build_ext": BuildExtension}, | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
import torch # Use standard Python import for PyTorch | ||
cimport cython | ||
import numpy as np | ||
|
||
@cython.boundscheck(False) | ||
@cython.wraparound(False) | ||
def apply_sin(input_tensor): | ||
cdef int i | ||
cdef int size = input_tensor.numel() | ||
|
||
# Convert the PyTorch tensor to a NumPy array | ||
np_array = input_tensor.numpy() | ||
|
||
# Apply sin element-wise using NumPy | ||
np_output = np.sin(np_array) | ||
|
||
# Convert the NumPy array back to a PyTorch tensor | ||
output_tensor = torch.from_numpy(np_output) | ||
|
||
return output_tensor |