Skip to content

Commit

Permalink
[FEAT][LinearizedAttention][Mask]
Browse files Browse the repository at this point in the history
  • Loading branch information
Kye committed May 13, 2024
1 parent 45d4120 commit 7c5df32
Show file tree
Hide file tree
Showing 6 changed files with 52 additions and 57 deletions.
55 changes: 17 additions & 38 deletions multi_head_latent_attention.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import torch
import torch
from torch import nn, Tensor
from zeta.nn.embeddings.rope import RotaryEmbedding
from zeta.nn.attention.multiquery_attention import MultiQueryAttention


class MultiHeadLatentAttention(nn.Module):
def __init__(
self,
Expand All @@ -13,9 +14,8 @@ def __init__(
rope_scale_base: int = 512,
batch_size: int = 1,
seqlen: int = 10000,

*args,
**kwargs
*args,
**kwargs,
):
super().__init__(*args, **kwargs)
self.dim = dim
Expand All @@ -25,49 +25,28 @@ def __init__(
self.rope_scale_base = rope_scale_base
self.batch_size = batch_size
self.seqlen = seqlen

# Rotary Embedding
self.rope = RotaryEmbedding(
dim,
use_xpos=True,
scale_base=rope_scale_base,
*args,
**kwargs
dim, use_xpos=True, scale_base=rope_scale_base, *args, **kwargs
)

# Attention
self.attn = MultiQueryAttention(
dim,
heads,
*args,
**kwargs
)

#
self.latent_q = nn.Parameter(
torch.randn(
batch_size,
seqlen,
dim
)
)

self.attn = MultiQueryAttention(dim, heads, *args, **kwargs)

#
self.latent_q = nn.Parameter(torch.randn(batch_size, seqlen, dim))

# KV
self.latent_kv = nn.Parameter(
torch.randn(
batch_size,
seqlen,
dim
)
)

self.latent_kv = nn.Parameter(torch.randn(batch_size, seqlen, dim))

def forward(self, x: Tensor) -> Tensor:
device = x.device
k_r_t, scale = self.rope(self.seqlen, device)
print(k_r_t)
x = k_r_t + x


# # Example
# x = torch.randn(1, 100, 10)

Expand All @@ -79,4 +58,4 @@ def forward(self, x: Tensor) -> Tensor:

# # Apply the model
# out = model(x)
# print(out.shape)
# print(out.shape)
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "zetascale"
version = "2.4.8"
version = "2.4.9"
description = "Rapidly Build, Optimize, and Deploy SOTA AI Models"
authors = ["Zeta Team <kye@apac.ai>"]
license = "MIT"
Expand Down
1 change: 0 additions & 1 deletion zeta/nn/attention/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
from zeta.nn.attention.linearized_attention import LinearizedAttention



__all__ = [
"Attend",
"FlashAttention",
Expand Down
47 changes: 31 additions & 16 deletions zeta/nn/attention/linearized_attention.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import torch
import torch
from torch import nn, Tensor


Expand All @@ -7,8 +7,11 @@ def __init__(
self,
dim: int,
heads: int = 8,
seqlen: int = 10000,
seqlen: int = 1000,
groups: int = 1,
mask_on: bool = False,
*args,
**kwargs
):
"""
Linearized Attention module.
Expand All @@ -24,17 +27,21 @@ def __init__(
self.heads = heads
self.seqlen = seqlen
self.groups = groups

self.mask_on = mask_on

# Projection
self.proj = nn.Linear(dim, dim)

# RELU
self.act = nn.ReLU()

# Groupnorm
self.norm = nn.GroupNorm(groups, dim)

def forward(self, x: Tensor) -> Tensor:
# Mask Tensor
self.mask_tensor = torch.zeros(1, seqlen).bool()

def forward(self, x: Tensor, mask: bool = None) -> Tensor:
"""
Forward pass of the LinearizedAttention module.
Expand All @@ -48,21 +55,29 @@ def forward(self, x: Tensor) -> Tensor:
q = self.proj(x)
k = self.proj(x)
v = self.proj(x)

# Projected again
q_p = self.proj(q)
q_k = self.proj(k)

# Apply the relu
q_acted = self.act(q_p)
k_acted = self.act(q_k)

# Groupnorm
return nn.GroupNorm(self.groups, s)(q_acted + k_acted + v)



# x = torch.randn(1, 100, 512)
# model = LinearizedAttention(512, 8)
output = nn.GroupNorm(self.groups, s)(q_acted + k_acted + v)

# Apply mask
if mask is not None:
if self.mask_on is True:
mask = self.mask_tensor
else:
output = output.masked_fill(mask.unsqueeze(-1), float('-inf'))
print(output.shape)

return output

# x = torch.randn(1, 10, 20)
# model = LinearizedAttention(20, 8, mask_on=True)
# print(model(x))
# # torch.Size([1, 100, 512])
# # torch.Size([1, 10, 20])
1 change: 1 addition & 0 deletions zeta/nn/modules/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,7 @@
from zeta.nn.modules.fractoral_norm import FractoralNorm
from zeta.nn.modules.kv_cache_update import kv_cache_with_update
from zeta.nn.modules.expand import expand

# from zeta.nn.modules.img_reshape import image_reshape
# from zeta.nn.modules.flatten_features import flatten_features
# from zeta.nn.modules.scaled_sinusoidal import ScaledSinuosidalEmbedding
Expand Down
3 changes: 2 additions & 1 deletion zeta/nn/modules/expand.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from einops import repeat


def expand(*args, **kwargs):
return repeat(*args, **kwargs)
return repeat(*args, **kwargs)

0 comments on commit 7c5df32

Please sign in to comment.