Skip to content

Commit

Permalink
[FEAT][SigmoidAttention]
Browse files Browse the repository at this point in the history
  • Loading branch information
Your Name committed Sep 9, 2024
1 parent 1dfc128 commit afb1556
Show file tree
Hide file tree
Showing 5 changed files with 181 additions and 3 deletions.
26 changes: 26 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -548,6 +548,32 @@ loss.backward()
```


## Sigmoid Attention

Attention 18% faster with sigmoid instead of attention

- replace traditional softmax in attention with a sigmoid and
- a constant (not learned) scalar bias based on the sequence length.


```python
import torch
from zeta import SigmoidAttention
from loguru import logger

batch_size = 32
seq_len = 128
dim = 512
heads = 8

x = torch.rand(batch_size, seq_len, dim)
mask = torch.ones(batch_size, seq_len, seq_len) # Example mask

sigmoid_attn = SigmoidAttention(dim, heads, seq_len)
output = sigmoid_attn(x, mask)
print(output.shape)
```



# Documentation
Expand Down
15 changes: 15 additions & 0 deletions examples/modules/sigmoid_attn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
import torch
from zeta import SigmoidAttention
from loguru import logger

batch_size = 32
seq_len = 128
dim = 512
heads = 8

x = torch.rand(batch_size, seq_len, dim)
mask = torch.ones(batch_size, seq_len, seq_len) # Example mask

sigmoid_attn = SigmoidAttention(dim, heads, seq_len)
output = sigmoid_attn(x, mask)
logger.info(f"Final output shape: {output.shape}")
3 changes: 1 addition & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "zetascale"
version = "2.7.1"
version = "2.7.2"
description = "Rapidly Build, Optimize, and Train SOTA AI Models"
authors = ["Zeta Team <kye@apac.ai>"]
license = "MIT"
Expand Down Expand Up @@ -99,4 +99,3 @@ preview = true
# [tool.poetry.scripts]
# zeta = 'zeta.cli.main:main'


3 changes: 2 additions & 1 deletion zeta/nn/modules/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,7 @@
from zeta.nn.modules.adaptive_gating import AdaptiveGating
from zeta.nn.modules.crome_adapter import CROMEAdapter
from zeta.nn.modules.cog_vlm_two_adapter import CogVLMTwoAdapter

from zeta.nn.modules.sigmoid_attn import SigmoidAttention
# from zeta.nn.modules.img_reshape import image_reshape
# from zeta.nn.modules.flatten_features import flatten_features
# from zeta.nn.modules.scaled_sinusoidal import ScaledSinuosidalEmbedding
Expand Down Expand Up @@ -455,4 +455,5 @@
"AdaptiveGating",
"CROMEAdapter",
"CogVLMTwoAdapter",
"SigmoidAttention",
]
137 changes: 137 additions & 0 deletions zeta/nn/modules/sigmoid_attn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from loguru import logger
from typing import Optional


class SigmoidAttention(nn.Module):
"""
Implements Sigmoid Attention Mechanism.
This replaces the traditional softmax in attention with a sigmoid function.
Additionally, a constant scalar bias based on the sequence length is introduced.
Args:
dim (int): Dimension of the model (input size).
heads (int): Number of attention heads.
seq_len (int): The length of the input sequence.
dropout (float, optional): Dropout rate. Default is 0.1.
bias (bool, optional): Whether to include bias in linear layers. Default is True.
"""

def __init__(
self,
dim: int,
heads: int,
seq_len: int,
dropout: float = 0.1,
bias: bool = True,
) -> None:
super(SigmoidAttention, self).__init__()

logger.info(
f"Initializing SigmoidAttention with dim={dim}, heads={heads}, seq_len={seq_len}, dropout={dropout}, bias={bias}"
)
self.dim = dim
self.heads = heads
self.seq_len = seq_len
self.head_dim = dim // heads

assert self.head_dim * heads == dim, "dim must be divisible by heads"
logger.debug(f"Each attention head has {self.head_dim} dimensions.")

self.query = nn.Linear(dim, dim, bias=bias)
self.key = nn.Linear(dim, dim, bias=bias)
self.value = nn.Linear(dim, dim, bias=bias)

self.out = nn.Linear(dim, dim)
self.dropout = nn.Dropout(dropout)

# Create a constant scalar bias based on the sequence length
self.bias = nn.Parameter(
torch.ones(1) * math.sqrt(self.seq_len), requires_grad=False
)
logger.debug(
f"Scalar bias initialized as {self.bias.item()} based on sequence length."
)

def forward(
self, x: torch.Tensor, mask: Optional[torch.Tensor] = None
) -> torch.Tensor:
"""
Forward pass of the Sigmoid Attention mechanism.
Args:
x (torch.Tensor): Input tensor of shape (batch_size, seq_len, dim).
mask (Optional[torch.Tensor], optional): Mask tensor to prevent attention to certain positions.
Should be of shape (batch_size, seq_len, seq_len).
Returns:
torch.Tensor: Output tensor of shape (batch_size, seq_len, dim).
"""
logger.info(f"Running forward pass with input shape {x.shape}")
batch_size, seq_len, _ = x.size()

# Linear projections for query, key, and value
Q = (
self.query(x)
.view(batch_size, seq_len, self.heads, self.head_dim)
.transpose(1, 2)
)
K = (
self.key(x)
.view(batch_size, seq_len, self.heads, self.head_dim)
.transpose(1, 2)
)
V = (
self.value(x)
.view(batch_size, seq_len, self.heads, self.head_dim)
.transpose(1, 2)
)

logger.debug(f"Q, K, V shapes: {Q.shape}, {K.shape}, {V.shape}")

# Scaled dot-product attention with sigmoid instead of softmax
scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.head_dim)
scores = scores / self.bias # Apply the constant scalar bias
attn = torch.sigmoid(scores)
logger.debug(f"Attention scores computed with sigmoid: {attn.shape}")

# Apply the mask (optional)
if mask is not None:
logger.debug(f"Original mask shape: {mask.shape}")
# Expand the mask to match the attention scores shape
mask = mask.unsqueeze(1) # Adds dimension for heads
logger.debug(f"Expanded mask shape: {mask.shape}")
attn = attn.masked_fill(mask == 0, -1e9)
logger.debug("Mask applied to attention scores.")

attn = self.dropout(attn)
output = torch.matmul(attn, V)
output = (
output.transpose(1, 2)
.contiguous()
.view(batch_size, seq_len, self.dim)
)

logger.info(f"Output shape: {output.shape}")
return self.out(output)


# # Example usage
# if __name__ == "__main__":
# import torch
# from zeta import SigmoidAttention
# batch_size = 32
# seq_len = 128
# dim = 512
# heads = 8

# x = torch.rand(batch_size, seq_len, dim)
# mask = torch.ones(batch_size, seq_len, seq_len) # Example mask

# sigmoid_attn = SigmoidAttention(dim, heads, seq_len)
# output = sigmoid_attn(x, mask)
# logger.info(f"Final output shape: {output.shape}")

0 comments on commit afb1556

Please sign in to comment.