Skip to content

Commit

Permalink
Fix (nn/sdpa): formatting
Browse files Browse the repository at this point in the history
  • Loading branch information
nickfraser committed Nov 18, 2024
1 parent d54c832 commit b7cda03
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 15 deletions.
2 changes: 1 addition & 1 deletion src/brevitas/graph/standardize.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from brevitas.fx import GraphModule
from brevitas.fx import immutable_dict
from brevitas.fx import Node
from brevitas.nn.quant_sdpa import ScaledDotProductAttention
from brevitas.nn.quant_sdpa import ScaledDotProductAttention

from .base import FnToModule
from .base import GraphTransform
Expand Down
60 changes: 46 additions & 14 deletions src/brevitas/nn/quant_sdpa.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,17 @@


class ScaledDotProductAttention(Module):
def forward(self, query: Tensor, key: Tensor, value: Tensor, attn_mask: Optional[Tensor] = None, dropout_p: float = 0.0, is_causal: bool = False, scale: Optional[float] = None, enable_gqa: bool = False):

def forward(
self,
query: Tensor,
key: Tensor,
value: Tensor,
attn_mask: Optional[Tensor] = None,
dropout_p: float = 0.0,
is_causal: bool = False,
scale: Optional[float] = None,
enable_gqa: bool = False):
r"""
Args:
query (Tensor): Query tensor; shape :math:`(N, ..., Hq, L, E)`.
Expand All @@ -71,10 +81,10 @@ def forward(self, query: Tensor, key: Tensor, value: Tensor, attn_mask: Optional
scale (optional float, keyword-only): Scaling factor applied prior to softmax. If None, the default value is set
to :math:`\frac{1}{\sqrt{E}}`.
enable_gqa (bool): Ignored to make calling interface compatible with PyTorch >v2.5. Always set to False.
Returns:
output (Tensor): Attention output; shape :math:`(N, ..., Hq, L, Ev)`.
Shape legend:
- :math:`N: \text{Batch size} ... : \text{Any number of other batch dimensions (optional)}`
- :math:`S: \text{Source sequence length}`
Expand All @@ -84,30 +94,52 @@ def forward(self, query: Tensor, key: Tensor, value: Tensor, attn_mask: Optional
- :math:`Hq: \text{Number of heads of query}`
- :math:`H: \text{Number of heads of key and value}`
"""
return F.scaled_dot_product_attention(query=query, key=key, value=value, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal, scale=scale)
return F.scaled_dot_product_attention(
query=query,
key=key,
value=value,
attn_mask=attn_mask,
dropout_p=dropout_p,
is_causal=is_causal,
scale=scale)


class QuantScaledDotProductAttention(Module):
def __init__(self, query_quant=Int8ActPerTensorFloat, key_quant=Int8ActPerTensorFloat, value_quant=Int8ActPerTensorFloat, softmax_input_quant=Int8ActPerTensorFloat, softmax_output_quant=Uint8ActPerTensorFloat, attn_output_quant=None, **kwargs) -> None:

def __init__(
self,
query_quant=Int8ActPerTensorFloat,
key_quant=Int8ActPerTensorFloat,
value_quant=Int8ActPerTensorFloat,
softmax_input_quant=Int8ActPerTensorFloat,
softmax_output_quant=Uint8ActPerTensorFloat,
attn_output_quant=None,
**kwargs) -> None:
super(QuantScaledDotProductAttention, self).__init__()

def filter_kwargs(prefix):
return {k[len(prefix):]: v for k, v in kwargs.items() if k.startswith(prefix)}

self.query_quant = QuantIdentity(
act_quant=query_quant, **filter_kwargs('query_'))
self.key_quant = QuantIdentity(
act_quant=key_quant, **filter_kwargs('key_'))
self.value_quant = QuantIdentity(
act_quant=value_quant, **filter_kwargs('value_'))
self.query_quant = QuantIdentity(act_quant=query_quant, **filter_kwargs('query_'))
self.key_quant = QuantIdentity(act_quant=key_quant, **filter_kwargs('key_'))
self.value_quant = QuantIdentity(act_quant=value_quant, **filter_kwargs('value_'))
self.softmax_input_quant = QuantIdentity(
act_quant=softmax_input_quant, **filter_kwargs('softmax_input_'))
self.softmax_output_quant = QuantIdentity(
act_quant=softmax_output_quant, **filter_kwargs('softmax_output_'))
self.attn_output_quant = QuantIdentity(
act_quant=attn_output_quant, **filter_kwargs('attn_output_'))

def forward(self, query: Tensor, key: Tensor, value: Tensor, attn_mask: Optional[Tensor] = None, dropout_p: float = 0.0, is_causal: bool = False, scale: Optional[float] = None, enable_gqa: bool = False):
def forward(
self,
query: Tensor,
key: Tensor,
value: Tensor,
attn_mask: Optional[Tensor] = None,
dropout_p: float = 0.0,
is_causal: bool = False,
scale: Optional[float] = None,
enable_gqa: bool = False):
r"""
Args:
query (Tensor): Query tensor; shape :math:`(N, ..., Hq, L, E)`.
Expand All @@ -125,10 +157,10 @@ def forward(self, query: Tensor, key: Tensor, value: Tensor, attn_mask: Optional
scale (optional float, keyword-only): Scaling factor applied prior to softmax. If None, the default value is set
to :math:`\frac{1}{\sqrt{E}}`.
enable_gqa (bool): Ignored to make calling interface compatible with PyTorch >v2.5. Always set to False.
Returns:
output (Tensor): Attention output; shape :math:`(N, ..., Hq, L, Ev)`.
Shape legend:
- :math:`N: \text{Batch size} ... : \text{Any number of other batch dimensions (optional)}`
- :math:`S: \text{Source sequence length}`
Expand Down

0 comments on commit b7cda03

Please sign in to comment.