Skip to content

Commit

Permalink
[ops] Add a sharded attention operation for SDPA (#381)
Browse files Browse the repository at this point in the history
Existing implementation invokes the `torch` sdpa operator directly.
Rewired to invoke via the `ops` system for a sharded sdpa operation.
This includes a sharded implementation that compares the two versions.
  • Loading branch information
rsuderman authored Oct 31, 2024
1 parent 3058af9 commit 2d46caa
Show file tree
Hide file tree
Showing 5 changed files with 102 additions and 13 deletions.
11 changes: 5 additions & 6 deletions sharktank/sharktank/layers/paged_llama_attention_block.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,12 +171,11 @@ def repeat_kv(x: torch.Tensor) -> torch.Tensor:
) # (bs, heads, slen, head_dim)
else:
is_causal = attention_mask is None and batch_seq_len == 1
attn_output = torch.nn.functional.scaled_dot_product_attention(
query=xq, # [bs, ..., sl, dim]
key=keys, # [bs, ..., sl, dim]
value=values, # [bs, ..., sl, dim]
attn_mask=attention_mask, # [bs, ..., sl, sl]
dropout_p=0.0,
attn_output = ops.scaled_dot_product_attention(
q=xq, # [bs, ..., sl, dim]
k=keys, # [bs, ..., sl, dim]
v=values, # [bs, ..., sl, dim]
a=attention_mask, # [bs, ..., sl, sl]
is_causal=is_causal, # assumes causal masking when true
scale=None, # defaults to 1/sqrt(dim)
)
Expand Down
8 changes: 3 additions & 5 deletions sharktank/sharktank/ops/default_impls.py
Original file line number Diff line number Diff line change
Expand Up @@ -359,10 +359,8 @@ def matmul_default(lhs, rhs, *, transpose_rhs: bool) -> Tensor:


# Scaled dot product attention
@scaled_dot_product_attention.override(
Tensor, Tensor, Tensor, Optional[Tensor], auto_dequant=True
)
def scaled_dot_product_attention(q, k, v, a) -> Tensor:
@scaled_dot_product_attention.override(Tensor, Tensor, Tensor, None)
def scaled_dot_product_attention_torch(q, k, v, a, is_causal, scale) -> Tensor:
q = unbox_tensor(q)
k = unbox_tensor(k)
v = unbox_tensor(v)
Expand All @@ -371,7 +369,7 @@ def scaled_dot_product_attention(q, k, v, a) -> Tensor:

# TODO: plumb dropout and is_causal through ops
return torch.nn.functional.scaled_dot_product_attention(
q, k, v, attn_mask=a, dropout_p=0.0, is_causal=False
q, k, v, attn_mask=a, dropout_p=0.0, is_causal=is_causal, scale=scale
)


Expand Down
36 changes: 36 additions & 0 deletions sharktank/sharktank/ops/sharded_impls.py
Original file line number Diff line number Diff line change
Expand Up @@ -801,6 +801,42 @@ def matmul_split(
assert False, "Sharding configuration not supported"


# Scaled dot product attention
@scaled_dot_product_attention.override(
SplitPrimitiveTensor,
SplitPrimitiveTensor,
SplitPrimitiveTensor,
Optional[ReplicatedTensor],
)
def scaled_dot_product_attention_sharded(q, k, v, a, is_causal, scale) -> Tensor:
if q.shard_count != k.shard_count or q.shard_count != v.shard_count:
raise ValueError("Incompatible number of shards for qkv")

if a and q.shard_count != a.shard_count:
raise ValueError(
f"Incompatible number of shards for a ({a.shard_count}) should be ({q.shard_count})"
)

if q.shard_dim != k.shard_dim or q.shard_dim != v.shard_dim:
raise ValueError("Incompatible shard dim across qkv")

if q.shard_dim > len(q.shards[0].shape) - 2:
raise ValueError("Sharding must occur as batch dimension")

a_shards = [None] * q.shard_count
if a is not None:
a_shards = a.shards

output_shards = []
for q_s, k_s, v_s, a_s in zip(q.shards, k.shards, v.shards, a_shards):
o_s = scaled_dot_product_attention(
q_s, k_s, v_s, a_s, is_causal=is_causal, scale=scale
)
output_shards.append(o_s)

return SplitPrimitiveTensor(ts=output_shards, shard_dim=q.shard_dim)


@mean.override(ReplicatedTensor)
def mean_replicated(
x: ReplicatedTensor,
Expand Down
6 changes: 4 additions & 2 deletions sharktank/sharktank/ops/signatures.py
Original file line number Diff line number Diff line change
Expand Up @@ -784,7 +784,7 @@ def _replicate_trampoline(

@overridable
def scaled_dot_product_attention(
q: AnyTensor, k: AnyTensor, v: AnyTensor, a: Optional[AnyTensor]
q: AnyTensor, k: AnyTensor, v: AnyTensor, a: Optional[AnyTensor], is_causal: bool
) -> AnyTensor:
"""Computes the scaled dot product attention using QKV."""
raise NotImplementedError
Expand All @@ -797,10 +797,12 @@ def _scaled_dot_product_attention(
k: AnyTensor,
v: AnyTensor,
a: Optional[AnyTensor],
is_causal: bool = False,
scale: Optional[float] = None,
):
tensors = (q, k, v, a)
for override in d.find_overrides(tensors):
result = override(q, k, v, a)
result = override(q, k, v, a, is_causal=is_causal, scale=scale)
if result is not NotImplemented:
return override, result
else:
Expand Down
54 changes: 54 additions & 0 deletions sharktank/tests/ops/sharded_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -588,6 +588,60 @@ def testShardedPrimitiveTensorPermute(self):
assert ops.equal(expected_result, result)


class AttentionTest(unittest.TestCase):
def testAttentionShardedBatch(self):
q = torch.rand(4, 32, 16, dtype=torch.float32)
k = torch.rand(4, 32, 16, dtype=torch.float32)
v = torch.rand(4, 32, 16, dtype=torch.float32)

qs = SplitPrimitiveTensor(shard_dim=0, ts=q.split(4, dim=0))
ks = SplitPrimitiveTensor(shard_dim=0, ts=k.split(4, dim=0))
vs = SplitPrimitiveTensor(shard_dim=0, ts=v.split(4, dim=0))

expected_result = ops.scaled_dot_product_attention(q, k, v, a=None)
sharded_result = ops.scaled_dot_product_attention(qs, ks, vs, a=None)
unsharded_result = ops.sharded_cat(sharded_result)
torch.testing.assert_close(unsharded_result, expected_result)

def testAttentionShardedBatchCausal(self):
q = torch.rand(4, 32, 16, dtype=torch.float32)
k = torch.rand(4, 32, 16, dtype=torch.float32)
v = torch.rand(4, 32, 16, dtype=torch.float32)

qs = SplitPrimitiveTensor(shard_dim=0, ts=q.split(4, dim=0))
ks = SplitPrimitiveTensor(shard_dim=0, ts=k.split(4, dim=0))
vs = SplitPrimitiveTensor(shard_dim=0, ts=v.split(4, dim=0))

expected_result = ops.scaled_dot_product_attention(
q, k, v, a=None, is_causal=True
)
sharded_result = ops.scaled_dot_product_attention(
qs, ks, vs, a=None, is_causal=True
)
unsharded_result = ops.sharded_cat(sharded_result)
torch.testing.assert_close(unsharded_result, expected_result)

def testAttentionShardedBatchMask(self):
q = torch.rand(4, 32, 16, dtype=torch.float32)
k = torch.rand(4, 32, 16, dtype=torch.float32)
v = torch.rand(4, 32, 16, dtype=torch.float32)
a = torch.rand(1, 32, 32, dtype=torch.float32) > 0.5

q_s = SplitPrimitiveTensor(shard_dim=0, ts=q.split(1, dim=0))
k_s = SplitPrimitiveTensor(shard_dim=0, ts=k.split(1, dim=0))
v_s = SplitPrimitiveTensor(shard_dim=0, ts=v.split(1, dim=0))
a_s = ReplicatedTensor(ts=a, shard_count=4)

expected_result = ops.scaled_dot_product_attention(
q, k, v, a=a, is_causal=False
)
sharded_result = ops.scaled_dot_product_attention(
q_s, k_s, v_s, a=a_s, is_causal=False
)
unsharded_result = ops.sharded_cat(sharded_result)
torch.testing.assert_close(unsharded_result, expected_result)


class MatmulTest(unittest.TestCase):
def testTorchRHSColumnShardedTransposed(self):
t1 = torch.rand(4, 32, 16, dtype=torch.float32)
Expand Down

0 comments on commit 2d46caa

Please sign in to comment.