Skip to content

Commit

Permalink
Fix (nn/sdpa): Updated argument to match qsdpa
Browse files Browse the repository at this point in the history
  • Loading branch information
nickfraser committed Nov 18, 2024
1 parent b7cda03 commit 221c822
Showing 1 changed file with 16 additions and 13 deletions.
29 changes: 16 additions & 13 deletions src/brevitas/nn/quant_sdpa.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,25 +108,26 @@ class QuantScaledDotProductAttention(Module):

def __init__(
self,
query_quant=Int8ActPerTensorFloat,
key_quant=Int8ActPerTensorFloat,
value_quant=Int8ActPerTensorFloat,
softmax_input_quant=Int8ActPerTensorFloat,
softmax_output_quant=Uint8ActPerTensorFloat,
softmax_input_quant=None,
attn_output_weights_quant=Uint8ActPerTensorFloat,
q_scaled_quant=Int8ActPerTensorFloat,
k_transposed_quant=Int8ActPerTensorFloat,
v_quant=Int8ActPerTensorFloat,
attn_output_quant=None,
**kwargs) -> None:
super(QuantScaledDotProductAttention, self).__init__()

def filter_kwargs(prefix):
return {k[len(prefix):]: v for k, v in kwargs.items() if k.startswith(prefix)}

self.query_quant = QuantIdentity(act_quant=query_quant, **filter_kwargs('query_'))
self.key_quant = QuantIdentity(act_quant=key_quant, **filter_kwargs('key_'))
self.value_quant = QuantIdentity(act_quant=value_quant, **filter_kwargs('value_'))
self.q_scaled_quant = QuantIdentity(act_quant=q_scaled_quant, **filter_kwargs('q_scaled_'))
self.k_transposed_quant = QuantIdentity(
act_quant=k_transposed_quant, **filter_kwargs('k_transposed_'))
self.v_quant = QuantIdentity(act_quant=v_quant, **filter_kwargs('v_'))
self.softmax_input_quant = QuantIdentity(
act_quant=softmax_input_quant, **filter_kwargs('softmax_input_'))
self.softmax_output_quant = QuantIdentity(
act_quant=softmax_output_quant, **filter_kwargs('softmax_output_'))
self.attn_output_weights_quant = QuantIdentity(
act_quant=attn_output_weights_quant, **filter_kwargs('attn_output_weights_'))
self.attn_output_quant = QuantIdentity(
act_quant=attn_output_quant, **filter_kwargs('attn_output_'))

Expand Down Expand Up @@ -187,12 +188,14 @@ def forward(
attn_bias.masked_fill_(attn_mask.logical_not(), float("-inf"))
else:
attn_bias += attn_mask
attn_weight = query @ key.transpose(-2, -1) * scale_factor
q_scaled = self.q_scaled_quant(query * scale_factor)
k_transpose = self.k_transpose_quant(key.transpose(-2, -1))
attn_weight = q_scaled @ k_transpose
attn_weight += attn_bias
attn_weight = self.softmax_input_quant(attn_weight)
attn_weight = torch.softmax(attn_weight, dim=-1)
attn_weight = torch.dropout(attn_weight, dropout_p, train=True)
attn_weight = self.softmax_output_quant(attn_weight)
attn_output = attn_weight @ value
attn_weight = self.attn_output_weights_quant(attn_weight)
attn_output = attn_weight @ self.v_quant(value)
attn_output = self.attn_output_quant(attn_output)
return attn_output

0 comments on commit 221c822

Please sign in to comment.