Skip to content

Commit

Permalink
Update flash attention op (#616)
Browse files Browse the repository at this point in the history
This commit updates the flash attention op to adhere to the addition of
is_causal and scale args added by this commit:
2d46caa.
Without this, we are seeing a fp8 attention export failure
  • Loading branch information
saienduri authored Nov 27, 2024
1 parent cdb4ccd commit d6be43f
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion sharktank/sharktank/ops/attention_impls.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def _extract_linear_scale(t):
return unbox_tensor(t), None


def flash_attention(q, k, v, a):
def flash_attention(q, k, v, a, is_causal, scale):
scale = torch.scalar_tensor(1.0 / math.sqrt(q.shape[-1]), dtype=torch.float32)

q, qscale = _extract_linear_scale(q)
Expand Down

0 comments on commit d6be43f

Please sign in to comment.